model
Cargador de Modelo
Este módulo proporciona una función y clases para cargar y usar un modelo pre-entrenado de solucionador AlphaCube.
Función:
load_model(model_id, cache_dir)
: Carga el modelo pre-entrenado de solucionador AlphaCube.
Clases:
Model
: La arquitectura MLP para el solucionador AlphaCube.LinearBlock
: Un bloque de construcción para la arquitectura MLP.
load_model
def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))
Carga el modelo pre-entrenado de solucionador AlphaCube.
Argumentos:
model_id
str - Identificador para la variante del modelo a cargar ("small", "base", o "large").cache_dir
str - Directorio para almacenar en caché el modelo descargado.
Retorna:
nn.Module
- Modelo cargado de solucionador AlphaCube.
Model_v1
class Model_v1(nn.Module)
La arquitectura MLP para el solucionador óptimo en cómputo del Cubo de Rubik introducido en el siguiente artículo: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Paso hacia adelante del modelo.
Argumentos:
inputs
torch.Tensor - Tensor de entrada que representa el estado del problema.
Retorna:
torch.Tensor
- Distribución predicha sobre las posibles soluciones.
Model
class Model(nn.Module)
Una arquitectura mejor que Model
.
Cambios:
- Elimina la activación ReLU de la primera capa (
embedding
), que tenía el problema de ReLU moribundo. - Siguiendo la convención reciente, la capa
embedding
no cuenta como una capa oculta.
reset_parameters
def reset_parameters()
Inicializa todos los pesos de manera que las varianzas de activación sean aproximadamente 1.0, con términos de sesgo y los pesos de la capa de salida siendo ceros.
forward
def forward(inputs)
Paso hacia adelante del modelo.
Argumentos:
inputs
torch.Tensor - Tensor de entrada que representa el estado del problema.
Retorna:
torch.Tensor
- Distribución predicha sobre las posibles soluciones.
LinearBlock
class LinearBlock(nn.Module)
Un bloque de construcción para la arquitectura MLP.
Este bloque consiste en una capa lineal seguida de activación ReLU y normalización por lotes.
forward
def forward(inputs)
Paso hacia adelante del bloque lineal.
Argumentos:
inputs
torch.Tensor - Tensor de entrada.
Retorna:
torch.Tensor
- Tensor de salida después de la transformación lineal, activación ReLU y normalización por lotes.