model
Model Loader
Este módulo proporciona una función y clases para cargar y usar un modelo de solucionador AlphaCube preentrenado.
Function:
load_model(model_id, cache_dir)
: Carga el modelo de solucionador AlphaCube preentrenado.
Classes:
Model
: La arquitectura MLP para el solucionador AlphaCube.LinearBlock
: Un bloque de construcción para la arquitectura MLP.
load_model
def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)
Carga el modelo de solucionador AlphaCube preentrenado.
Argumentos:
model_id
str - Identificador para la variante del modelo a cargar ("small", "base" o "large").cache_dir
str - Directorio para almacenar en caché el modelo descargado.
Retorna:
nn.Module
- Modelo de solucionador AlphaCube cargado.
Model_v1
class Model_v1(nn.Module)
La arquitectura MLP para el solucionador de cubo de Rubik computacionalmente óptimo introducido en el siguiente artículo: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Paso hacia adelante (forward pass) del modelo.
Argumentos:
inputs
torch.Tensor - Tensor de entrada que representa el estado del problema.
Retorna:
torch.Tensor
- Distribución predicha sobre las posibles soluciones.
Model
class Model(nn.Module)
Una arquitectura mejor que Model
.
Cambios:
- Elimina la activación ReLU de la primera capa (
embedding
), que tenía el problema del ReLU moribundo (dying ReLU). - Siguiendo la convención reciente, la capa
embedding
no cuenta como una capa oculta.
reset_parameters
def reset_parameters()
Inicializa todos los pesos de tal manera que las varianzas de activación sean aproximadamente 1.0, con los términos de sesgo y los pesos de la capa de salida siendo ceros.
forward
def forward(inputs)
Paso hacia adelante (forward pass) del modelo.
Argumentos:
inputs
torch.Tensor - Tensor de entrada que representa el estado del problema.
Retorna:
torch.Tensor
- Distribución predicha sobre las posibles soluciones.
LinearBlock
class LinearBlock(nn.Module)
Un bloque de construcción para la arquitectura MLP.
Este bloque consiste en una capa lineal seguida de una activación ReLU y normalización por lotes (batch normalization).
forward
def forward(inputs)
Paso hacia adelante (forward pass) del bloque lineal.
Argumentos:
inputs
torch.Tensor - Tensor de entrada.
Retorna:
torch.Tensor
- Tensor de salida después de la transformación lineal, activación ReLU y normalización por lotes.