Przejdź do głównej zawartości

model

Moduł ładowania modelu

Ten moduł udostępnia funkcję i klasy do ładowania i używania wstępnie wytrenowanego modelu solvera AlphaCube.

Funkcja: load_model(model_id, cache_dir): Ładuje wstępnie wytrenowany model solvera AlphaCube.

Klasy:

  • Model: Architektura MLP dla solvera AlphaCube.
  • LinearBlock: Blok budulcowy dla architektury MLP.

load_model

def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))

Ładuje wstępnie wytrenowany model solvera AlphaCube.

Argumenty:

  • model_id str - Identyfikator wariantu modelu do załadowania ("small", "base" lub "large").
  • cache_dir str - Katalog do przechowywania pobranego modelu.

Zwraca:

  • nn.Module - Załadowany model solvera AlphaCube.

Model_v1

class Model_v1(nn.Module)

Architektura MLP dla optymalnego obliczeniowo solvera kostki Rubika wprowadzona w następującym artykule: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

Przejście w przód modelu.

Argumenty:

  • inputs torch.Tensor - Tensor wejściowy reprezentujący stan problemu.

Zwraca:

  • torch.Tensor - Przewidywany rozkład możliwych rozwiązań.

Model

class Model(nn.Module)

Architektura lepsza niż Model.

Zmiany:

  • Usunięcie aktywacji ReLU z pierwszej warstwy (embedding), która miała problem z zanikającym ReLU.
  • Zgodnie z najnowszą konwencją, warstwa embedding nie jest liczona jako jedna warstwa ukryta.

reset_parameters

def reset_parameters()

Inicjalizuje wszystkie wagi tak, aby wariancje aktywacji wynosiły w przybliżeniu 1.0, przy czym wyrazy bias i wagi warstwy wyjściowej są zerami.

forward

def forward(inputs)

Przejście w przód modelu.

Argumenty:

  • inputs torch.Tensor - Tensor wejściowy reprezentujący stan problemu.

Zwraca:

  • torch.Tensor - Przewidywany rozkład możliwych rozwiązań.

LinearBlock

class LinearBlock(nn.Module)

Blok budulcowy dla architektury MLP.

Ten blok składa się z warstwy liniowej, po której następuje aktywacja ReLU i normalizacja wsadowa.

forward

def forward(inputs)

Przejście w przód bloku liniowego.

Argumenty:

  • inputs torch.Tensor - Tensor wejściowy.

Zwraca:

  • torch.Tensor - Tensor wyjściowy po transformacji liniowej, aktywacji ReLU i normalizacji wsadowej.