model
モデルローダー
このモジュールは、事前学習済みのAlphaCubeソルバーモデルを読み込んで使用するための関数とクラスを提供します。
関数:
load_model(model_id, cache_dir)
: 事前学習済みのAlphaCubeソルバーモデルを読み込みます。
クラス:
Model
: AlphaCubeソルバーのMLPアーキテクチャ。LinearBlock
: MLPアーキテクチャのビルディングブロック。
load_model
def load_model(model_id="small", cache_dir=os.path.expanduser("~/.cache/alphacube"))
事前学習済みのAlphaCubeソルバーモデルを読み込みます。
引数:
model_id
str - 読み込むモデルのバリアントを識別するID("small"、"base"、または"large")。cache_dir
str - ダウンロードしたモデルをキャッシュするディレクトリ。
戻り値:
nn.Module
- 読み込まれたAlphaCubeソルバーモデル。
Model_v1
class Model_v1(nn.Module)
以下の論文で紹介された、計算に最適化されたルービックキューブソルバーのMLPアーキテクチャ: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
モデルの順伝播。
引数:
inputs
torch.Tensor - 問題の状態を表す入力テンソル。
戻り 値:
torch.Tensor
- 可能な解に対する予測分布。
Model
class Model(nn.Module)
Model
よりも優れたアーキテクチャ。
変更点:
- 最初の層(
embedding
)からReLU活性化を削除。これにはdying ReLUの問題があった。 - 最近の 慣例に従い、
embedding
層は隠れ層としてはカウントされない。
reset_parameters
def reset_parameters()
活性化の分散がおよそ1.0になるようにすべての重みを初期化し、バイアス項と出力層の重みをゼロにする。
forward
def forward(inputs)
モデルの順伝播。
引数:
inputs
torch.Tensor - 問題の状態を表す入力テンソル。
戻り値:
torch.Tensor
- 可能な解に対する予測分布。
LinearBlock
class LinearBlock(nn.Module)
MLPアーキテクチャのビルディングブロック。
このブロックは、線形層の後にReLU活性化とバッチ正規化が続く構成になっています。
forward
def forward(inputs)
線形ブロックの順伝播。
引数:
inputs
torch.Tensor - 入力テンソル。
戻り値:
torch.Tensor
- 線形変換、ReLU活性化、バッチ正規化後の出力テンソル。