Zum Hauptinhalt springen

search

Dieses Modul stellt eine Funktion zur Verfügung, um eine Beam-Suche durchzuführen und Lösungen für einen gegebenen Zustand zu finden.

Funktion: beam_search: Führt eine Beam-Suche durch, um Lösungen in einer Rubik's Cube Umgebung zu finden.

MAX_BATCH_SIZE

Die maximale Anzahl von Zuständen, die gleichzeitig durch ein DNN vorwärts propagiert werden.

def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)

Führt eine Beam-Suche durch, um Lösungen für einen gegebenen verscrambelten Zustand zu finden.

Argumente:

  • env Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.
  • model torch.nn.Module - DNN, das verwendet wird, um die Wahrscheinlichkeitsverteilung der nächsten Züge für jeden Zustand vorherzusagen.
  • beam_width int - Die maximale Anzahl von Kandidaten, die bei jedem Schritt der Suche beibehalten werden.
  • ergonomic_bias dict oder None - Ein Wörterbuch, das den ergonomischen Bias für Züge angibt, falls verfügbar.
  • extra_depths int - Die Anzahl zusätzlicher Tiefen, die über die Tiefe der ersten Lösung hinaus durchsucht werden sollen.
  • max_depth int - Die maximale Suchtiefe, sollte gleich oder größer als God's Number sein (20 für Rubik's Cube in HTM).

Rückgabe:

  • dict | None: Mit mindestens einer Lösung, ein Wörterbuch mit folgenden Schlüsseln:
  1. "solutions": Eine Liste der optimalen oder nahezu optimalen Lösungen, die während der Suche gefunden wurden.
  2. "num_nodes": Die Gesamtzahl der während der Suche expandierten Knoten.
  3. "time": Die für die Suche benötigte Zeit (in Sekunden).

Andernfalls None.

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

Initialisiert den ergonomischen Bias, falls vorhanden.

Argumente:

  • ergonomic_bias dict oder None - Ein Wörterbuch, das den ergonomischen Bias für Züge angibt, falls verfügbar.
  • env Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.

Rückgabe:

  • ergonomic_bias numpy.ndarray - Der ergonomische Bias für Züge, falls verfügbar.
  • env Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

Sagt die Wahrscheinlichkeitsverteilung der nächsten Züge für jeden Zustand vorher.

Argumente:

  • model torch.nn.Module - DNN, das verwendet wird, um die Wahrscheinlichkeitsverteilung der nächsten Züge für jeden Zustand vorherzusagen.
  • batch_x numpy.ndarray - Batch von Zuständen.
  • ergonomic_bias dict oder None - Ein Wörterbuch, das den ergonomischen Bias für Züge angibt, falls verfügbar.
  • env Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.

Rückgabe:

  • batch_logprob numpy.ndarray - Die logarithmische Wahrscheinlichkeit jedes Zuges für jeden Zustand.
note

Inferenz mit Automatic Mixed Prevision ist aus irgendeinem Grund etwas schneller als die einfache Halbpräzision (mit model.half()).

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

Erweitert Kandidatenpfade mit den vorhergesagten Wahrscheinlichkeiten der nächsten Züge.

Argumente:

  • candidates dict - Ein Wörterbuch, das Kandidatenpfade, kumulative Wahrscheinlichkeiten und Zustände enthält.
  • batch_logprob numpy.ndarray - Die logarithmische Wahrscheinlichkeit jedes Zuges für jeden Zustand.
  • env Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.
  • depth int - Die aktuelle Tiefe der Suche.
  • beam_width int - Die maximale Anzahl von Kandidaten, die bei jedem Schritt der Suche beibehalten werden.

Rückgabe:

  • candidates dict - Das aktualisierte Wörterbuch, das Kandidatenpfade, kumulative Wahrscheinlichkeiten und Zustände enthält.

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

Ermittelt die Indizes der Kandidaten, die basierend auf vorherigen Zügen entfernt werden sollen.

Argumente:

  • candidates_paths numpy.ndarray - Die Pfade der Kandidatenzustände.
  • allow_wide bool - Ob weite Züge erlaubt sind.
  • depth int - Die aktuelle Tiefe der Suche.

Rückgabe:

  • prune_idx numpy.ndarray - Die Indizes der zu entfernenden Kandidaten.
note

Die Verwendung von numba.jit verlangsamt diese Funktion tatsächlich.

_update_states

def _update_states(candidate_states, candidate_paths, env)

Aktualisiert Zustände basierend auf den erweiterten Pfaden.

Argumente:

  • candidate_states numpy.ndarray - Die Zustände der Kandidatenzustände.
  • candidate_paths numpy.ndarray - Die Pfade der Kandidatenzustände.
  • env Cube3 - Die Rubik's Cube Umgebung, die den verscrambelten Zustand repräsentiert.

Rückgabe:

  • candidate_states numpy.ndarray - Die aktualisierten Zustände der Kandidatenzustände.

_map_state

def _map_state(candidate_states, target_ix, source_ix)

Führt einen Sticker-Austausch auf Batch-Ebene durch.

Argumente:

  • candidate_states numpy.ndarray - Die Zustände der Kandidatenzustände.
  • target_ix numpy.ndarray - Die Zielindizes für den Sticker-Austausch.
  • source_ix numpy.ndarray - Die Quellindizes für den Sticker-Austausch.

Rückgabe:

  • candidate_states numpy.ndarray - Die aktualisierten Zustände der Kandidatenzustände.
note

Die Verwendung von numba.jit verlangsamt diese Funktion tatsächlich.