Aller au contenu principal

search

This module provides a function to perform beam search and find solutions for a given state.

Function: beam_search: Effectue une recherche en faisceau pour trouver des solutions dans un environnement de Rubik's Cube.

MAX_BATCH_SIZE

Le nombre maximum d'états à traiter en une seule passe avant (forward-pass) à travers un DNN.

def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)

Effectue une recherche en faisceau pour trouver des solutions pour un état mélangé donné.

Arguments:

  • env Cube3 - L'environnement Rubik's Cube représentant l'état mélangé.
  • model torch.nn.Module - Le DNN utilisé pour prédire la distribution de probabilité des prochains mouvements pour chaque état.
  • beam_width int - Le nombre maximum de candidats à conserver à chaque étape de la recherche.
  • ergonomic_bias dict ou None - Un dictionnaire spécifiant un biais ergonomique pour les mouvements, si disponible.
  • extra_depths int - Le nombre de profondeurs supplémentaires à explorer au-delà de la profondeur de la première solution trouvée.
  • max_depth int - La profondeur maximale de recherche, doit être égale ou supérieure au Nombre de Dieu (20 pour le Rubik's Cube en métrique HTM).

Retourne:

  • dict | None: Si au moins une solution est trouvée, un dictionnaire avec les clés suivantes :
  1. "solutions": Une liste des solutions optimales ou quasi-optimales trouvées pendant la recherche.
  2. "num_nodes": Le nombre total de nœuds explorés pendant la recherche.
  3. "time": Le temps (en secondes) nécessaire pour terminer la recherche.

Sinon, None.

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

Initialise le biais ergonomique s'il est fourni.

Arguments:

  • ergonomic_bias dict ou None - Un dictionnaire spécifiant un biais ergonomique pour les mouvements, si disponible.
  • env Cube3 - L'environnement Rubik's Cube représentant l'état mélangé.

Retourne:

  • ergonomic_bias numpy.ndarray - Le biais ergonomique pour les mouvements, si disponible.
  • env Cube3 - L'environnement Rubik's Cube représentant l'état mélangé.

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

Prédit la distribution de probabilité des prochains mouvements pour chaque état.

Arguments:

  • model torch.nn.Module - Le DNN utilisé pour prédire la distribution de probabilité des prochains mouvements pour chaque état.
  • batch_x numpy.ndarray - Un lot d'états.
  • ergonomic_bias dict ou None - Un dictionnaire spécifiant un biais ergonomique pour les mouvements, si disponible.
  • env Cube3 - L'environnement Rubik's Cube représentant l'état mélangé.

Retourne:

  • batch_logprob numpy.ndarray - Le logarithme de la probabilité de chaque mouvement pour chaque état.
remarque

Pour certaines raisons, l'inférence avec Automatic Mixed Precision est légèrement plus rapide que la simple demi-précision (avec model.half()).

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

Étend les chemins candidats avec les probabilités prédites des prochains mouvements.

Arguments:

  • candidates dict - Un dictionnaire contenant les chemins candidats, les probabilités cumulées et les états.
  • batch_logprob numpy.ndarray - Le logarithme de la probabilité de chaque mouvement pour chaque état.
  • env Cube3 - L'environnement Rubik's Cube représentant l'état mélangé.
  • depth int - La profondeur actuelle de la recherche.
  • beam_width int - Le nombre maximum de candidats à conserver à chaque étape de la recherche.

Retourne:

  • candidates dict - Le dictionnaire mis à jour contenant les chemins candidats, les probabilités cumulées et les états.

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

Obtient les indices des candidats à élaguer en fonction des mouvements précédents.

Arguments:

  • candidates_paths numpy.ndarray - Les chemins des états candidats.
  • allow_wide bool - Indique s'il faut autoriser les mouvements larges (wide moves).
  • depth int - La profondeur actuelle de la recherche.

Retourne:

  • prune_idx numpy.ndarray - Les indices des candidats à élaguer.
remarque

L'utilisation de numba.jit ralentit en réalité cette fonction.

_update_states

def _update_states(candidate_states, candidate_paths, env)

Met à jour les états en fonction des chemins étendus.

Arguments:

  • candidate_states numpy.ndarray - Les états des candidats.
  • candidate_paths numpy.ndarray - Les chemins des états candidats.
  • env Cube3 - L'environnement Rubik's Cube représentant l'état mélangé.

Retourne:

  • candidate_states numpy.ndarray - Les états mis à jour des candidats.

_map_state

def _map_state(candidate_states, target_ix, source_ix)

Effectue le remplacement des autocollants (stickers) au niveau du lot (batch).

Arguments:

  • candidate_states numpy.ndarray - Les états des candidats.
  • target_ix numpy.ndarray - Les indices cibles pour le remplacement des autocollants.
  • source_ix numpy.ndarray - Les indices sources pour le remplacement des autocollants.

Retourne:

  • candidate_states numpy.ndarray - Les états mis à jour des candidats.
remarque

L'utilisation de numba.jit ralentit en réalité cette fonction.