model
模型加载器
此模块提供了一个函数和几个类,用于加载和使用预训练的 AlphaCube 求解器模型。
函数:
load_model(model_id, cache_dir)
:加载预训练的 AlphaCube 求解器模型。
类:
Model
:AlphaCube 求解器的 MLP 架构。LinearBlock
:MLP 架构的构建块。
load_model
def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))
加载预训练的 AlphaCube 求解器模型。
参数:
model_id
str - 要加载的模型变体的标识符("small"、"base" 或 "large")。cache_dir
str - 缓存下载模型的目录。
返回:
nn.Module
- 加载的 AlphaCube 求解器模型。
Model_v1
class Model_v1(nn.Module)
以下论文中介绍的计算最优 Rubik's Cube 求解器的 MLP 架构: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
模型的前向传递。
参数:
inputs
torch.Tensor - 表示问题状态的输入张量。
返回:
torch.Tensor
- 预测的可能解决方案的分布。
Model
class Model(nn.Module)
一个比 Model
更好的架构。
变化:
- 从第一层(
embedding
)中移除 ReLU 激活,该层存在 ReLU 死亡问题。 - 遵循最新的惯例,
embedding
层不算作一个隐藏层。
reset_parameters
def reset_parameters()
初始化所有权重,使激活方差近似为 1.0,偏置项和输出层权重为零。
forward
def forward(inputs)
模型的前向传递。
参数:
inputs
torch.Tensor - 表示问题状态的输入张量。
返回:
torch.Tensor
- 预测的可能解决方案的分布。
LinearBlock
class LinearBlock(nn.Module)
MLP 架构的构建块。
该块由线性层、ReLU 激活和批归一化组成。
forward
def forward(inputs)
线性块的前向传递。
参数:
inputs
torch.Tensor - 输入张量。
返回:
torch.Tensor
- 经过线性变换、ReLU 激活和批归一化后的输出张量。