search
Dieses Modul stellt eine Funktion zur Verfügung, um eine Beam-Suche durchzuführen und Lösungen für einen gegebenen Zustand zu finden.
Funktion:
beam_search
: Führt eine Beam-Suche durch, um Lösungen in einer Rubik's Cube Umgebung zu finden.
MAX_BATCH_SIZE
Die maximale Anzahl von Zuständen, die gleichzeitig durch ein DNN vorwärts propagiert werden.
beam_search
def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)
Führt eine Beam-Suche durch, um Lösungen für einen gegebenen verscrambelten Zustand zu finden.
Argumente:
env
Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.model
torch.nn.Module - DNN, das verwendet wird, um die Wahrscheinlichkeitsverteilung der nächsten Züge für jeden Zustand vorherzusagen.beam_width
int - Die maximale Anzahl von Kandidaten, die bei jedem Schritt der Suche beibehalten werden.ergonomic_bias
dict oder None - Ein Wörterbuch, das den ergonomischen Bias für Züge angibt, falls verfügbar.extra_depths
int - Die Anzahl zusätzlicher Tiefen, die über die Tiefe der ersten Lösung hinaus durchsucht werden sollen.max_depth
int - Die maximale Suchtiefe, sollte gleich oder größer als God's Number sein (20 für Rubik's Cube in HTM).
Rückgabe:
dict
|None
: Mit mindestens einer Lösung, ein Wörterbuch mit folgenden Schlüsseln:
"solutions"
: Eine Liste der optimalen oder nahezu optimalen Lösungen, die während der Suche gefunden wurden."num_nodes"
: Die Gesamtzahl der während der Suche expandierten Knoten."time"
: Die für die Suche benötigte Zeit (in Sekunden).
Andernfalls None
.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Initialisiert den ergonomischen Bias, falls vorhanden.
Argumente:
ergonomic_bias
dict oder None - Ein Wörterbuch, das den ergonomischen Bias für Züge angibt, falls verfügbar.env
Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.
Rückgabe:
ergonomic_bias
numpy.ndarray - Der ergonomische Bias für Züge, falls verfügbar.env
Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Sagt die Wahrscheinlichkeitsverteilung der nächsten Züge für jeden Zustand vorher.
Argumente:
model
torch.nn.Module - DNN, das verwendet wird, um die Wahrscheinlichkeitsverteilung der nächsten Züge für jeden Zustand vorherzusagen.batch_x
numpy.ndarray - Batch von Zuständen.ergonomic_bias
dict oder None - Ein Wörterbuch, das den ergonomischen Bias für Züge angibt, falls verfügbar.env
Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.
Rückgabe:
batch_logprob
numpy.ndarray - Die logarithmische Wahrscheinlichkeit jedes Zuges für jeden Zustand.
Inferenz mit Automatic Mixed Prevision ist aus irgendeinem Grund etwas schneller als die einfache Halbpräzision (mit model.half()
).
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Erweitert Kandidatenpfade mit den vorhergesagten Wahrscheinlichkeiten der nächsten Züge.
Argumente:
candidates
dict - Ein Wörterbuch, das Kandidatenpfade, kumulative Wahrscheinlichkeiten und Zustände enthält.batch_logprob
numpy.ndarray - Die logarithmische Wahrscheinlichkeit jedes Zuges für jeden Zustand.env
Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.depth
int - Die aktuelle Tiefe der Suche.beam_width
int - Die maximale Anzahl von Kandidaten, die bei jedem Schritt der Suche beibehalten werden.
Rückgabe:
candidates
dict - Das aktualisierte Wörterbuch, das Kandidatenpfade, kumulative Wahrscheinlichkeiten und Zustände enthält.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Ermittelt die Indizes der Kandidaten, die basierend auf vorherigen Zügen entfernt werden sollen.
Argumente:
candidates_paths
numpy.ndarray - Die Pfade der Kandidatenzustände.allow_wide
bool - Ob weite Züge erlaubt sind.depth
int - Die aktuelle Tiefe der Suche.
Rückgabe:
prune_idx
numpy.ndarray - Die Indizes der zu entfernenden Kandidaten.
Die Verwendung von numba.jit
verlangsamt diese Funktion tatsächlich.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Aktualisiert Zustände basierend auf den erweiterten Pfaden.
Argumente:
candidate_states
numpy.ndarray - Die Zustände der Kandidatenzustände.candidate_paths
numpy.ndarray - Die Pfade der Kandidatenzustände.env
Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.
Rückgabe:
candidate_states
numpy.ndarray - Die aktualisierten Zustände der Kandidatenzustände.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Führt einen Sticker-Austausch auf Batch-Ebene durch.
Argumente:
candidate_states
numpy.ndarray - Die Zustände der Kandidatenzustände.target_ix
numpy.ndarray - Die Zielindizes für den Sticker-Austausch.source_ix
numpy.ndarray - Die Quellindizes für den Sticker-Austausch.
Rückgabe:
candidate_states
numpy.ndarray - Die aktualisierten Zustände der Kandidatenzustände.
Die Verwendung von numba.jit
verlangsamt diese Funktion tatsächlich.