Skip to main content

search

This module provides a function to perform beam search and find solutions for a given state.

Function: beam_search: Perform beam search to find solutions in a Rubik's Cube environment.

MAX_BATCH_SIZE

The maximum number of states forward-pass through a DNN at a time.

def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)

Performs beam search to find solutions for a given scrambled state.

Arguments:

  • env Cube3 - The Rubik's Cube environment representing the scrambled state.
  • model torch.nn.Module - DNN used to predict the probability distribution of next moves for every state.
  • beam_width int - The maximum number of candidates to keep at each step of the search.
  • ergonomic_bias dict or None - A dictionary specifying ergonomic bias for moves, if available.
  • extra_depths int - The number of additional depths to search beyond the first solution's depth.
  • max_depth int - The maximum depth to search, should be equal to or greater than God's Number (20 for Rubik's Cube in HTM).

Returns:

  • dict | None: With at least one solution, a dictionary with the following keys:
  1. "solutions": A list of optimal or near-optimal solutions found during the search.
  2. "num_nodes": The total number of nodes expanded during the search.
  3. "time": The time taken (in seconds) to complete the search.

Otherwise, None.

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

Initialize ergonomic bias if provided.

Arguments:

  • ergonomic_bias dict or None - A dictionary specifying ergonomic bias for moves, if available.
  • env Cube3 - The Rubik's Cube environment representing the scrambled state.

Returns:

  • ergonomic_bias numpy.ndarray - The ergonomic bias for moves, if available.
  • env Cube3 - The Rubik's Cube environment representing the scrambled state.

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

Predict the probability distribution of next moves for every state.

Arguments:

  • model torch.nn.Module - DNN used to predict the probability distribution of next moves for every state.
  • batch_x numpy.ndarray - Batch of states.
  • ergonomic_bias dict or None - A dictionary specifying ergonomic bias for moves, if available.
  • env Cube3 - The Rubik's Cube environment representing the scrambled state.

Returns:

  • batch_logprob numpy.ndarray - The log probability of each move for each state.
note

Inference with Automatic Mixed Prevision is slightly faster than the simple half-precision (with model.half()) for some reasons.

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

Expand candidate paths with the predicted probabilities of next moves.

Arguments:

  • candidates dict - A dictionary containing candidate paths, cumulative probabilities, and states.
  • batch_logprob numpy.ndarray - The log probability of each move for each state.
  • env Cube3 - The Rubik's Cube environment representing the scrambled state.
  • depth int - The current depth of the search.
  • beam_width int - The maximum number of candidates to keep at each step of the search.

Returns:

  • candidates dict - The updated dictionary containing candidate paths, cumulative probabilities, and states.

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

Get the indices of candidates to prune based on previous moves.

Arguments:

  • candidates_paths numpy.ndarray - The paths of candidate states.
  • allow_wide bool - Whether to allow wide moves.
  • depth int - The current depth of the search.

Returns:

  • prune_idx numpy.ndarray - The indices of candidates to prune.
note

Using numba.jit actually slows down this function.

_update_states

def _update_states(candidate_states, candidate_paths, env)

Update states based on the expanded paths.

Arguments:

  • candidate_states numpy.ndarray - The states of candidate states.
  • candidate_paths numpy.ndarray - The paths of candidate states.
  • env Cube3 - The Rubik's Cube environment representing the scrambled state.

Returns:

  • candidate_states numpy.ndarray - The updated states of candidate states.

_map_state

def _map_state(candidate_states, target_ix, source_ix)

Perform sticker replacement on the batch level.

Arguments:

  • candidate_states numpy.ndarray - The states of candidate states.
  • target_ix numpy.ndarray - The target indices for sticker replacement.
  • source_ix numpy.ndarray - The source indices for sticker replacement.

Returns:

  • candidate_states numpy.ndarray - The updated states of candidate states.
note

Using numba.jit actually slows down this function.