跳到主要内容

model

模型加载器

此模块提供了一个函数和几个类,用于加载和使用预训练的 AlphaCube 求解器模型。

函数: load_model(model_id, cache_dir):加载预训练的 AlphaCube 求解器模型。

类:

  • Model:AlphaCube 求解器的 MLP 架构。
  • LinearBlock:MLP 架构的构建块。

load_model

def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))

加载预训练的 AlphaCube 求解器模型。

参数

  • model_id str - 要加载的模型变体的标识符("small"、"base" 或 "large")。
  • cache_dir str - 缓存下载模型的目录。

返回

  • nn.Module - 加载的 AlphaCube 求解器模型。

Model_v1

class Model_v1(nn.Module)

以下论文中介绍的计算最优 Rubik's Cube 求解器的 MLP 架构: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

模型的前向传递。

参数

  • inputs torch.Tensor - 表示问题状态的输入张量。

返回

  • torch.Tensor - 预测的可能解决方案的分布。

Model

class Model(nn.Module)

一个比 Model 更好的架构。

变化

  • 从第一层(embedding)中移除 ReLU 激活,该层存在 ReLU 死亡问题。
  • 遵循最新的惯例,embedding算作一个隐藏层。

reset_parameters

def reset_parameters()

初始化所有权重,使激活方差近似为 1.0,偏置项和输出层权重为零。

forward

def forward(inputs)

模型的前向传递。

参数

  • inputs torch.Tensor - 表示问题状态的输入张量。

返回

  • torch.Tensor - 预测的可能解决方案的分布。

LinearBlock

class LinearBlock(nn.Module)

MLP 架构的构建块。

该块由线性层、ReLU 激活和批归一化组成。

forward

def forward(inputs)

线性块的前向传递。

参数

  • inputs torch.Tensor - 输入张量。

返回

  • torch.Tensor - 经过线性变换、ReLU 激活和批归一化后的输出张量。