model
Modell Betöltő
Ez a modul egy függvényt és osztályokat biztosít egy előre betanított AlphaCube megoldó modell betöltéséhez és használatához.
Függvény:
load_model(model_id, cache_dir)
: Betölti az előre betanított AlphaCube megoldó modellt.
Osztályok:
Model
: Az MLP architektúra az AlphaCube megoldóhoz.LinearBlock
: Egy építőelem az MLP architektúrához.
load_model
def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)
Betölti az előre betanított AlphaCube megoldó modellt.
Argumentumok:
model_id
str - A betöltendő modellváltozat azonosítója ("small", "base" vagy "large").cache_dir
str - Könyvtár a letöltött modell gyorsítótárazásához.
Visszatérési érték:
nn.Module
- A betöltött AlphaCube megoldó modell.
Model_v1
class Model_v1(nn.Module)
Az alábbi tanulmányban bemutatott, számítási szempontból optimális Rubik-kocka megoldó MLP architektúrája: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
A modell előrecsatolási (forward pass) fázisa.
Argumentumok:
inputs
torch.Tensor - A probléma állapotát reprezentáló bemeneti tenzor.
Visszatérési érték:
torch.Tensor
- A lehetséges megoldások feletti prediktált eloszlás.
Model
class Model(nn.Module)
Egy, a Model_v1
-nél jobb architektúra.
Változtatások:
- A ReLU aktiváció eltávolítása az első rétegből (
embedding
), amely a "halott ReLU" problémával küzdött. - A legújabb konvenciókat követve az
embedding
réteg nem számít egy rejtett rétegnek.
reset_parameters
def reset_parameters()
Minden súly inicializálása úgy, hogy az aktivációs varianciák körülbelül 1.0 legyenek, a torzítási (bias) tagok és a kimeneti réteg súlyai pedig nullák.
forward
def forward(inputs)
A modell előrecsatolási (forward pass) fázisa.
Argumentumok:
inputs
torch.Tensor - A probléma állapotát reprezentáló bemeneti tenzor.
Visszatérési érték:
torch.Tensor
- A lehetséges megoldások feletti prediktált eloszlás.
LinearBlock
class LinearBlock(nn.Module)
Egy építőelem az MLP architektúrához.
Ez a blokk egy lineáris rétegből, majd egy ReLU aktivációból és egy kötegelt normalizálásból (batch normalization) áll.
forward
def forward(inputs)
A lineáris blokk előrecsatolási (forward pass) fázisa.
Argumentumok:
inputs
torch.Tensor - Bemeneti tenzor.
Visszatérési érték:
torch.Tensor
- Kimeneti tenzor a lineáris transzformáció, a ReLU aktiváció és a kötegelt normalizálás után.