env
Rubik's Cube Environment
This module defines the Cube3 class representing a 3x3x3 Rubik's Cube in the Half-Turn Metric.
Class:
Cube3
: A class for 3x3x3 Rubik's Cube in Half-Turn Metric.
Cube3
class Cube3()
A class for 3x3x3 Rubik's Cube in Half-Turn Metric (HTM).
This class provides methods to manipulate and solve a 3x3x3 Rubik's Cube using the half-turn metric. It defines the cube's initial and goal states, available moves, and methods for cube manipulation.
Representation:
Order of faces:
0
2 5 3 4
1Order of stickers on each face:
2 5 8
1 4 7
[0] 3 6Indices of state (each starting with
9 * (n-1)
):2 5 8
1 4 7
[0] 3 6
20 23 26 47 50 53 29 32 35 38 41 44
19 22 25 46 49 52 28 31 34 37 40 43
[18] 21 24 [45] 48 51 [27] 30 33 [36] 39 42
11 14 17
10 13 16
[9] 12 15Colors (
indices // 9
):0 0 0
0 0 0
0 0 0
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
1 1 1
1 1 1
1 1 1
Attributes:
state
ndarray - Current cube state represented as an array of sticker colors.GOAL
ndarray - Fixed goal state represented as an array of sticker colors.moves
list - List of possible cube moves (face and direction).allow_wide
bool - Flag indicating whether wide moves are allowed.sticker_target
dict - A dictionary mapping move strings to lists of target sticker indices.sticker_source
dict - A dictionary mapping move strings to lists of source sticker indices.sticker_target_ix
ndarray - A 2D numpy array mapping move indices to target sticker indices for normal moves.sticker_source_ix
ndarray - A 2D numpy array mapping move indices to source sticker indices for normal moves.sticker_target_ix_wide
ndarray - A 2D numpy array mapping move indices to target sticker indices for wide moves.sticker_source_ix_wide
ndarray - A 2D numpy array mapping move indices to source sticker indices for wide moves.
show
def show(flat=False, palette=["white", "yellow", "orange1", "red", "blue", "green"])
Display the cube's current state.
Arguments:
flat
bool - Whether to display the state in flat form.palette
list - List of colors for representing stickers.
validate
def validate(state=None, centered=True)
Validate the cube's state and arrangement.
Arguments:
centered
bool - Whether centers should be centered or not.
Raises:
ValueError
- If the cube's state or arrangement is invalid.
reset
def reset()
Resets the cube state to the solved state.
reset_axes
def reset_axes()
Reset color indices according to the given center colors. Useful when fat moves are applied or when an unexpected perspective is specified.
is_solved
def is_solved()
Checks if the cube is in the solved state.
finger
def finger(move)
Apply a single move on the cube state using move string.
Arguments:
move
str - Move string in HTM notation.
finger_ix
def finger_ix(ix)
Apply a single move using move index for faster execution.
Arguments:
ix
int - Index of the move.
apply_scramble
def apply_scramble(scramble)
Applies a sequence of moves (scramble) to the cube state.
Arguments:
scramble
str or list - Sequence of moves in HTM notation or list.
__vectorize_moves
def __vectorize_moves()
Vectorizes the sticker group replacement operations for faster computation.
This method defines self.sticker_target
and self.sticker_source
to manage sticker colors (target is replaced by source).
They define indices of target and source stickers so that the moves can be vectorized.
Dataset
class Dataset(torch.utils.data.Dataset)
Pseudo dataset class to infinitely yield random scrambles.
Examplebatch_size = 1024
dl = get_dataloader(batch_size)
for i, (batch_x, batch_y) in zip(range(1000), dl):
batch_x, batch_y = batch_x.to(device), batch_y.device().reshape(-1)
get_dataloader
def get_dataloader(batch_size, num_workers=min(os.cpu_count(), 32), max_depth=20, **dl_kwargs)
Create a DataLoader instance for generating random Rubik's Cube scrambles.
Arguments:
batch_size
int - The number of samples per batch.num_workers
int, optional - The number of worker processes to use for data loading. Defaults to the number of CPU cores or 32 (beyond which the return will diminish), whichever is smaller.max_depth
int, optional - The maximum depth of the scrambles. Defaults to 20.**dl_kwargs
- Additional keyword arguments to pass to the DataLoader constructor.
Returns:
torch.utils.data.DataLoader
- A DataLoader instance that yields batches of random scrambles.