search
Ce module fournit une fonction pour effectuer une recherche en faisceau (beam search) et trouver des solutions pour un état donné.
Fonction :
beam_search
: Effectue une recherche en faisceau pour trouver des solutions dans un environnement de Rubik's Cube.
MAX_BATCH_SIZE
Le nombre maximum d'états passés en une seule fois à travers un réseau de neurones profond (DNN).
beam_search
def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)
Effectue une recherche en faisceau pour trouver des solutions pour un état mélangé donné.
Arguments :
env
Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.model
torch.nn.Module - Le DNN utilisé pour prédire la distribution de probabilité des prochains mouvements pour chaque état.beam_width
int - Le nombre maximum de candidats à conserver à chaque étape de la recherche.ergonomic_bias
dict ou None - Un dictionnaire spécifiant le biais ergonomique pour les mouvements, s'il est disponible.extra_depths
int - Le nombre de profondeurs supplémentaires à explorer au-delà de la profondeur de la première solution.max_depth
int - La profondeur maximale de recherche, doit être supérieure ou égale au nombre de Dieu (20 pour le Rubik's Cube en HTM).
Retourne :
dict
|None
: Avec au moins une solution, un dictionnaire avec les clés suivantes :
"solutions"
: Une liste des solutions optimales ou quasi-optimales trouvées pendant la recherche."num_nodes"
: Le nombre total de nœuds développés pendant la recherche."time"
: Le temps pris (en secondes) pour compléter la recherche.
Sinon, None
.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Initialise le biais ergonomique s'il est fourni.
Arguments :
ergonomic_bias
dict ou None - Un dictionnaire spécifiant le biais ergonomique pour les mouvements, s'il est disponible.env
Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.
Retourne :
ergonomic_bias
numpy.ndarray - Le biais ergonomique pour les mouvements, s'il est disponible.env
Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Prédit la distribution de probabilité des prochains mouvements pour chaque état.
Arguments :
model
torch.nn.Module - Le DNN utilisé pour prédire la distribution de probabilité des prochains mouvements pour chaque état.batch_x
numpy.ndarray - Lot d'états.ergonomic_bias
dict ou None - Un dictionnaire spécifiant le biais ergonomique pour les mouvements, s'il est disponible.env
Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.
Retourne :
batch_logprob
numpy.ndarray - Le log de la probabilité de chaque mouvement pour chaque état.
L'inférence avec Automatic Mixed Prevision est légèrement plus rapide que la simple demi-précision (avec model.half()
) pour certaines raisons.
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Développe les chemins candidats avec les probabilités prédites des prochains mouvements.
Arguments :
candidates
dict - Un dictionnaire contenant les chemins candidats, les probabilités cumulées et les états.batch_logprob
numpy.ndarray - Le log de la probabilité de chaque mouvement pour chaque état.env
Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.depth
int - La profondeur actuelle de la recherche.beam_width
int - Le nombre maximum de candidats à conserver à chaque étape de la recherche.
Retourne :
candidates
dict - Le dictionnaire mis à jour contenant les chemins candidats, les probabilités cumulées et les états.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Obtient les indices des candidats à élaguer en fonction des mouvements précédents.
Arguments :
candidates_paths
numpy.ndarray - Les chemins des états candidats.allow_wide
bool - Indique s'il faut autoriser les mouvements larges.depth
int - La profondeur actuelle de la recherche.
Retourne :
prune_idx
numpy.ndarray - Les indices des candidats à élaguer.
L'utilisation de numba.jit
ralentit en fait cette fonction.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Met à jour les états en fonction des chemins développés.
Arguments :
candidate_states
numpy.ndarray - Les états des états candidats.candidate_paths
numpy.ndarray - Les chemins des états candidats.env
Cube3 - L'environnement du Rubik's Cube représentant l'état mélangé.
Retourne :
candidate_states
numpy.ndarray - Les états mis à jour des états candidats.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Effectue le remplacement des autocollants au niveau du lot.
Arguments :
candidate_states
numpy.ndarray - Les états des états candidats.target_ix
numpy.ndarray - Les indices cibles pour le remplacement des autocollants.source_ix
numpy.ndarray - Les indices sources pour le remplacement des autocollants.
Retourne :
candidate_states
numpy.ndarray - Les états mis à jour des états candidats.
L'utilisation de numba.jit
ralentit en fait cette fonction.