Saltar al contenido principal

model

Model Loader

Este módulo proporciona una función y clases para cargar y usar un modelo de solucionador AlphaCube preentrenado.

Function: load_model(model_id, cache_dir): Carga el modelo de solucionador AlphaCube preentrenado.

Classes:

  • Model: La arquitectura MLP para el solucionador AlphaCube.
  • LinearBlock: Un bloque de construcción para la arquitectura MLP.

load_model

def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)

Carga el modelo de solucionador AlphaCube preentrenado.

Argumentos:

  • model_id str - Identificador para la variante del modelo a cargar ("small", "base" o "large").
  • cache_dir str - Directorio para almacenar en caché el modelo descargado.

Retorna:

  • nn.Module - Modelo de solucionador AlphaCube cargado.

Model_v1

class Model_v1(nn.Module)

La arquitectura MLP para el solucionador de cubo de Rubik computacionalmente óptimo introducido en el siguiente artículo: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

Paso hacia adelante (forward pass) del modelo.

Argumentos:

  • inputs torch.Tensor - Tensor de entrada que representa el estado del problema.

Retorna:

  • torch.Tensor - Distribución predicha sobre las posibles soluciones.

Model

class Model(nn.Module)

Una arquitectura mejor que Model.

Cambios:

  • Elimina la activación ReLU de la primera capa (embedding), que tenía el problema del ReLU moribundo (dying ReLU).
  • Siguiendo la convención reciente, la capa embedding no cuenta como una capa oculta.

reset_parameters

def reset_parameters()

Inicializa todos los pesos de tal manera que las varianzas de activación sean aproximadamente 1.0, con los términos de sesgo y los pesos de la capa de salida siendo ceros.

forward

def forward(inputs)

Paso hacia adelante (forward pass) del modelo.

Argumentos:

  • inputs torch.Tensor - Tensor de entrada que representa el estado del problema.

Retorna:

  • torch.Tensor - Distribución predicha sobre las posibles soluciones.

LinearBlock

class LinearBlock(nn.Module)

Un bloque de construcción para la arquitectura MLP.

Este bloque consiste en una capa lineal seguida de una activación ReLU y normalización por lotes (batch normalization).

forward

def forward(inputs)

Paso hacia adelante (forward pass) del bloque lineal.

Argumentos:

  • inputs torch.Tensor - Tensor de entrada.

Retorna:

  • torch.Tensor - Tensor de salida después de la transformación lineal, activación ReLU y normalización por lotes.