跳到主要内容

env

Rubik's Cube Environment

该模块定义了 Cube3 类,用于表示半圈转动度量 (Half-Turn Metric) 下的 3x3x3 魔方。

Class: Cube3: 用于表示半圈转动度量下 3x3x3 魔方的类。

Cube3

class Cube3()

一个用于表示半圈转动度量 (HTM) 下 3x3x3 魔方的类。

该类提供了使用半圈转动度量来操作和求解 3x3x3 魔方的方法。 它定义了魔方的初始状态、目标状态、可用转动以及魔方操作方法。

表示法:

面的顺序:

   0
2 5 3 4
1

每个面上贴纸的顺序:

  2   5   8
1 4 7
[0] 3 6

状态索引 (每个面从 9 * (n-1) 开始):

                2   5   8
1 4 7
[0] 3 6
20 23 26 47 50 53 29 32 35 38 41 44
19 22 25 46 49 52 28 31 34 37 40 43
[18] 21 24 [45] 48 51 [27] 30 33 [36] 39 42
11 14 17
10 13 16
[9] 12 15

颜色 (indices // 9):

                0   0   0
0 0 0
0 0 0
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
1 1 1
1 1 1
1 1 1

属性:

  • state ndarray - 当前魔方状态,表示为贴纸颜色的数组。
  • GOAL ndarray - 固定的目标状态,表示为贴纸颜色的数组。
  • moves list - 可能的魔方转动列表(面和方向)。
  • allow_wide bool - 指示是否允许宽转动的标志。
  • max_depth int - 数据生成器的最大打乱深度。
  • sticker_target dict - 一个将转动字符串映射到目标贴纸索引列表的字典。
  • sticker_source dict - 一个将转动字符串映射到贴纸索引列表的字典。
  • sticker_target_ix ndarray - 一个将转动索引映射到普通转动的目标贴纸索引的二维 numpy 数组。
  • sticker_source_ix ndarray - 一个将转动索引映射到普通转动的贴纸索引的二维 numpy 数组。
  • sticker_target_ix_wide ndarray - 一个将转动索引映射到宽转动的目标贴纸索引的二维 numpy 数组。
  • sticker_source_ix_wide ndarray - 一个将转动索引映射到宽转动的贴纸索引的二维 numpy 数组。

show

def show(
flat=False,
palette=["white", "yellow", "orange1", "red", "blue", "green"]
)

显示魔方的当前状态。

参数:

  • flat bool - 是否以平面形式显示状态。
  • palette list - 用于表示贴纸的颜色列表。

validate

def validate(
state=None,
centered=True
)

验证魔方的状态和排列。

参数:

  • centered bool - 中心块是否应位于中心位置。

抛出:

  • ValueError - 如果魔方的状态或排列无效。

reset

def reset()

将魔方状态重置为已解开状态。

reset_axes

def reset_axes()

根据给定的中心块颜色重置颜色索引。 当应用宽转动(fat moves)或指定了意外的视角时很有用。

is_solved

def is_solved()

检查魔方是否处于已解开状态。

finger

def finger(move)

使用转动字符串在魔方状态上应用单次转动。

参数:

  • move str - HTM 标记法中的转动字符串。

finger_ix

def finger_ix(ix)

使用转动索引应用单次转动,比 .finger 执行速度更快。 检查转动索引是对应普通转动 (ix < 18) 还是宽转动,并使用预先计算的索引数组应用状态更改。

参数:

  • ix int - 要应用的转动的索引。

apply_scramble

def apply_scramble(scramble)

将一系列转动(打乱公式)应用于魔方状态。

参数:

  • scramble str or list - HTM 标记法中的转动序列或列表。

__iter__

def __iter__()

创建一个无限生成器,用于生成打乱状态和解法序列。

此方法旨在用于模型训练。在每次迭代中,它会生成一个新的 max_depth 步的随机打乱,并避免无意义的转动序列。 它会产生状态历史以及导致每个状态的相应转动。

产生:

tuple[np.ndarray, np.ndarray]: 一个包含以下内容的元组:

  • X (np.ndarray): 一个 (max_depth, 54) 的魔方状态数组。
  • y (np.ndarray): 一个 (max_depth,) 的生成这些状态的转动索引数组。

__vectorize_moves

def __vectorize_moves()

将贴纸组替换操作向量化,以加快计算速度。 此方法定义了 self.sticker_targetself.sticker_source 来管理贴纸颜色(目标被源替换)。 它们定义了目标和源贴纸的索引,以便可以对转动进行向量化。

Dataset

class Dataset(torch.utils.data.Dataset)

用于无限产生随机打乱的伪数据集类。

Example
batch_size = 1024
dl = get_dataloader(batch_size)
for i, (batch_x, batch_y) in zip(range(1000), dl):
batch_x, batch_y = batch_x.to(device), batch_y.device().reshape(-1)

get_dataloader

def get_dataloader(
batch_size,
num_workers=min(os.cpu_count(), 32), max_depth=20,
**dl_kwargs
)

创建一个 DataLoader 实例,用于生成随机的魔方打乱。

参数:

  • batch_size int - 每批的样本数量。
  • num_workers int, optional - 用于数据加载的工作进程数。默认为 CPU 核心数或 32(超过此值收益将递减)中较小的值。
  • max_depth int, optional - 打乱的最大深度。默认为 20。
  • **dl_kwargs - 传递给 DataLoader 构造函数的其他关键字参数。

返回:

  • torch.utils.data.DataLoader - 一个 DataLoader 实例,它会产生一批批的随机打乱。