メインコンテンツまでスキップ

search

このモジュールは、ビーム探索を実行し、与えられた状態に対する解を見つける関数を提供します。

関数: beam_search: ルービックキューブ環境でビーム探索を実行し、解を見つけます。

MAX_BATCH_SIZE

一度にDNNを通過する状態の最大数。

def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)

与えられたスクランブル状態に対する解を見つけるためにビーム探索を実行します。

引数:

  • env Cube3 - スクランブル状態を表すルービックキューブ環境。
  • model torch.nn.Module - すべての状態に対する次の手の確率分布を予測するために使用されるDNN。
  • beam_width int - 探索の各ステップで保持する候補の最大数。
  • ergonomic_bias dict or None - 利用可能な場合、手の人間工学的バイアスを指定する辞書。
  • extra_depths int - 最初の解の深さを超えて探索する追加の深さの数。
  • max_depth int - 探索する最大の深さ。God's Number(ルービックキューブの場合はHTMで20)以上でなければなりません。

戻り値:

  • dict | None: 少なくとも1つの解がある場合、以下のキーを持つ辞書:
  1. "solutions": 探索中に見つかった最適または準最適な解のリスト。
  2. "num_nodes": 探索中に展開されたノードの総数。
  3. "time": 探索の完了に要した時間(秒)。

それ以外の場合はNone

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

提供された場合、人間工学的バイアスを初期化します。

引数:

  • ergonomic_bias dict or None - 利用可能な場合、手の人間工学的バイアスを指定する辞書。
  • env Cube3 - スクランブル状態を表すルービックキューブ環境。

戻り値:

  • ergonomic_bias numpy.ndarray - 利用可能な場合、手の人間工学的バイアス。
  • env Cube3 - スクランブル状態を表すルービックキューブ環境。

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

すべての状態に対する次の手の確率分布を予測します。

引数:

  • model torch.nn.Module - すべての状態に対する次の手の確率分布を予測するために使用されるDNN。
  • batch_x numpy.ndarray - 状態のバッチ。
  • ergonomic_bias dict or None - 利用可能な場合、手の人間工学的バイアスを指定する辞書。
  • env Cube3 - スクランブル状態を表すルービックキューブ環境。

戻り値:

  • batch_logprob numpy.ndarray - 各状態の各手の対数確率。
注記

何らかの理由で、Automatic Mixed Previsionを使用した推論は、単純な半精度(model.half()を使用)よりもわずかに高速です。

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

次の手の予測確率で候補パスを拡張します。

引数:

  • candidates dict - 候補パス、累積確率、状態を含む辞書。
  • batch_logprob numpy.ndarray - 各状態の各手の対数確率。
  • env Cube3 - スクランブル状態を表すルービックキューブ環境。
  • depth int - 探索の現在の深さ。
  • beam_width int - 探索の各ステップで保持する候補の最大数。

戻り値:

  • candidates dict - 候補パス、累積確率、状態を含む更新された辞書。

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

前の手に基づいて、剪定する候補のインデックスを取得します。

引数:

  • candidates_paths numpy.ndarray - 候補状態のパス。
  • allow_wide bool - ワイド手を許可するかどうか。
  • depth int - 探索の現在の深さ。

戻り値:

  • prune_idx numpy.ndarray - 剪定する候補のインデックス。
注記

numba.jitを使用すると、実際にはこの関数の速度が低下します

_update_states

def _update_states(candidate_states, candidate_paths, env)

拡張されたパスに基づいて状態を更新します。

引数:

  • candidate_states numpy.ndarray - 候補状態の状態。
  • candidate_paths numpy.ndarray - 候補状態のパス。
  • env Cube3 - スクランブル状態を表すルービックキューブ環境。

戻り値:

  • candidate_states numpy.ndarray - 更新された候補状態の状態。

_map_state

def _map_state(candidate_states, target_ix, source_ix)

バッチレベルでステッカーの置換を実行します。

引数:

  • candidate_states numpy.ndarray - 候補状態の状態。
  • target_ix numpy.ndarray - ステッカー置換のターゲットインデックス。
  • source_ix numpy.ndarray - ステッカー置換のソースインデックス。

戻り値:

  • candidate_states numpy.ndarray - 更新された候補状態の状態。
注記

numba.jitを使用すると、実際にはこの関数の速度が低下します