Przejdź do głównej zawartości

model

Moduł ładowania modeli

Ten moduł dostarcza funkcję i klasy do ładowania i używania wstępnie wytrenowanego modelu solwera AlphaCube.

Funkcja: load_model(model_id, cache_dir): Ładuje wstępnie wytrenowany model solwera AlphaCube.

Klasy:

  • Model: Architektura MLP dla solwera AlphaCube.
  • LinearBlock: Blok budulcowy dla architektury MLP.

load_model

def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)

Ładuje wstępnie wytrenowany model solwera AlphaCube.

Argumenty:

  • model_id str - Identyfikator wariantu modelu do załadowania ("small", "base" lub "large").
  • cache_dir str - Katalog do przechowywania pobranego modelu w pamięci podręcznej.

Zwraca:

  • nn.Module - Załadowany model solwera AlphaCube.

Model_v1

class Model_v1(nn.Module)

Architektura MLP dla optymalnego obliczeniowo solwera Kostki Rubika, wprowadzona w następującej publikacji: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

Przejście w przód modelu.

Argumenty:

  • inputs torch.Tensor - Tensor wejściowy reprezentujący stan problemu.

Zwraca:

  • torch.Tensor - Przewidywany rozkład prawdopodobieństwa dla możliwych rozwiązań.

Model

class Model(nn.Module)

Architektura lepsza niż Model.

Zmiany:

  • Usunięcie aktywacji ReLU z pierwszej warstwy (embedding), która miała problem "umierającego ReLU" (dying ReLU).
  • Zgodnie z najnowszą konwencją, warstwa embedding nie jest liczona jako jedna warstwa ukryta.

reset_parameters

def reset_parameters()

Inicjalizuje wszystkie wagi tak, aby wariancje aktywacji wynosiły w przybliżeniu 1.0, przy czym wartości bias i wagi warstwy wyjściowej są zerowe.

forward

def forward(inputs)

Przejście w przód modelu.

Argumenty:

  • inputs torch.Tensor - Tensor wejściowy reprezentujący stan problemu.

Zwraca:

  • torch.Tensor - Przewidywany rozkład prawdopodobieństwa dla możliwych rozwiązań.

LinearBlock

class LinearBlock(nn.Module)

Blok budulcowy dla architektury MLP.

Ten blok składa się z warstwy liniowej, po której następuje aktywacja ReLU i normalizacja wsadowa (batch normalization).

forward

def forward(inputs)

Przejście w przód bloku liniowego.

Argumenty:

  • inputs torch.Tensor - Tensor wejściowy.

Zwraca:

  • torch.Tensor - Tensor wyjściowy po transformacji liniowej, aktywacji ReLU i normalizacji wsadowej.