env
Rubik's Cube Environment
该模块定义了 Cube3
类,用于表示半圈转动度量 (Half-Turn Metric) 下的 3x3x3 魔方。
Class:
Cube3
: 用于表示半圈转动度量下 3x3x3 魔方的类。
Cube3
class Cube3()
一个用于表示半圈转动度量 (HTM) 下 3x3x3 魔方的类。
该类提供了使用半圈转动度量来操作和求解 3x3x3 魔方的方法。 它定义了魔方的初始状态、目标状态、可用转动以及魔方操作方法。
表示法:
面的顺序:
0
2 5 3 4
1每个面上贴纸的顺序:
2 5 8
1 4 7
[0] 3 6状态索引 (每个面从
9 * (n-1)
开始):2 5 8
1 4 7
[0] 3 6
20 23 26 47 50 53 29 32 35 38 41 44
19 22 25 46 49 52 28 31 34 37 40 43
[18] 21 24 [45] 48 51 [27] 30 33 [36] 39 42
11 14 17
10 13 16
[9] 12 15颜色 (
indices // 9
):0 0 0
0 0 0
0 0 0
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
1 1 1
1 1 1
1 1 1
属性:
state
ndarray - 当前魔方状态,表示为贴纸颜色的数组。GOAL
ndarray - 固定的目标状态,表示为贴纸颜色的数组。moves
list - 可能的魔方转动列表(面和方向)。allow_wide
bool - 指示是否允许宽转动的标志。max_depth
int - 数据生成器的最大打乱深度。sticker_target
dict - 一个将转动字符串映射到目标贴纸索引列表的字典。sticker_source
dict - 一个将转动字符串映射到源贴纸索引列表的字典。sticker_target_ix
ndarray - 一个将转动索引映射到普通转动的目标贴纸索引的二维 numpy 数组。sticker_source_ix
ndarray - 一个将转动索引映射到普通转动的源贴纸索引的二维 numpy 数组。sticker_target_ix_wide
ndarray - 一个将转动索引映射到宽转动的目标贴纸索引的二维 numpy 数组。sticker_source_ix_wide
ndarray - 一个将转动索引映射到宽转动的源贴纸索引的二维 numpy 数组。
show
def show(
flat=False,
palette=["white", "yellow", "orange1", "red", "blue", "green"]
)
显示魔方的当前状态。
参数:
flat
bool - 是否以平面形式显示状态。palette
list - 用于表示贴纸的颜色列表。
validate
def validate(
state=None,
centered=True
)
验证魔方的状态和排列。
参数:
centered
bool - 中心块是否应位于中心位置。
抛出:
ValueError
- 如果魔方的状态或排列无效。
reset
def reset()
将魔方状态重置为已解开状态。
reset_axes
def reset_axes()
根据给定的中心块颜色重置颜色索引。 当应用宽转动(fat moves)或指定了意外的视角时很有用。
is_solved
def is_solved()
检查魔方是否处于已解开状态。
finger
def finger(move)
使用转动字符串在魔方状态上应用单次转动。
参数:
move
str - HTM 标记法中的转动字符串。
finger_ix
def finger_ix(ix)
使用转动索引应用单次转动,比 .finger
执行速度更快。
检查转动索引是对应普通转动 (ix < 18) 还是宽转动,并使用预先计算的索引数组应用状态更改。
参数:
ix
int - 要应用的转动的索引。
apply_scramble
def apply_scramble(scramble)
将一系列转动(打乱公式)应用于魔方状态。
参数:
scramble
str or list - HTM 标记法中的转动序列或列表。
__iter__
def __iter__()
创建一个无限生成器,用于生成打乱状态和解法序列。
此方法旨在用于模型训练。在每次迭代中,它会生成一个新的 max_depth
步的随机打乱,并避免无意义的转动序列。
它会产生状态历史以及导致每个状态的相应转动。
产生:
tuple[np.ndarray, np.ndarray]: 一个包含以下内容的元组:
- X (np.ndarray): 一个 (max_depth, 54) 的魔方状态数组。
- y (np.ndarray): 一个 (max_depth,) 的生成这些状态的转动索引数组。
__vectorize_moves
def __vectorize_moves()
将贴纸组替换操作向量化,以加快计算速度。
此方法定义了 self.sticker_target
和 self.sticker_source
来管理贴纸颜色(目标被源替换)。
它们定义了目标和源贴纸的索引,以便可以对转动进行向量化。
Dataset
class Dataset(torch.utils.data.Dataset)
用于无限产生随机打乱的伪数据集类。
Examplebatch_size = 1024
dl = get_dataloader(batch_size)
for i, (batch_x, batch_y) in zip(range(1000), dl):
batch_x, batch_y = batch_x.to(device), batch_y.device().reshape(-1)
get_dataloader
def get_dataloader(
batch_size,
num_workers=min(os.cpu_count(), 32), max_depth=20,
**dl_kwargs
)
创建一个 DataLoader 实例,用于生成随机的魔方打乱。
参数:
batch_size
int - 每批的样本数量。num_workers
int, optional - 用于数据加载的工作进程数。默认为 CPU 核心数或 32(超过此值收益将递减)中较小的值。max_depth
int, optional - 打乱的最大深度。默认为 20。**dl_kwargs
- 传递给 DataLoader 构造函数的其他关键字参数。
返回:
torch.utils.data.DataLoader
- 一个 DataLoader 实例,它会产生一批批的随机打乱。