跳到主要内容

model

Model Loader

该模块提供了一个函数和多个类,用于加载和使用预训练的 AlphaCube 求解器模型。

Function: load_model(model_id, cache_dir):加载预训练的 AlphaCube 求解器模型。

Classes:

  • Model:AlphaCube 求解器的 MLP 架构。
  • LinearBlock:MLP 架构的构建块。

load_model

def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)

加载预训练的 AlphaCube 求解器模型。

Arguments:

  • model_id str - 要加载的模型变体的标识符(“small”、“base”或“large”)。
  • cache_dir str - 用于缓存下载模型的目录。

Returns:

  • nn.Module - 已加载的 AlphaCube 求解器模型。

Model_v1

class Model_v1(nn.Module)

在以下论文中介绍的计算最优魔方求解器的 MLP 架构: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

模型的前向传播。

Arguments:

  • inputs torch.Tensor - 代表问题状态的输入张量。

Returns:

  • torch.Tensor - 预测的可能解的分布。

Model

class Model(nn.Module)

一个比 Model 更好的架构。

Changes:

  • 从第一层(embedding)中移除 ReLU 激活函数,该层存在 ReLU 死亡问题。
  • 遵循最近的惯例,embedding计为一个隐藏层。

reset_parameters

def reset_parameters()

初始化所有权重,使激活方差约等于 1.0,同时偏置项和输出层权重为零

forward

def forward(inputs)

模型的前向传播。

Arguments:

  • inputs torch.Tensor - 代表问题状态的输入张量。

Returns:

  • torch.Tensor - 预测的可能解的分布。

LinearBlock

class LinearBlock(nn.Module)

MLP 架构的构建块。

该块由一个线性层、一个 ReLU 激活函数和一个批归一化层组成。

forward

def forward(inputs)

线性块的前向传播。

Arguments:

  • inputs torch.Tensor - 输入张量。

Returns:

  • torch.Tensor - 经过线性变换、ReLU 激活和批归一化后的输出张量。