search
このモジュールは、ビームサーチを実行し、与えられた状態に対する解を見つけるための関数を提供します。
関数:
beam_search: ルービックキューブ環境でビームサーチを実行し、解を見つけます。
MAX_BATCH_SIZE
一度にDNNを順伝播させる状態の最大数。
beam_search
def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)
与えられたスクランブル状態に対する解を見つけるためにビームサーチを実行します。
引数:
envCube3 - スクランブルされた状態を表すルービックキューブ環境。modeltorch.nn.Module - 各状態に対する次の手の確率分布を予測するために使用されるDNN。beam_widthint - 探索の各ステップで保持する候補の最大数。ergonomic_biasdict or None - 利用可能な場合、手の人間工学的なバイアスを指定する辞書。extra_depthsint - 最初の解が見つかった深さを超えて探索する追加の深さの数。max_depthint - 探索する最大の深さ。神の数(HTMにおけるルービックキューブでは20)以上である必要があります。
戻り値:
dict|None: 少なくとも1つの解が見つかった場合、以下のキーを持つ辞書:
"solutions": 探索中に見つかった最適解またはそれに近い解のリスト。"num_nodes": 探索中に展開されたノードの総数。"time": 探索にかかった時間(秒)。
それ以外の場合は None。
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
人間工学的バイアスが提供された場合に初期化します。
引数:
ergonomic_biasdict or None - 利用可能な場合、手の人間工学的なバイアスを指定する辞書。envCube3 - スクランブルされた状態を表すルービックキューブ環境。
戻り値:
ergonomic_biasnumpy.ndarray - 利用可能な場合、手の人間工学的バイアス。envCube3 - スクランブルされた状態を表すルービックキューブ環境。
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
各状態に対する次の手の確率分布を予測します。
引数:
modeltorch.nn.Module - 各状態に対する次の手の確率分布を予測するために使用されるDNN。batch_xnumpy.ndarray - 状態のバッチ。ergonomic_biasdict or None - 利用可能な場合、手の人間工学的なバイアスを指定する辞書。envCube3 - スクランブルされた状態を表すルービックキューブ環境。
戻り値:
batch_logprobnumpy.ndarray - 各状態の各手の対数確率。
注記
何らかの理由で、自動混合精度を使用した推論は、
単純な半精度(model.half()を使用)よりもわずかに高速です。
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
予測された次の手の確率で候補パスを展開します。
引数:
candidatesdict - 候補パス、累積確率、および状態を含む辞書。batch_logprobnumpy.ndarray - 各状態の各手の対数確率。envCube3 - スクランブルされた状態を表すルービックキューブ環境。depthint - 探索の現在の深さ。beam_widthint - 探索の各ステップで保持する候補の最大数。
戻り値:
candidatesdict - 更新された、候補パス、累積確率、および状態を含む辞書。
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
前の手に基づいて枝刈りする候補のインデックスを取得します。
引数:
candidates_pathsnumpy.ndarray - 候補状態のパス。allow_widebool - ワイドムーブを許可するかどうか。depthint - 探索の現在の深さ。
戻り値:
prune_idxnumpy.ndarray - 枝刈りする候補のインデックス。
注記
numba.jitを使用すると、実際にはこの関数が遅くなります。
_update_states
def _update_states(candidate_states, candidate_paths, env)
展開されたパスに基づいて状態を更新します。
引数:
candidate_statesnumpy.ndarray - 候補状態の状態。candidate_pathsnumpy.ndarray - 候補状態のパス。envCube3 - スクランブルされた状態を表すルービックキューブ環境。
戻り値:
candidate_statesnumpy.ndarray - 更新された候補状態の状態。
_map_state
def _map_state(candidate_states, target_ix, source_ix)
バッチレベルでステッカーの置換を実行します。
引数:
candidate_statesnumpy.ndarray - 候補状態の状態。target_ixnumpy.ndarray - ステッカー置換のターゲットインデックス。source_ixnumpy.ndarray - ステッカー置換のソースインデックス。
戻り値:
candidate_statesnumpy.ndarray - 更新された候補状態の状態。
注記
numba.jitを使用すると、実際にはこの関数が遅くなります。