Skip to main content

model

Model Loader

This module provides a function and classes for loading and using a pre-trained AlphaCube solver model.

Function: load_model(model_id, cache_dir): Load the pre-trained AlphaCube solver model.

Classes:

  • Model: The MLP architecture for the AlphaCube solver.
  • LinearBlock: A building block for the MLP architecture.

load_model

def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))

Load the pre-trained AlphaCube solver model.

Arguments:

  • model_id str - Identifier for the model variant to load ("small", "base", or "large").
  • cache_dir str - Directory to cache the downloaded model.

Returns:

  • nn.Module - Loaded AlphaCube solver model.

Model_v1

class Model_v1(nn.Module)

The MLP architecture for the compute-optimal Rubik's Cube solver introduced in the following paper: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

Forward pass of the model.

Arguments:

  • inputs torch.Tensor - Input tensor representing the problem state.

Returns:

  • torch.Tensor - Predicted distribution over possible solutions.

Model

class Model(nn.Module)

An architecture better than Model.

Changes:

  • Remove ReLU activation from the first layer (embedding), which had the dying ReLU problem.
  • Following the recent convention, the embedding layer does not count as one hidden layer.

reset_parameters

def reset_parameters()

Initialize all weight such that the activation variances are approximately 1.0, with bias terms & the output layer weights being zeros

forward

def forward(inputs)

Forward pass of the model.

Arguments:

  • inputs torch.Tensor - Input tensor representing the problem state.

Returns:

  • torch.Tensor - Predicted distribution over possible solutions.

LinearBlock

class LinearBlock(nn.Module)

A building block for the MLP architecture.

This block consists of a linear layer followed by ReLU activation and batch normalization.

forward

def forward(inputs)

Forward pass of the linear block.

Arguments:

  • inputs torch.Tensor - Input tensor.

Returns:

  • torch.Tensor - Output tensor after linear transformation, ReLU activation, and batch normalization.