model
Chargeur de modèle
Ce module fournit une fonction et des classes pour charger et utiliser un modèle de solveur AlphaCube pré-entraîné.
Fonction :
load_model(model_id, cache_dir) : Charge le modèle de solveur AlphaCube pré-entraîné.
Classes :
Model: L'architecture MLP pour le solveur AlphaCube.LinearBlock: Un bloc de construction pour l'architecture MLP.
load_model
def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)
Charge le modèle de solveur AlphaCube pré-entraîné.
Arguments :
model_idstr - Identifiant de la variante du modèle à charger (« small », « base » ou « large »).cache_dirstr - Répertoire pour mettre en cache le modèle télécharg é.
Retourne :
nn.Module- Modèle de solveur AlphaCube chargé.
Model_v1
class Model_v1(nn.Module)
L'architecture MLP pour le solveur de Rubik's Cube à calcul optimal introduite dans l'article suivant : https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Passe avant du modèle.
Arguments :
inputstorch.Tensor - Tenseur d'entrée représentant l'état du problème.
Retourne :
torch.Tensor- Distribution prédite sur les solutions possibles.
Model
class Model(nn.Module)
Une architecture meilleure que Model.
Changements :
- Suppression de l'activation ReLU de la première couche (
embedding), qui souffrait du problème du « dying ReLU » (ReLU mourant). - Suivant la convention récente, la couche
embeddingne compte pas comme une couche cachée.
reset_parameters
def reset_parameters()
Initialise tous les poids de sorte que les variances d'activation soient approximativement de 1.0, avec les termes de biais et les poids de la couche de sortie étant à zéro.
forward
def forward(inputs)
Passe avant du modèle.
Arguments :
inputstorch.Tensor - Tenseur d'entrée représentant l'état du problème.
Retourne :
torch.Tensor- Distribution prédite sur les solutions possibles.
LinearBlock
class LinearBlock(nn.Module)
Un bloc de construction pour l'architecture MLP.
Ce bloc se compose d'une couche linéaire suivie d'une activation ReLU et d'une normalisation par lots (batch normalization).
forward
def forward(inputs)
Passe avant du bloc linéaire.
Arguments :
inputstorch.Tensor - Tenseur d'entrée.
Retourne :
torch.Tensor- Tenseur de sortie après transformation linéaire, activation ReLU et normalisation par lots.