メインコンテンツまでスキップ

model

モデルローダー

このモジュールは、事前学習済みのAlphaCubeソルバーモデルを読み込んで使用するための関数とクラスを提供します。

関数: load_model(model_id, cache_dir): 事前学習済みのAlphaCubeソルバーモデルを読み込みます。

クラス:

  • Model: AlphaCubeソルバーのMLPアーキテクチャ。
  • LinearBlock: MLPアーキテクチャのビルディングブロック。

load_model

def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))

事前学習済みのAlphaCubeソルバーモデルを読み込みます。

引数:

  • model_id str - 読み込むモデルのバリアントを識別するID("small"、"base"、または"large")。
  • cache_dir str - ダウンロードしたモデルをキャッシュするディレクトリ。

戻り値:

  • nn.Module - 読み込まれたAlphaCubeソルバーモデル。

Model_v1

class Model_v1(nn.Module)

以下の論文で紹介された、計算に最適化されたルービックキューブソルバーのMLPアーキテクチャ: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

モデルの順伝播。

引数:

  • inputs torch.Tensor - 問題の状態を表す入力テンソル。

戻り値:

  • torch.Tensor - 可能な解に対する予測分布。

Model

class Model(nn.Module)

Modelよりも優れたアーキテクチャ。

変更点:

  • 最初の層(embedding)からReLU活性化を削除。これにはdying ReLUの問題があった。
  • 最近の慣例に従い、embedding層は隠れ層としてはカウントされない

reset_parameters

def reset_parameters()

活性化の分散がおよそ1.0になるようにすべての重みを初期化し、バイアス項と出力層の重みをゼロにする。

forward

def forward(inputs)

モデルの順伝播。

引数:

  • inputs torch.Tensor - 問題の状態を表す入力テンソル。

戻り値:

  • torch.Tensor - 可能な解に対する予測分布。

LinearBlock

class LinearBlock(nn.Module)

MLPアーキテクチャのビルディングブロック。

このブロックは、線形層の後にReLU活性化とバッチ正規化が続く構成になっています。

forward

def forward(inputs)

線形ブロックの順伝播。

引数:

  • inputs torch.Tensor - 入力テンソル。

戻り値:

  • torch.Tensor - 線形変換、ReLU活性化、バッチ正規化後の出力テンソル。