model
Model Loader
This module provides a function and classes for loading and using a pre-trained AlphaCube solver model.
Function:
load_model(model_id, cache_dir)
: Load the pre-trained AlphaCube solver model.
Classes:
Model
: The MLP architecture for the AlphaCube solver.LinearBlock
: A building block for the MLP architecture.
load_model
def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))
Load the pre-trained AlphaCube solver model.
Arguments:
model_id
str - Identifier for the model variant to load ("small", "base", or "large").cache_dir
str - Directory to cache the downloaded model.
Returns:
nn.Module
- Loaded AlphaCube solver model.
Model_v1
class Model_v1(nn.Module)
The MLP architecture for the compute-optimal Rubik's Cube solver introduced in the following paper: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Forward pass of the model.
Arguments:
inputs
torch.Tensor - Input tensor representing the problem state.
Returns:
torch.Tensor
- Predicted distribution over possible solutions.
Model
class Model(nn.Module)
An architecture better than Model
.
Changes:
- Remove ReLU activation from the first layer (
embedding
), which had the dying ReLU problem. - Following the recent convention, the
embedding
layer does not count as one hidden layer.
reset_parameters
def reset_parameters()
Initialize all weight such that the activation variances are approximately 1.0, with bias terms & the output layer weights being zeros
forward
def forward(inputs)
Forward pass of the model.
Arguments:
inputs
torch.Tensor - Input tensor representing the problem state.
Returns:
torch.Tensor
- Predicted distribution over possible solutions.
LinearBlock
class LinearBlock(nn.Module)
A building block for the MLP architecture.
This block consists of a linear layer followed by ReLU activation and batch normalization.
forward
def forward(inputs)
Forward pass of the linear block.
Arguments:
inputs
torch.Tensor - Input tensor.
Returns:
torch.Tensor
- Output tensor after linear transformation, ReLU activation, and batch normalization.