Перейти к основному содержимому

model

Загрузчик модели

Этот модуль предоставляет функцию и классы для загрузки и использования предварительно обученной модели решателя AlphaCube.

Функция: load_model(model_id, cache_dir): Загружает предварительно обученную модель решателя AlphaCube.

Классы:

  • Model: Архитектура MLP для решателя AlphaCube.
  • LinearBlock: Строительный блок для архитектуры MLP.

load_model

def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)

Загружает предварительно обученную модель решателя AlphaCube.

Аргументы:

  • model_id str - Идентификатор варианта модели для загрузки ("small", "base" или "large").
  • cache_dir str - Директория для кэширования загруженной модели.

Возвращает:

  • nn.Module - Загруженная модель решателя AlphaCube.

Model_v1

class Model_v1(nn.Module)

Архитектура MLP для вычислительно-оптимального решателя кубика Рубика, представленная в следующей статье: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

Прямой проход модели.

Аргументы:

  • inputs torch.Tensor - Входной тензор, представляющий состояние задачи.

Возвращает:

  • torch.Tensor - Предсказанное распределение по возможным решениям.

Model

class Model(nn.Module)

Архитектура, превосходящая Model.

Изменения:

  • Удалена активация ReLU из первого слоя (embedding), который страдал от проблемы затухающего ReLU.
  • В соответствии с последними соглашениями, слой embedding не считается скрытым слоем.

reset_parameters

def reset_parameters()

Инициализирует все веса так, чтобы дисперсии активаций были примерно равны 1.0, а смещения и веса выходного слоя были равны нулю.

forward

def forward(inputs)

Прямой проход модели.

Аргументы:

  • inputs torch.Tensor - Входной тензор, представляющий состояние задачи.

Возвращает:

  • torch.Tensor - Предсказанное распределение по возможным решениям.

LinearBlock

class LinearBlock(nn.Module)

Строительный блок для архитектуры MLP.

Этот блок состоит из линейного слоя, за которым следуют активация ReLU и батч-нормализация.

forward

def forward(inputs)

Прямой проход линейного блока.

Аргументы:

  • inputs torch.Tensor - Входной тензор.

Возвращает:

  • torch.Tensor - Выходной тензор после линейного преобразования, активации ReLU и батч-нормализации.