search
このモジュールは、ビーム探索を実行し、与えられた状態に対する解を見つける関数を提供します。
関数:
beam_search
: ルービックキューブ環境でビーム探索を実行し、解を見つけます。
MAX_BATCH_SIZE
一度にDNNを通過する状態の最大数。
beam_search
def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)
与えられたスクランブル状態に対する解を見つけ るためにビーム探索を実行します。
引数:
env
Cube3 - スクランブル状態を表すルービックキューブ環境。model
torch.nn.Module - すべての状態に対する次の手の確率分布を予測するために使用されるDNN。beam_width
int - 探索の各ステップで保持する候補の最大数。ergonomic_bias
dict or None - 利用可能な場合、手の人間工学的バイアスを指定する辞書。extra_depths
int - 最初の解の深さを超えて探索する追加の深さの数。max_depth
int - 探索する最大の深さ。God's Number(ルービックキューブの場合はHTMで20)以上でなければなりません。
戻り値:
dict
|None
: 少なくとも1つの解がある場合、以下のキーを持つ辞書:
"solutions"
: 探索中に見つかった最適または準最適な解のリスト。"num_nodes"
: 探索中に展開されたノードの総数。"time"
: 探索の完了に要した時間(秒)。
それ以外の場合はNone
。
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
提供された場合、人間工学的バイアスを初期化します。
引数:
ergonomic_bias
dict or None - 利用可能な場合、手の人間工学的バイアスを指定する辞書。env
Cube3 - スクランブル状態を表すルービックキューブ環境。
戻り値:
ergonomic_bias
numpy.ndarray - 利用可能な場合、手の人間工学的バイアス。env
Cube3 - スクランブル状態を表すルービックキューブ環境。
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
すべての状態に対する次の手の確率分布を予測します。
引数:
model
torch.nn.Module - すべての状態に対する次の手の確率分布を予測するために使用されるDNN。batch_x
numpy.ndarray - 状態のバッチ。ergonomic_bias
dict or None - 利用可能な場合、手の人間工学的バイアスを指定する辞書。env
Cube3 - スクランブル状態を表すルービックキューブ環境。
戻り値:
batch_logprob
numpy.ndarray - 各状態の各手の対数確率。
注記
何らかの理由で、Automatic Mixed Previsionを使用した推論は、単純な半精度(model.half()
を使用)よりもわずかに高速です。
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
次の手の予測確率で候補パスを拡張します。
引数:
candidates
dict - 候補パス、累積確率、状態を含む辞書。batch_logprob
numpy.ndarray - 各状態の各手の対数確率。env
Cube3 - スクランブル状態を表すルービックキューブ環境。depth
int - 探索の現在の深さ。beam_width
int - 探索の各ステップで保持する候補の最大数。
戻り値:
candidates
dict - 候補パス、累積確率、状態を含む更新された辞書。
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
前の手に基づいて、剪定する候補のインデックスを取得します。
引数:
candidates_paths
numpy.ndarray - 候補状態のパス。allow_wide
bool - ワイド手を許可するかどうか。depth
int - 探索の現在の深さ。
戻り値:
prune_idx
numpy.ndarray - 剪定する候補のインデックス。
注記
numba.jit
を使用すると、実際にはこの関数の速度が低下します。
_update_states
def _update_states(candidate_states, candidate_paths, env)
拡張されたパスに基づいて状 態を更新します。
引数:
candidate_states
numpy.ndarray - 候補状態の状態。candidate_paths
numpy.ndarray - 候補状態のパス。env
Cube3 - スクランブル状態を表すルービックキューブ環境。
戻り値:
candidate_states
numpy.ndarray - 更新された候補状態の状態。
_map_state
def _map_state(candidate_states, target_ix, source_ix)
バッチレベルでステッカーの置換を実行します。
引数:
candidate_states
numpy.ndarray - 候補状態の状態。target_ix
numpy.ndarray - ステッカー置換のターゲットインデックス。source_ix
numpy.ndarray - ステッカー置換のソースインデックス。
戻り値:
candidate_states
numpy.ndarray - 更新された候補状態の状態。
注記
numba.jit
を使用すると、実際にはこの関数の速度が低下します。