model
Moduł ładowania modeli
Ten moduł dostarcza funkcję i klasy do ładowania i używania wstępnie wytrenowanego modelu solwera AlphaCube.
Funkcja:
load_model(model_id, cache_dir)
: Ładuje wstępnie wytrenowany model solwera AlphaCube.
Klasy:
Model
: Architektura MLP dla solwera AlphaCube.LinearBlock
: Blok budulcowy dla architektury MLP.
load_model
def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)
Ładuje wstępnie wytrenowany model solwera AlphaCube.
Argumenty:
model_id
str - Identyfikator wariantu modelu do załadowania ("small", "base" lub "large").cache_dir
str - Katalog do przechowywania pobranego modelu w pamięci podręcznej.
Zwraca:
nn.Module
- Załadowany model solwera AlphaCube.
Model_v1
class Model_v1(nn.Module)
Architektura MLP dla optymalnego obliczeniowo solwera Kostki Rubika, wprowadzona w następującej publikacji: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Przejście w przód modelu.
Argumenty:
inputs
torch.Tensor - Tensor wejściowy reprezentujący stan problemu.
Zwraca:
torch.Tensor
- Przewidywany rozkład prawdopodobieństwa dla możliwych rozwiązań.
Model
class Model(nn.Module)
Architektura lepsza niż Model
.
Zmiany:
- Usunięcie aktywacji ReLU z pierwszej warstwy (
embedding
), która miała problem "umierającego ReLU" (dying ReLU). - Zgodnie z najnowszą konwencją, warstwa
embedding
nie jest liczona jako jedna warstwa ukryta.
reset_parameters
def reset_parameters()
Inicjalizuje wszystkie wagi tak, aby wariancje aktywacji wynosiły w przybliżeniu 1.0, przy czym wartości bias i wagi warstwy wyjściowej są zerowe.
forward
def forward(inputs)
Przejście w przód modelu.
Argumenty:
inputs
torch.Tensor - Tensor wejściowy reprezentujący stan problemu.
Zwraca:
torch.Tensor
- Przewidywany rozkład prawdopodobieństwa dla możliwych rozwiązań.
LinearBlock
class LinearBlock(nn.Module)
Blok budulcowy dla architektury MLP.
Ten blok składa się z warstwy liniowej, po której następuje aktywacja ReLU i normalizacja wsadowa (batch normalization).
forward
def forward(inputs)
Przejście w przód bloku liniowego.
Argumenty:
inputs
torch.Tensor - Tensor wejściowy.
Zwraca:
torch.Tensor
- Tensor wyjściowy po transformacji liniowej, aktywacji ReLU i normalizacji wsadowej.