model
Chargeur de modèle
Ce module fournit une fonction et des classes pour charger et utiliser un modèle de solveur AlphaCube pré-entraîné.
Fonction :
load_model(model_id, cache_dir)
: Charge le modèle de solveur AlphaCube pré-entraîné.
Classes :
Model
: L'architecture MLP pour le solveur AlphaCube.LinearBlock
: Un bloc de construction pour l'architecture MLP.
load_model
def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)
Charge le modèle de solveur AlphaCube pré-entraîné.
Arguments :
model_id
str - Identifiant de la variante du modèle à charger (« small », « base » ou « large »).cache_dir
str - Répertoire pour mettre en cache le modèle télécharg é.
Retourne :
nn.Module
- Modèle de solveur AlphaCube chargé.
Model_v1
class Model_v1(nn.Module)
L'architecture MLP pour le solveur de Rubik's Cube à calcul optimal introduite dans l'article suivant : https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Passe avant du modèle.
Arguments :
inputs
torch.Tensor - Tenseur d'entrée représentant l'état du problème.
Retourne :
torch.Tensor
- Distribution prédite sur les solutions possibles.
Model
class Model(nn.Module)
Une architecture meilleure que Model
.
Changements :
- Suppression de l'activation ReLU de la première couche (
embedding
), qui souffrait du problème du « dying ReLU » (ReLU mourant). - Suivant la convention récente, la couche
embedding
ne compte pas comme une couche cachée.
reset_parameters
def reset_parameters()
Initialise tous les poids de sorte que les variances d'activation soient approximativement de 1.0, avec les termes de biais et les poids de la couche de sortie étant à zéro.
forward
def forward(inputs)
Passe avant du modèle.
Arguments :
inputs
torch.Tensor - Tenseur d'entrée représentant l'état du problème.
Retourne :
torch.Tensor
- Distribution prédite sur les solutions possibles.
LinearBlock
class LinearBlock(nn.Module)
Un bloc de construction pour l'architecture MLP.
Ce bloc se compose d'une couche linéaire suivie d'une activation ReLU et d'une normalisation par lots (batch normalization).
forward
def forward(inputs)
Passe avant du bloc linéaire.
Arguments :
inputs
torch.Tensor - Tenseur d'entrée.
Retourne :
torch.Tensor
- Tenseur de sortie après transformation linéaire, activation ReLU et normalisation par lots.