model
Model Loader
Dieses Modul stellt eine Funktion und Klassen zum Laden und Verwenden eines vortrainierten AlphaCube Solver-Modells bereit.
Funktion:
load_model(model_id, cache_dir)
: Lädt das vortrainierte AlphaCube Solver-Modell.
Klassen:
Model
: Die MLP-Architektur für den AlphaCube Solver.LinearBlock
: Ein Baustein für die MLP-Architektur.
load_model
def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))
Lädt das vortrainierte AlphaCube Solver-Modell.
Argumente:
model_id
str - Identifikator für die zu ladende Modellvariante ("small", "base" oder "large").cache_dir
str - Verzeichnis zum Zwischenspeichern des heruntergeladenen Modells.
Rückgabewert:
nn.Module
- Geladenes AlphaCube Solver-Modell.
Model_v1
class Model_v1(nn.Module)
Die MLP-Architektur für den rechenoptimalen Rubik's Cube Solver, der in folgendem Paper vorgestellt wurde: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Vorwärtsdurchlauf des Modells.
Argumente:
inputs
torch.Tensor - Eingabetensor, der den Problemzustand repräsentiert.
Rückgabewert:
torch.Tensor
- Vorhergesagte Verteilung über mögliche Lösungen.
Model
class Model(nn.Module)
Eine Architektur, die besser ist als Model
.
Änderungen:
- Entfernung der ReLU-Aktivierung aus der ersten Schicht (
embedding
), die das Problem der sterbenden ReLU hatte. - Der Konvention folgend zählt die
embedding
-Schicht nicht als eine versteckte Schicht.
reset_parameters
def reset_parameters()
Initialisiert alle Gewichte so, dass die Aktivierungsvarianzen ungefähr 1,0 betragen, wobei die Bias-Terme und die Gewichte der Ausgabeschicht Nullen sind.
forward
def forward(inputs)
Vorwärtsdurchlauf des Modells.
Argumente:
inputs
torch.Tensor - Eingabetensor, der den Problemzustand repräsentiert.
Rückgabewert:
torch.Tensor
- Vorhergesagte Verteilung über mögliche Lösungen.
LinearBlock
class LinearBlock(nn.Module)
Ein Baustein für die MLP-Architektur.
Dieser Block besteht aus einer linearen Schicht, gefolgt von einer ReLU-Aktivierung und Batch-Normalisierung.
forward
def forward(inputs)
Vorwärtsdurchlauf des linearen Blocks.
Argumente:
inputs
torch.Tensor - Eingabetensor.
Rückgabewert:
torch.Tensor
- Ausgabetensor nach linearer Transformation, ReLU-Aktivierung und Batch-Normalisierung.