model
Model Loader
该模块提供了一个函数和多个类,用于加载和使用预训练的 AlphaCube 求解器模型。
Function:
load_model(model_id, cache_dir)
:加载预训练的 AlphaCube 求解器模型。
Classes:
Model
:AlphaCube 求解器的 MLP 架构。LinearBlock
:MLP 架构的构建块。
load_model
def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)
加载预训练的 AlphaCube 求解器模型。
Arguments:
model_id
str - 要加载的模型变体的标识符(“small”、“base”或“large”)。cache_dir
str - 用于缓存下载模型的目录。
Returns:
nn.Module
- 已加载的 AlphaCube 求解器模型。
Model_v1
class Model_v1(nn.Module)
在以下论文中介绍的计算最优魔方求解器的 MLP 架构: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
模型的前向传播。
Arguments:
inputs
torch.Tensor - 代表问题状态的输入张量。
Returns:
torch.Tensor
- 预测的可能解的分布。
Model
class Model(nn.Module)
一个比 Model
更好的架构。
Changes:
- 从第一层(
embedding
)中移除 ReLU 激活函数,该层存在 ReLU 死亡问题。 - 遵循最近的惯例,
embedding
层不计为一个隐藏层。
reset_parameters
def reset_parameters()
初始化所有权重,使激活方差约等于 1.0,同时偏置项和输出层权重为零
forward
def forward(inputs)
模型的前向传播。
Arguments:
inputs
torch.Tensor - 代表问题状态的输入张量。
Returns:
torch.Tensor
- 预测的可能解的分布。
LinearBlock
class LinearBlock(nn.Module)
MLP 架构的构建块。
该块由一个线性层、一个 ReLU 激活函数和一个批归一化层组成。
forward
def forward(inputs)
线性块的前向传播。
Arguments:
inputs
torch.Tensor - 输入张量。
Returns:
torch.Tensor
- 经过线性变换、ReLU 激活和批归一化后的输出张量。