search
This module provides a function to perform beam search and find solutions for a given state.
Function:
beam_search: Perform beam search to find solutions in a Rubik's Cube environment.
MAX_BATCH_SIZE
The maximum number of states forward-pass through a DNN at a time.
beam_search
def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)
Performs beam search to find solutions for a given scrambled state.
Arguments:
envCube3 - The Rubik's Cube environment representing the scrambled state.modeltorch.nn.Module - DNN used to predict the probability distribution of next moves for every state.beam_widthint - The maximum number of candidates to keep at each step of the search.ergonomic_biasdict or None - A dictionary specifying ergonomic bias for moves, if available.extra_depthsint - The number of additional depths to search beyond the first solution's depth.max_depthint - The maximum depth to search, should be equal to or greater than God's Number (20 for Rubik's Cube in HTM).
Returns:
dict|None: With at least one solution, a dictionary with the following keys:
"solutions": A list of optimal or near-optimal solutions found during the search."num_nodes": The total number of nodes expanded during the search."time": The time taken (in seconds) to complete the search.
Otherwise, None.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Initialize ergonomic bias if provided.
Arguments:
ergonomic_biasdict or None - A dictionary specifying ergonomic bias for moves, if available.envCube3 - The Rubik's Cube environment representing the scrambled state.
Returns:
ergonomic_biasnumpy.ndarray - The ergonomic bias for moves, if available.envCube3 - The Rubik's Cube environment representing the scrambled state.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Predict the probability distribution of next moves for every state.
Arguments:
modeltorch.nn.Module - DNN used to predict the probability distribution of next moves for every state.batch_xnumpy.ndarray - Batch of states.ergonomic_biasdict or None - A dictionary specifying ergonomic bias for moves, if available.envCube3 - The Rubik's Cube environment representing the scrambled state.
Returns:
batch_logprobnumpy.ndarray - The log probability of each move for each state.
Inference with Automatic Mixed Prevision is slightly faster than
the simple half-precision (with model.half()) for some reasons.
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Expand candidate paths with the predicted probabilities of next moves.
Arguments:
candidatesdict - A dictionary containing candidate paths, cumulative probabilities, and states.batch_logprobnumpy.ndarray - The log probability of each move for each state.envCube3 - The Rubik's Cube environment representing the scrambled state.depthint - The current depth of the search.beam_widthint - The maximum number of candidates to keep at each step of the search.
Returns:
candidatesdict - The updated dictionary containing candidate paths, cumulative probabilities, and states.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Get the indices of candidates to prune based on previous moves.
Arguments:
candidates_pathsnumpy.ndarray - The paths of candidate states.allow_widebool - Whether to allow wide moves.depthint - The current depth of the search.
Returns:
prune_idxnumpy.ndarray - The indices of candidates to prune.
Using numba.jit actually slows down this function.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Update states based on the expanded paths.
Arguments:
candidate_statesnumpy.ndarray - The states of candidate states.candidate_pathsnumpy.ndarray - The paths of candidate states.envCube3 - The Rubik's Cube environment representing the scrambled state.
Returns:
candidate_statesnumpy.ndarray - The updated states of candidate states.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Perform sticker replacement on the batch level.
Arguments:
candidate_statesnumpy.ndarray - The states of candidate states.target_ixnumpy.ndarray - The target indices for sticker replacement.source_ixnumpy.ndarray - The source indices for sticker replacement.
Returns:
candidate_statesnumpy.ndarray - The updated states of candidate states.
Using numba.jit actually slows down this function.