env
Rubik's Cube 环境
此模块定义了表示半转公制下 3x3x3 魔方的 Cube3 类。
类:
Cube3
:半转公制下 3x3x3 魔方的类。
Cube3
class Cube3()
半转公制(HTM)下 3x3x3 魔方的类。
此类提供了使用半转公制操作和求解 3x3x3 魔方的方法。 它定义了魔方的初始状态和目标状态、可用的转动以及魔方操作的方法。
表示:
面的顺序:
0
2 5 3 4
1每个面上贴纸的顺序:
2 5 8
1 4 7
[0] 3 6状态的索引(每个索引以
9 * (n-1)
开始):2 5 8
1 4 7
[0] 3 6
20 23 26 47 50 53 29 32 35 38 41 44
19 22 25 46 49 52 28 31 34 37 40 43
[18] 21 24 [45] 48 51 [27] 30 33 [36] 39 42
11 14 17
10 13 16
[9] 12 15颜色(
索引 // 9
):0 0 0
0 0 0
0 0 0
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
1 1 1
1 1 1
1 1 1
属性:
state
ndarray - 表示为贴纸颜色数组的当前魔方状态。GOAL
ndarray - 表示为贴纸颜色数组的固定目标状态。moves
list - 可能的魔方转动列表(面和方向)。allow_wide
bool - 指示是否允许宽转的标志。sticker_target
dict - 将转动字符串映射到目标贴纸索引列表的字典。sticker_source
dict - 将转动字符串映射到源贴纸索引列表的字典。sticker_target_ix
ndarray - 将转动索引映射到普通转动的目标贴纸索引的二维 numpy 数组。sticker_source_ix
ndarray - 将转动索引映射到普通转动的源贴纸索引的二维 numpy 数组。sticker_target_ix_wide
ndarray - 将转动索引映射到宽转的目标贴纸索引的二维 numpy 数组。sticker_source_ix_wide
ndarray - 将转动索引映射到宽转的源贴纸索引的二维 numpy 数组。
show
def show(flat=False, palette=["white", "yellow", "orange1", "red", "blue", "green"])
显示魔方的当前状态。
参数:
flat
bool - 是否以平面形式显示状态。palette
list - 表示贴纸的颜色列表。
validate
def validate(state=None, centered=True)
验证魔方的状态和排列。
参数:
centered
bool - 中心块是否应该居中。
引发:
ValueError
- 如果魔方的状态或排列无效。
reset
def reset()
将魔方状态重置为已还原状态。
reset_axes
def reset_axes()
根据给定的中心块颜色重置颜色索引。 在应用宽转或指定意外视角时很有用。
is_solved
def is_solved()
检查魔方是否处于已还原状态。
finger
def finger(move)
使用转动字符串对魔方状态应用单次转动。
参数:
move
str - HTM 表示法中的转动字符串。
finger_ix
def finger_ix(ix)
使用转动索引应用单次转动以加快执行速度。
参数:
ix
int - 转动的索引。
apply_scramble
def apply_scramble(scramble)
将一系列转动(打乱)应用于魔方状态。
参数:
scramble
str or list - HTM 表示法或列表中的转动序列。
__vectorize_moves
def __vectorize_moves()
对贴纸组替换操作进行矢量化,以加快计算速度。
此方法定义 self.sticker_target
和 self.sticker_source
来管理贴纸颜色(目标被源替换)。
它们定义目标和源贴纸的索引,以便转动可以矢量化。
Dataset
class Dataset(torch.utils.data.Dataset)
无限生成随机打乱的伪数据集类。
示例batch_size = 1024
dl = get_dataloader(batch_size)
for i, (batch_x, batch_y) in zip(range(1000), dl):
batch_x, batch_y = batch_x.to(device), batch_y.device().reshape(-1)
get_dataloader
def get_dataloader(batch_size, num_workers=min(os.cpu_count(), 32), max_depth=20, **dl_kwargs)
为生成随机魔方打乱创建 DataLoader 实例。
参数:
batch_size
int - 每批样本数。num_workers
int, optional - 用于数据加载的工作进程数。 默认为 CPU 核心数或 32(超过此值回报会减少),取较小值。max_depth
int, optional - 打乱的最大深度。默认为 20。**dl_kwargs
- 传递给 DataLoader 构造函数的其他关键字参数。
返回:
torch.utils.data.DataLoader
- 生成随机打乱批次的 DataLoader 实例。