search
Ten moduł udostępnia funkcję do wykonywania przeszukiwania wiązkowego (beam search) i znajdowania rozwiązań dla danego stanu.
Funkcja:
beam_search
: Wykonuje przeszukiwanie wiązkowe, aby znaleźć rozwiązania w środowisku kostki Rubika.
MAX_BATCH_SIZE
Maksymalna liczba stanów przekazywanych jednocześnie przez sieć neuronową (DNN).
beam_search
def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)
Wykonuje przeszukiwanie wiązkowe, aby znaleźć rozwiązania dla danego pomieszanego stanu.
Argumenty:
env
Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.model
torch.nn.Module - Sieć neuronowa (DNN) używana do przewidywania rozkładu prawdopodobieństwa następnych ruchów dla każdego stanu.beam_width
int - Maksymalna liczba kandydatów do zachowania na każdym kroku przeszukiwania.ergonomic_bias
dict lub None - Słownik określający ergonomiczne obciążenie dla ruchów, jeśli jest dostępny.extra_depths
int - Liczba dodatkowych głębokości do przeszukania poza głębokością pierwszego rozwiązania.max_depth
int - Maksymalna głębokość przeszukiwania, powinna być równa lub większa od liczby Boga (20 dla kostki Rubika w HTM).
Zwraca:
dict
|None
: Przy co najmniej jednym rozwiązaniu, słownik z następującymi kluczami:
"solutions"
: Lista optymalnych lub prawie optymalnych rozwiązań znalezionych podczas przeszukiwania."num_nodes"
: Całkowita liczba węzłów rozszerzonych podczas przeszukiwania."time"
: Czas (w sekundach) potrzebny na ukończenie przeszukiwania.
W przeciwnym razie, None
.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Inicjalizuje ergonomiczne obciążenie, jeśli jest podane.
Argumenty:
ergonomic_bias
dict lub None - Słownik określający ergonomiczne obciążenie dla ruchów, jeśli jest dostępny.env
Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.
Zwraca:
ergonomic_bias
numpy.ndarray - Ergonomiczne obciążenie dla ruchów, jeśli jest dostępne.env
Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Przewiduje rozkład prawdopodobieństwa następnych ruchów dla każdego stanu.
Argumenty:
model
torch.nn.Module - Sieć neuronowa (DNN) używana do przewidywania rozkładu prawdopodobieństwa następnych ruchów dla każdego stanu.batch_x
numpy.ndarray - Partia stanów.ergonomic_bias
dict lub None - Słownik określający ergonomiczne obciążenie dla ruchów, jeśli jest dostępny.env
Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.
Zwraca:
batch_logprob
numpy.ndarray - Logarytm prawdopodobieństwa każdego ruchu dla każdego stanu.
Wnioskowanie z Automatic Mixed Prevision jest nieco szybsze niż
prosta precyzja połówkowa (z model.half()
) z pewnych powodów.
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Rozszerza ścieżki kandydatów o przewidywane prawdopodobieństwa następnych ruchów.
Argumenty:
candidates
dict - Słownik zawierający ścieżki kandydatów, skumulowane prawdopodobieństwa i stany.batch_logprob
numpy.ndarray - Logarytm prawdopodobieństwa każdego ruchu dla każdego stanu.env
Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.depth
int - Aktualna głębokość przeszukiwania.beam_width
int - Maksymalna liczba kandydatów do zachowania na każdym kroku przeszukiwania.
Zwraca:
candidates
dict - Zaktualizowany słownik zawierający ścieżki kandydatów, skumulowane prawdopodobieństwa i stany.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Pobiera indeksy kandydatów do przycięcia na podstawie poprzednich ruchów.
Argumenty:
candidates_paths
numpy.ndarray - Ścieżki stanów kandydatów.allow_wide
bool - Czy zezwolić na szerokie ruchy.depth
int - Aktualna głębokość przeszukiwania.
Zwraca:
prune_idx
numpy.ndarray - Indeksy kandydatów do przycięcia.
Użycie numba.jit
w rzeczywistości spowalnia tę funkcję.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Aktualizuje stany na podstawie rozszerzonych ścieżek.
Argumenty:
candidate_states
numpy.ndarray - Stany kandydatów.candidate_paths
numpy.ndarray - Ścieżki stanów kandydatów.env
Cube3 - Środowisko kostki Rubika reprezentujące pomieszany stan.
Zwraca:
candidate_states
numpy.ndarray - Zaktualizowane stany kandydatów.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Wykonuje zastępowanie naklejek na poziomie partii.
Argumenty:
candidate_states
numpy.ndarray - Stany kandydatów.target_ix
numpy.ndarray - Indeksy docelowe do zastąpienia naklejek.source_ix
numpy.ndarray - Indeksy źródłowe do zastąpienia naklejek.
Zwraca:
candidate_states
numpy.ndarray - Zaktualizowane stany kandydatów.
Użycie numba.jit
w rzeczywistości spowalnia tę funkcję.