メインコンテンツまでスキップ

Model Loader

このモジュールは、学習済みのAlphaCubeソルバーモデルを読み込んで使用するための関数とクラスを提供します。

Function: load_model(model_id, cache_dir): 学習済みのAlphaCubeソルバーモデルを読み込みます。

Classes:

  • Model: AlphaCubeソルバーのMLPアーキテクチャです。
  • LinearBlock: MLPアーキテクチャの構成要素です。

load_model

def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)

学習済みのAlphaCubeソルバーモデルを読み込みます。

引数:

  • model_id str - 読み込むモデルのバリアント("small"、"base"、または "large")の識別子です。
  • cache_dir str - ダウンロードしたモデルをキャッシュするディレクトリです。

戻り値:

  • nn.Module - 読み込まれたAlphaCubeソルバーモデルです。

Model_v1

class Model_v1(nn.Module)

以下の論文で紹介された、計算量最適化されたルービックキューブソルバーのためのMLPアーキテクチャです: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

モデルのフォワードパスです。

引数:

  • inputs torch.Tensor - 問題の状態を表す入力テンソルです。

戻り値:

  • torch.Tensor - 予測された解候補の分布です。

Model

class Model(nn.Module)

Model_v1よりも優れたアーキテクチャです。

変更点:

  • 最初の層(embedding)から、Dying ReLU問題を引き起こしていたReLU活性化関数を削除しました。
  • 最近の慣例に従い、embedding層は隠れ層の1つとしてカウントしません

reset_parameters

def reset_parameters()

活性化の分散が約1.0になるようにすべての重みを初期化し、バイアス項と出力層の重みはゼロに初期化します。

forward

def forward(inputs)

モデルのフォワードパスです。

引数:

  • inputs torch.Tensor - 問題の状態を表す入力テンソルです。

戻り値:

  • torch.Tensor - 予測された解候補の分布です。

LinearBlock

class LinearBlock(nn.Module)

MLPアーキテクチャの構成要素です。

このブロックは、線形層、ReLU活性化関数、バッチ正規化で構成されています。

forward

def forward(inputs)

線形ブロックのフォワードパスです。

引数:

  • inputs torch.Tensor - 入力テンソルです。

戻り値:

  • torch.Tensor - 線形変換、ReLU活性化、バッチ正規化後の出力テンソルです。