Aller au contenu principal

model

Module de chargement de modèle

Ce module fournit une fonction et des classes pour charger et utiliser un modèle pré-entraîné de solveur AlphaCube.

Fonction : load_model(model_id, cache_dir) : Charge le modèle pré-entraîné de solveur AlphaCube.

Classes :

  • Model : L'architecture MLP pour le solveur AlphaCube.
  • LinearBlock : Un bloc de construction pour l'architecture MLP.

load_model

def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))

Charge le modèle pré-entraîné de solveur AlphaCube.

Arguments:

  • model_id str - Identifiant de la variante du modèle à charger ("small", "base" ou "large").
  • cache_dir str - Répertoire pour mettre en cache le modèle téléchargé.

Retourne:

  • nn.Module - Modèle de solveur AlphaCube chargé.

Model_v1

class Model_v1(nn.Module)

L'architecture MLP pour le solveur de Rubik's Cube optimisé pour le calcul introduit dans l'article suivant : https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

Passe avant du modèle.

Arguments:

  • inputs torch.Tensor - Tenseur d'entrée représentant l'état du problème.

Retourne:

  • torch.Tensor - Distribution prédite sur les solutions possibles.

Model

class Model(nn.Module)

Une architecture meilleure que Model.

Changements:

  • Suppression de l'activation ReLU de la première couche (embedding), qui avait le problème de ReLU mourant.
  • Suivant la convention récente, la couche embedding ne compte pas comme une couche cachée.

reset_parameters

def reset_parameters()

Initialise tous les poids de sorte que les variances d'activation soient d'environ 1.0, les termes de biais et les poids de la couche de sortie étant des zéros.

forward

def forward(inputs)

Passe avant du modèle.

Arguments:

  • inputs torch.Tensor - Tenseur d'entrée représentant l'état du problème.

Retourne:

  • torch.Tensor - Distribution prédite sur les solutions possibles.

LinearBlock

class LinearBlock(nn.Module)

Un bloc de construction pour l'architecture MLP.

Ce bloc est constitué d'une couche linéaire suivie d'une activation ReLU et d'une normalisation par lots.

forward

def forward(inputs)

Passe avant du bloc linéaire.

Arguments:

  • inputs torch.Tensor - Tenseur d'entrée.

Retourne:

  • torch.Tensor - Tenseur de sortie après transformation linéaire, activation ReLU et normalisation par lots.