model
Moduł ładowania modelu
Ten moduł udostępnia funkcję i klasy do ładowania i używania wstępnie wytrenowanego modelu solvera AlphaCube.
Funkcja:
load_model(model_id, cache_dir)
: Ładuje wstępnie wytrenowany model solvera AlphaCube.
Klasy:
Model
: Architektura MLP dla solvera AlphaCube.LinearBlock
: Blok budulcowy dla architektury MLP.
load_model
def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))
Ładuje wstępnie wytrenowany model solvera AlphaCube.
Argumenty:
model_id
str - Identyfikator wariantu modelu do załadowania ("small", "base" lub "large").cache_dir
str - Katalog do przechowywania pobranego modelu.
Zwraca:
nn.Module
- Załadowany model solvera AlphaCube.
Model_v1
class Model_v1(nn.Module)
Architektura MLP dla optymalnego obliczeniowo solvera kostki Rubika wprowadzona w następującym artykule: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Przejście w przód modelu.
Argumenty:
inputs
torch.Tensor - Tensor wejściowy reprezentujący stan problemu.
Zwraca:
torch.Tensor
- Przewidywany rozkład możliwych rozwiązań.
Model
class Model(nn.Module)
Architektura lepsza niż Model
.
Zmiany:
- Usunięcie aktywacji ReLU z pierwszej warstwy (
embedding
), która miała problem z zanikającym ReLU. - Zgodnie z najnowszą konwencją, warstwa
embedding
nie jest liczona jako jedna warstwa ukryta.
reset_parameters
def reset_parameters()
Inicjalizuje wszystkie wagi tak, aby wariancje aktywacji wynosiły w przybliżeniu 1.0, przy czym wyrazy bias i wagi warstwy wyjściowej są zerami.
forward
def forward(inputs)
Przejście w przód modelu.
Argumenty:
inputs
torch.Tensor - Tensor wejściowy reprezentujący stan problemu.
Zwraca:
torch.Tensor
- Przewidywany rozkład możliwych rozwiązań.
LinearBlock
class LinearBlock(nn.Module)
Blok budulcowy dla architektury MLP.
Ten blok składa się z warstwy liniowej, po której następuje aktywacja ReLU i normalizacja wsadowa.
forward
def forward(inputs)
Przejście w przód bloku liniowego.
Argumenty:
inputs
torch.Tensor - Tensor wejściowy.
Zwraca:
torch.Tensor
- Tensor wyjściowy po transformacji liniowej, aktywacji ReLU i normalizacji wsadowej.