model
Module de chargement de modèle
Ce module fournit une fonction et des classes pour charger et utiliser un modèle pré-entraîné de solveur AlphaCube.
Fonction :
load_model(model_id, cache_dir)
: Charge le modèle pré-entraîné de solveur AlphaCube.
Classes :
Model
: L'architecture MLP pour le solveur AlphaCube.LinearBlock
: Un bloc de construction pour l'architecture MLP.
load_model
def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))
Charge le modèle pré-entraîné de solveur AlphaCube.
Arguments:
model_id
str - Identifiant de la variante du modèle à charger ("small", "base" ou "large").cache_dir
str - Répertoire pour mettre en cache le modèle téléchargé.
Retourne:
nn.Module
- Modèle de solveur AlphaCube chargé.
Model_v1
class Model_v1(nn.Module)
L'architecture MLP pour le solveur de Rubik's Cube optimisé pour le calcul introduit dans l'article suivant : https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Passe avant du modèle.
Arguments:
inputs
torch.Tensor - Tenseur d'entrée représentant l'état du problème.
Retourne:
torch.Tensor
- Distribution prédite sur les solutions possibles.
Model
class Model(nn.Module)
Une architecture meilleure que Model
.
Changements:
- Suppression de l'activation ReLU de la première couche (
embedding
), qui avait le problème de ReLU mourant. - Suivant la convention récente, la couche
embedding
ne compte pas comme une couche cachée.
reset_parameters
def reset_parameters()
Initialise tous les poids de sorte que les variances d'activation soient d'environ 1.0, les termes de biais et les poids de la couche de sortie étant des zéros.
forward
def forward(inputs)
Passe avant du modèle.
Arguments:
inputs
torch.Tensor - Tenseur d'entrée représentant l'état du problème.
Retourne:
torch.Tensor
- Distribution prédite sur les solutions possibles.
LinearBlock
class LinearBlock(nn.Module)
Un bloc de construction pour l'architecture MLP.
Ce bloc est constitué d'une couche linéaire suivie d'une activation ReLU et d'une normalisation par lots.
forward
def forward(inputs)
Passe avant du bloc linéaire.
Arguments:
inputs
torch.Tensor - Tenseur d'entrée.
Retourne:
torch.Tensor
- Tenseur de sortie après transformation linéaire, activation ReLU et normalisation par lots.