env
Rubik's Cube Environment
This module defines the Cube3 class representing a 3x3x3 Rubik's Cube in the Half-Turn Metric.
Class:
Cube3: A class for 3x3x3 Rubik's Cube in Half-Turn Metric.
Cube3
class Cube3()
A class for 3x3x3 Rubik's Cube in Half-Turn Metric (HTM).
This class provides methods to manipulate and solve a 3x3x3 Rubik's Cube using the half-turn metric. It defines the cube's initial and goal states, available moves, and methods for cube manipulation.
Representation:
Order of faces:
0
2 5 3 4
1Order of stickers on each face:
2 5 8
1 4 7
[0] 3 6Indices of state (each starting with
9 * (n-1)):2 5 8
1 4 7
[0] 3 6
20 23 26 47 50 53 29 32 35 38 41 44
19 22 25 46 49 52 28 31 34 37 40 43
[18] 21 24 [45] 48 51 [27] 30 33 [36] 39 42
11 14 17
10 13 16
[9] 12 15Colors (
indices // 9):0 0 0
0 0 0
0 0 0
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
1 1 1
1 1 1
1 1 1
Attributes:
statendarray - Current cube state represented as an array of sticker colors.GOALndarray - Fixed goal state represented as an array of sticker colors.moveslist - List of possible cube moves (face and direction).allow_widebool - Flag indicating whether wide moves are allowed.max_depthint - The maximum scramble depth for the data generator.sticker_targetdict - A dictionary mapping move strings to lists of target sticker indices.sticker_sourcedict - A dictionary mapping move strings to lists of source sticker indices.sticker_target_ixndarray - A 2D numpy array mapping move indices to target sticker indices for normal moves.sticker_source_ixndarray - A 2D numpy array mapping move indices to source sticker indices for normal moves.sticker_target_ix_widendarray - A 2D numpy array mapping move indices to target sticker indices for wide moves.sticker_source_ix_widendarray - A 2D numpy array mapping move indices to source sticker indices for wide moves.
show
def show(
flat=False,
palette=["white", "yellow", "orange1", "red", "blue", "green"]
)
Display the cube's current state.
Arguments:
flatbool - Whether to display the state in flat form.palettelist - List of colors for representing stickers.
validate
def validate(
state=None,
centered=True
)
Validate the cube's state and arrangement.
Arguments:
centeredbool - Whether centers should be centered or not.
Raises:
ValueError- If the cube's state or arrangement is invalid.
reset
def reset()
Resets the cube state to the solved state.
reset_axes
def reset_axes()
Reset color indices according to the given center colors. Useful when fat moves are applied or when an unexpected perspective is specified.
is_solved
def is_solved()
Checks if the cube is in the solved state.
finger
def finger(move)
Apply a single move on the cube state using move string.
Arguments:
movestr - Move string in HTM notation.
finger_ix
def finger_ix(ix)
Apply a single move using its index for faster execution than .finger.
Checks if the move index corresponds to a normal move (ix < 18) or a wide move and applies
the state change using pre-calculated index arrays.
Arguments:
ixint - Index of the move to apply.
apply_scramble
def apply_scramble(scramble)
Applies a sequence of moves (scramble) to the cube state.
Arguments:
scramblestr or list - Sequence of moves in HTM notation or list.
__iter__
def __iter__()
Create an infinite generator of scrambled states and solution sequences.
This method is intended for model training. On each iteration, it generates
a new random scramble of max_depth moves, avoiding trivial move sequences.
It yields the history of states and the corresponding move that led to each state.
Yields:
tuple[np.ndarray, np.ndarray]: A tuple containing:
- X (np.ndarray): A (max_depth, 54) array of cube states.
- y (np.ndarray): A (max_depth,) array of move indices that generated the states.
__vectorize_moves
def __vectorize_moves()
Vectorizes the sticker group replacement operations for faster computation.
This method defines self.sticker_target and self.sticker_source to manage sticker colors (target is replaced by source).
They define indices of target and source stickers so that the moves can be vectorized.
Dataset
class Dataset(torch.utils.data.Dataset)
Pseudo dataset class to infinitely yield random scrambles.
Examplebatch_size = 1024
dl = get_dataloader(batch_size)
for i, (batch_x, batch_y) in zip(range(1000), dl):
batch_x, batch_y = batch_x.to(device), batch_y.device().reshape(-1)
get_dataloader
def get_dataloader(
batch_size,
num_workers=min(os.cpu_count(), 32), max_depth=20,
**dl_kwargs
)
Create a DataLoader instance for generating random Rubik's Cube scrambles.
Arguments:
batch_sizeint - The number of samples per batch.num_workersint, optional - The number of worker processes to use for data loading. Defaults to the number of CPU cores or 32 (beyond which the return will diminish), whichever is smaller.max_depthint, optional - The maximum depth of the scrambles. Defaults to 20.**dl_kwargs- Additional keyword arguments to pass to the DataLoader constructor.
Returns:
torch.utils.data.DataLoader- A DataLoader instance that yields batches of random scrambles.