Zum Hauptinhalt springen

model

Model Loader

Dieses Modul stellt eine Funktion und Klassen zum Laden und Verwenden eines vortrainierten AlphaCube-Lösermodells bereit.

Funktion: load_model(model_id, cache_dir): Lädt das vortrainierte AlphaCube-Lösermodell.

Klassen:

  • Model: Die MLP-Architektur für den AlphaCube-Löser.
  • LinearBlock: Ein Baustein für die MLP-Architektur.

load_model

def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)

Lädt das vortrainierte AlphaCube-Lösermodell.

Argumente:

  • model_id str - Bezeichner für die zu ladende Modellvariante ("small", "base" oder "large").
  • cache_dir str - Verzeichnis zum Zwischenspeichern des heruntergeladenen Modells.

Rückgabewert:

  • nn.Module - Geladenes AlphaCube-Lösermodell.

Model_v1

class Model_v1(nn.Module)

Die MLP-Architektur für den rechenoptimalen Zauberwürfel-Löser, der in folgendem Artikel vorgestellt wird: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

Forward-Pass des Modells.

Argumente:

  • inputs torch.Tensor - Eingabe-Tensor, der den Problemzustand darstellt.

Rückgabewert:

  • torch.Tensor - Vorhergesagte Verteilung über mögliche Lösungen.

Model

class Model(nn.Module)

Eine Architektur, die besser ist als Model.

Änderungen:

  • Entfernen der ReLU-Aktivierung aus der ersten Schicht (embedding), die das Problem des „sterbenden ReLU“ hatte.
  • Der jüngsten Konvention folgend, zählt die embedding-Schicht nicht als eine versteckte Schicht.

reset_parameters

def reset_parameters()

Initialisiert alle Gewichte so, dass die Aktivierungsvarianzen ungefähr 1.0 betragen, wobei die Bias-Terme und die Gewichte der Ausgabeschicht Nullen sind

forward

def forward(inputs)

Forward-Pass des Modells.

Argumente:

  • inputs torch.Tensor - Eingabe-Tensor, der den Problemzustand darstellt.

Rückgabewert:

  • torch.Tensor - Vorhergesagte Verteilung über mögliche Lösungen.

LinearBlock

class LinearBlock(nn.Module)

Ein Baustein für die MLP-Architektur.

Dieser Block besteht aus einer linearen Schicht, gefolgt von ReLU-Aktivierung und Batch-Normalisierung.

forward

def forward(inputs)

Forward-Pass des linearen Blocks.

Argumente:

  • inputs torch.Tensor - Eingabe-Tensor.

Rückgabewert:

  • torch.Tensor - Ausgabe-Tensor nach linearer Transformation, ReLU-Aktivierung und Batch-Normalisierung.