メインコンテンツまでスキップ

search

このモジュールは、ビームサーチを実行し、与えられた状態に対する解を見つけるための関数を提供します。

関数: beam_search: ルービックキューブ環境でビームサーチを実行し、解を見つけます。

MAX_BATCH_SIZE

一度にDNNを順伝播させる状態の最大数。

def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)

与えられたスクランブル状態に対する解を見つけるためにビームサーチを実行します。

引数:

  • env Cube3 - スクランブルされた状態を表すルービックキューブ環境。
  • model torch.nn.Module - 各状態に対する次の手の確率分布を予測するために使用されるDNN。
  • beam_width int - 探索の各ステップで保持する候補の最大数。
  • ergonomic_bias dict or None - 利用可能な場合、手の人間工学的なバイアスを指定する辞書。
  • extra_depths int - 最初の解が見つかった深さを超えて探索する追加の深さの数。
  • max_depth int - 探索する最大の深さ。神の数(HTMにおけるルービックキューブでは20)以上である必要があります。

戻り値:

  • dict | None: 少なくとも1つの解が見つかった場合、以下のキーを持つ辞書:
  1. "solutions": 探索中に見つかった最適解またはそれに近い解のリスト。
  2. "num_nodes": 探索中に展開されたノードの総数。
  3. "time": 探索にかかった時間(秒)。

それ以外の場合は None

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

人間工学的バイアスが提供された場合に初期化します。

引数:

  • ergonomic_bias dict or None - 利用可能な場合、手の人間工学的なバイアスを指定する辞書。
  • env Cube3 - スクランブルされた状態を表すルービックキューブ環境。

戻り値:

  • ergonomic_bias numpy.ndarray - 利用可能な場合、手の人間工学的バイアス。
  • env Cube3 - スクランブルされた状態を表すルービックキューブ環境。

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

各状態に対する次の手の確率分布を予測します。

引数:

  • model torch.nn.Module - 各状態に対する次の手の確率分布を予測するために使用されるDNN。
  • batch_x numpy.ndarray - 状態のバッチ。
  • ergonomic_bias dict or None - 利用可能な場合、手の人間工学的なバイアスを指定する辞書。
  • env Cube3 - スクランブルされた状態を表すルービックキューブ環境。

戻り値:

  • batch_logprob numpy.ndarray - 各状態の各手の対数確率。
注記

何らかの理由で、自動混合精度を使用した推論は、 単純な半精度(model.half()を使用)よりもわずかに高速です。

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

予測された次の手の確率で候補パスを展開します。

引数:

  • candidates dict - 候補パス、累積確率、および状態を含む辞書。
  • batch_logprob numpy.ndarray - 各状態の各手の対数確率。
  • env Cube3 - スクランブルされた状態を表すルービックキューブ環境。
  • depth int - 探索の現在の深さ。
  • beam_width int - 探索の各ステップで保持する候補の最大数。

戻り値:

  • candidates dict - 更新された、候補パス、累積確率、および状態を含む辞書。

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

前の手に基づいて枝刈りする候補のインデックスを取得します。

引数:

  • candidates_paths numpy.ndarray - 候補状態のパス。
  • allow_wide bool - ワイドムーブを許可するかどうか。
  • depth int - 探索の現在の深さ。

戻り値:

  • prune_idx numpy.ndarray - 枝刈りする候補のインデックス。
注記

numba.jitを使用すると、実際にはこの関数が遅くなります。

_update_states

def _update_states(candidate_states, candidate_paths, env)

展開されたパスに基づいて状態を更新します。

引数:

  • candidate_states numpy.ndarray - 候補状態の状態。
  • candidate_paths numpy.ndarray - 候補状態のパス。
  • env Cube3 - スクランブルされた状態を表すルービックキューブ環境。

戻り値:

  • candidate_states numpy.ndarray - 更新された候補状態の状態。

_map_state

def _map_state(candidate_states, target_ix, source_ix)

バッチレベルでステッカーの置換を実行します。

引数:

  • candidate_states numpy.ndarray - 候補状態の状態。
  • target_ix numpy.ndarray - ステッカー置換のターゲットインデックス。
  • source_ix numpy.ndarray - ステッカー置換のソースインデックス。

戻り値:

  • candidate_states numpy.ndarray - 更新された候補状態の状態。
注記

numba.jitを使用すると、実際にはこの関数が遅くなります。