跳到主要内容

env

Rubik's Cube 环境

此模块定义了表示半转公制下 3x3x3 魔方的 Cube3 类。

类: Cube3:半转公制下 3x3x3 魔方的类。

Cube3

class Cube3()

半转公制(HTM)下 3x3x3 魔方的类。

此类提供了使用半转公制操作和求解 3x3x3 魔方的方法。 它定义了魔方的初始状态和目标状态、可用的转动以及魔方操作的方法。

表示

面的顺序

   0
2 5 3 4
1

每个面上贴纸的顺序

  2   5   8
1 4 7
[0] 3 6

状态的索引(每个索引以 9 * (n-1) 开始):

                2   5   8
1 4 7
[0] 3 6
20 23 26 47 50 53 29 32 35 38 41 44
19 22 25 46 49 52 28 31 34 37 40 43
[18] 21 24 [45] 48 51 [27] 30 33 [36] 39 42
11 14 17
10 13 16
[9] 12 15

颜色索引 // 9):

                0   0   0
0 0 0
0 0 0
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
2 2 2 5 5 5 3 3 3 4 4 4
1 1 1
1 1 1
1 1 1

属性

  • state ndarray - 表示为贴纸颜色数组的当前魔方状态。
  • GOAL ndarray - 表示为贴纸颜色数组的固定目标状态。
  • moves list - 可能的魔方转动列表(面和方向)。
  • allow_wide bool - 指示是否允许宽转的标志。
  • sticker_target dict - 将转动字符串映射到目标贴纸索引列表的字典。
  • sticker_source dict - 将转动字符串映射到贴纸索引列表的字典。
  • sticker_target_ix ndarray - 将转动索引映射到普通转动的目标贴纸索引的二维 numpy 数组。
  • sticker_source_ix ndarray - 将转动索引映射到普通转动的贴纸索引的二维 numpy 数组。
  • sticker_target_ix_wide ndarray - 将转动索引映射到宽转的目标贴纸索引的二维 numpy 数组。
  • sticker_source_ix_wide ndarray - 将转动索引映射到宽转的贴纸索引的二维 numpy 数组。

show

def show(flat=False, palette=["white", "yellow", "orange1", "red", "blue", "green"])

显示魔方的当前状态。

参数

  • flat bool - 是否以平面形式显示状态。
  • palette list - 表示贴纸的颜色列表。

validate

def validate(state=None, centered=True)

验证魔方的状态和排列。

参数

  • centered bool - 中心块是否应该居中。

引发

  • ValueError - 如果魔方的状态或排列无效。

reset

def reset()

将魔方状态重置为已还原状态。

reset_axes

def reset_axes()

根据给定的中心块颜色重置颜色索引。 在应用宽转或指定意外视角时很有用。

is_solved

def is_solved()

检查魔方是否处于已还原状态。

finger

def finger(move)

使用转动字符串对魔方状态应用单次转动。

参数

  • move str - HTM 表示法中的转动字符串。

finger_ix

def finger_ix(ix)

使用转动索引应用单次转动以加快执行速度。

参数

  • ix int - 转动的索引。

apply_scramble

def apply_scramble(scramble)

将一系列转动(打乱)应用于魔方状态。

参数

  • scramble str or list - HTM 表示法或列表中的转动序列。

__vectorize_moves

def __vectorize_moves()

对贴纸组替换操作进行矢量化,以加快计算速度。 此方法定义 self.sticker_targetself.sticker_source 来管理贴纸颜色(目标被源替换)。 它们定义目标和源贴纸的索引,以便转动可以矢量化。

Dataset

class Dataset(torch.utils.data.Dataset)

无限生成随机打乱的伪数据集类。

示例
batch_size = 1024
dl = get_dataloader(batch_size)
for i, (batch_x, batch_y) in zip(range(1000), dl):
batch_x, batch_y = batch_x.to(device), batch_y.device().reshape(-1)

get_dataloader

def get_dataloader(batch_size, num_workers=min(os.cpu_count(), 32), max_depth=20, **dl_kwargs)

为生成随机魔方打乱创建 DataLoader 实例。

参数

  • batch_size int - 每批样本数。
  • num_workers int, optional - 用于数据加载的工作进程数。 默认为 CPU 核心数或 32(超过此值回报会减少),取较小值。
  • max_depth int, optional - 打乱的最大深度。默认为 20。
  • **dl_kwargs - 传递给 DataLoader 构造函数的其他关键字参数。

返回

  • torch.utils.data.DataLoader - 生成随机打乱批次的 DataLoader 实例。