Przejdź do głównej zawartości

search

Ten moduł udostępnia funkcję do wykonywania przeszukiwania wiązkowego (beam search) i znajdowania rozwiązań dla danego stanu.

Funkcja: beam_search: Wykonuje przeszukiwanie wiązkowe, aby znaleźć rozwiązania w środowisku kostki Rubika.

MAX_BATCH_SIZE

Maksymalna liczba stanów przekazywanych jednocześnie przez sieć neuronową (DNN).

def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)

Wykonuje przeszukiwanie wiązkowe, aby znaleźć rozwiązania dla danego pomieszanego stanu.

Argumenty:

  • env Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.
  • model torch.nn.Module - Sieć neuronowa (DNN) używana do przewidywania rozkładu prawdopodobieństwa następnych ruchów dla każdego stanu.
  • beam_width int - Maksymalna liczba kandydatów do zachowania na każdym kroku przeszukiwania.
  • ergonomic_bias dict lub None - Słownik określający ergonomiczne obciążenie dla ruchów, jeśli jest dostępny.
  • extra_depths int - Liczba dodatkowych głębokości do przeszukania poza głębokością pierwszego rozwiązania.
  • max_depth int - Maksymalna głębokość przeszukiwania, powinna być równa lub większa od liczby Boga (20 dla kostki Rubika w HTM).

Zwraca:

  • dict | None: Przy co najmniej jednym rozwiązaniu, słownik z następującymi kluczami:
  1. "solutions": Lista optymalnych lub prawie optymalnych rozwiązań znalezionych podczas przeszukiwania.
  2. "num_nodes": Całkowita liczba węzłów rozszerzonych podczas przeszukiwania.
  3. "time": Czas (w sekundach) potrzebny na ukończenie przeszukiwania.

W przeciwnym razie, None.

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

Inicjalizuje ergonomiczne obciążenie, jeśli jest podane.

Argumenty:

  • ergonomic_bias dict lub None - Słownik określający ergonomiczne obciążenie dla ruchów, jeśli jest dostępny.
  • env Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.

Zwraca:

  • ergonomic_bias numpy.ndarray - Ergonomiczne obciążenie dla ruchów, jeśli jest dostępne.
  • env Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

Przewiduje rozkład prawdopodobieństwa następnych ruchów dla każdego stanu.

Argumenty:

  • model torch.nn.Module - Sieć neuronowa (DNN) używana do przewidywania rozkładu prawdopodobieństwa następnych ruchów dla każdego stanu.
  • batch_x numpy.ndarray - Partia stanów.
  • ergonomic_bias dict lub None - Słownik określający ergonomiczne obciążenie dla ruchów, jeśli jest dostępny.
  • env Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.

Zwraca:

  • batch_logprob numpy.ndarray - Logarytm prawdopodobieństwa każdego ruchu dla każdego stanu.
notatka

Wnioskowanie z Automatic Mixed Prevision jest nieco szybsze niż prosta precyzja połówkowa (z model.half()) z pewnych powodów.

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

Rozszerza ścieżki kandydatów o przewidywane prawdopodobieństwa następnych ruchów.

Argumenty:

  • candidates dict - Słownik zawierający ścieżki kandydatów, skumulowane prawdopodobieństwa i stany.
  • batch_logprob numpy.ndarray - Logarytm prawdopodobieństwa każdego ruchu dla każdego stanu.
  • env Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.
  • depth int - Aktualna głębokość przeszukiwania.
  • beam_width int - Maksymalna liczba kandydatów do zachowania na każdym kroku przeszukiwania.

Zwraca:

  • candidates dict - Zaktualizowany słownik zawierający ścieżki kandydatów, skumulowane prawdopodobieństwa i stany.

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

Pobiera indeksy kandydatów do przycięcia na podstawie poprzednich ruchów.

Argumenty:

  • candidates_paths numpy.ndarray - Ścieżki stanów kandydatów.
  • allow_wide bool - Czy zezwolić na szerokie ruchy.
  • depth int - Aktualna głębokość przeszukiwania.

Zwraca:

  • prune_idx numpy.ndarray - Indeksy kandydatów do przycięcia.
notatka

Użycie numba.jit w rzeczywistości spowalnia tę funkcję.

_update_states

def _update_states(candidate_states, candidate_paths, env)

Aktualizuje stany na podstawie rozszerzonych ścieżek.

Argumenty:

  • candidate_states numpy.ndarray - Stany kandydatów.
  • candidate_paths numpy.ndarray - Ścieżki stanów kandydatów.
  • env Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.

Zwraca:

  • candidate_states numpy.ndarray - Zaktualizowane stany kandydatów.

_map_state

def _map_state(candidate_states, target_ix, source_ix)

Wykonuje zastępowanie naklejek na poziomie partii.

Argumenty:

  • candidate_states numpy.ndarray - Stany kandydatów.
  • target_ix numpy.ndarray - Indeksy docelowe do zastąpienia naklejek.
  • source_ix numpy.ndarray - Indeksy źródłowe do zastąpienia naklejek.

Zwraca:

  • candidate_states numpy.ndarray - Zaktualizowane stany kandydatów.
notatka

Użycie numba.jit w rzeczywistości spowalnia tę funkcję.