Aller au contenu principal

search

Ce module fournit une fonction pour effectuer une recherche en faisceau (beam search) et trouver des solutions pour un état donné.

Fonction : beam_search : Effectue une recherche en faisceau pour trouver des solutions dans un environnement de Rubik's Cube.

MAX_BATCH_SIZE

Le nombre maximum d'états passés en une seule fois à travers un réseau de neurones profond (DNN).

def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)

Effectue une recherche en faisceau pour trouver des solutions pour un état mélangé donné.

Arguments :

  • env Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.
  • model torch.nn.Module - Le DNN utilisé pour prédire la distribution de probabilité des prochains mouvements pour chaque état.
  • beam_width int - Le nombre maximum de candidats à conserver à chaque étape de la recherche.
  • ergonomic_bias dict ou None - Un dictionnaire spécifiant le biais ergonomique pour les mouvements, s'il est disponible.
  • extra_depths int - Le nombre de profondeurs supplémentaires à explorer au-delà de la profondeur de la première solution.
  • max_depth int - La profondeur maximale de recherche, doit être supérieure ou égale au nombre de Dieu (20 pour le Rubik's Cube en HTM).

Retourne :

  • dict | None : Avec au moins une solution, un dictionnaire avec les clés suivantes :
  1. "solutions" : Une liste des solutions optimales ou quasi-optimales trouvées pendant la recherche.
  2. "num_nodes" : Le nombre total de nœuds développés pendant la recherche.
  3. "time" : Le temps pris (en secondes) pour compléter la recherche.

Sinon, None.

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

Initialise le biais ergonomique s'il est fourni.

Arguments :

  • ergonomic_bias dict ou None - Un dictionnaire spécifiant le biais ergonomique pour les mouvements, s'il est disponible.
  • env Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.

Retourne :

  • ergonomic_bias numpy.ndarray - Le biais ergonomique pour les mouvements, s'il est disponible.
  • env Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

Prédit la distribution de probabilité des prochains mouvements pour chaque état.

Arguments :

  • model torch.nn.Module - Le DNN utilisé pour prédire la distribution de probabilité des prochains mouvements pour chaque état.
  • batch_x numpy.ndarray - Lot d'états.
  • ergonomic_bias dict ou None - Un dictionnaire spécifiant le biais ergonomique pour les mouvements, s'il est disponible.
  • env Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.

Retourne :

  • batch_logprob numpy.ndarray - Le log de la probabilité de chaque mouvement pour chaque état.
remarque

L'inférence avec Automatic Mixed Prevision est légèrement plus rapide que la simple demi-précision (avec model.half()) pour certaines raisons.

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

Développe les chemins candidats avec les probabilités prédites des prochains mouvements.

Arguments :

  • candidates dict - Un dictionnaire contenant les chemins candidats, les probabilités cumulées et les états.
  • batch_logprob numpy.ndarray - Le log de la probabilité de chaque mouvement pour chaque état.
  • env Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.
  • depth int - La profondeur actuelle de la recherche.
  • beam_width int - Le nombre maximum de candidats à conserver à chaque étape de la recherche.

Retourne :

  • candidates dict - Le dictionnaire mis à jour contenant les chemins candidats, les probabilités cumulées et les états.

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

Obtient les indices des candidats à élaguer en fonction des mouvements précédents.

Arguments :

  • candidates_paths numpy.ndarray - Les chemins des états candidats.
  • allow_wide bool - Indique s'il faut autoriser les mouvements larges.
  • depth int - La profondeur actuelle de la recherche.

Retourne :

  • prune_idx numpy.ndarray - Les indices des candidats à élaguer.
remarque

L'utilisation de numba.jit ralentit en fait cette fonction.

_update_states

def _update_states(candidate_states, candidate_paths, env)

Met à jour les états en fonction des chemins développés.

Arguments :

  • candidate_states numpy.ndarray - Les états des états candidats.
  • candidate_paths numpy.ndarray - Les chemins des états candidats.
  • env Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.

Retourne :

  • candidate_states numpy.ndarray - Les états mis à jour des états candidats.

_map_state

def _map_state(candidate_states, target_ix, source_ix)

Effectue le remplacement des autocollants au niveau du lot.

Arguments :

  • candidate_states numpy.ndarray - Les états des états candidats.
  • target_ix numpy.ndarray - Les indices cibles pour le remplacement des autocollants.
  • source_ix numpy.ndarray - Les indices sources pour le remplacement des autocollants.

Retourne :

  • candidate_states numpy.ndarray - Les états mis à jour des états candidats.
remarque

L'utilisation de numba.jit ralentit en fait cette fonction.