search
This module provides a function to perform beam search and find solutions for a given state.
Function:
beam_search
: Perform beam search to find solutions in a Rubik's Cube environment.
MAX_BATCH_SIZE
The maximum number of states forward-pass through a DNN at a time.
beam_search
def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)
Performs beam search to find solutions for a given scrambled state.
Arguments:
env
Cube3 - The Rubik's Cube environment representing the scrambled state.model
torch.nn.Module - DNN used to predict the probability distribution of next moves for every state.beam_width
int - The maximum number of candidates to keep at each step of the search.ergonomic_bias
dict or None - A dictionary specifying ergonomic bias for moves, if available.extra_depths
int - The number of additional depths to search beyond the first solution's depth.max_depth
int - The maximum depth to search, should be equal to or greater than God's Number (20 for Rubik's Cube in HTM).
Returns:
dict
|None
: With at least one solution, a dictionary with the following keys:
"solutions"
: A list of optimal or near-optimal solutions found during the search."num_nodes"
: The total number of nodes expanded during the search."time"
: The time taken (in seconds) to complete the search.
Otherwise, None
.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Initialize ergonomic bias if provided.
Arguments:
ergonomic_bias
dict or None - A dictionary specifying ergonomic bias for moves, if available.env
Cube3 - The Rubik's Cube environment representing the scrambled state.
Returns:
ergonomic_bias
numpy.ndarray - The ergonomic bias for moves, if available.env
Cube3 - The Rubik's Cube environment representing the scrambled state.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Predict the probability distribution of next moves for every state.
Arguments:
model
torch.nn.Module - DNN used to predict the probability distribution of next moves for every state.batch_x
numpy.ndarray - Batch of states.ergonomic_bias
dict or None - A dictionary specifying ergonomic bias for moves, if available.env
Cube3 - The Rubik's Cube environment representing the scrambled state.
Returns:
batch_logprob
numpy.ndarray - The log probability of each move for each state.
Inference with Automatic Mixed Prevision is slightly faster than
the simple half-precision (with model.half()
) for some reasons.
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Expand candidate paths with the predicted probabilities of next moves.
Arguments:
candidates
dict - A dictionary containing candidate paths, cumulative probabilities, and states.batch_logprob
numpy.ndarray - The log probability of each move for each state.env
Cube3 - The Rubik's Cube environment representing the scrambled state.depth
int - The current depth of the search.beam_width
int - The maximum number of candidates to keep at each step of the search.
Returns:
candidates
dict - The updated dictionary containing candidate paths, cumulative probabilities, and states.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Get the indices of candidates to prune based on previous moves.
Arguments:
candidates_paths
numpy.ndarray - The paths of candidate states.allow_wide
bool - Whether to allow wide moves.depth
int - The current depth of the search.
Returns:
prune_idx
numpy.ndarray - The indices of candidates to prune.
Using numba.jit
actually slows down this function.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Update states based on the expanded paths.
Arguments:
candidate_states
numpy.ndarray - The states of candidate states.candidate_paths
numpy.ndarray - The paths of candidate states.env
Cube3 - The Rubik's Cube environment representing the scrambled state.
Returns:
candidate_states
numpy.ndarray - The updated states of candidate states.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Perform sticker replacement on the batch level.
Arguments:
candidate_states
numpy.ndarray - The states of candidate states.target_ix
numpy.ndarray - The target indices for sticker replacement.source_ix
numpy.ndarray - The source indices for sticker replacement.
Returns:
candidate_states
numpy.ndarray - The updated states of candidate states.
Using numba.jit
actually slows down this function.