search
Ten moduł dostarcza funkcję do przeprowadzania przeszukiwania wiązkowego i znajdowania rozwiązań dla danego stanu.
Funkcja:
beam_search
: Przeprowadza przeszukiwanie wiązkowe w celu znalezienia rozwiązań w środowisku Kostki Rubika.
MAX_BATCH_SIZE
Maksymalna liczba stanów przetwarzanych jednocześnie przez sieć neuronową (DNN) w jednym przejściu w przód (forward-pass).
beam_search
def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)
Przeprowadza przeszukiwanie wiązkowe w celu znalezienia rozwiązań dla danego, pomieszanego stanu.
Argumenty:
env
Cube3 - Środowisko Kostki Rubika reprezentujące pomieszany stan.model
torch.nn.Module - Sieć neuronowa (DNN) używana do przewidywania rozkładu prawdopodobieństwa następnych ruchów dla każdego stanu.beam_width
int - Maksymalna liczba kandydatów do zachowania na każdym kroku przeszukiwania.ergonomic_bias
dict lub None - Słownik określający preferencje ergonomiczne dla ruchów, jeśli jest dostępny.extra_depths
int - Liczba dodatkowych poziomów głębokości do przeszukania po znalezieniu pierwszego rozwiązania.max_depth
int - Maksymalna głębokość przeszukiwania, powinna być równa lub większa niż Liczba Boga (20 dla Kostki Rubika w metryce HTM).
Zwraca:
dict
|None
: Jeśli znaleziono co najmniej jedno rozwiązanie, słownik z następującymi kluczami:
"solutions"
: Lista optymalnych lub niemal optymalnych rozwiązań znalezionych podczas przeszukiwania."num_nodes"
: Całkowita liczba węzłów rozwiniętych podczas przeszukiwania."time"
: Czas (w sekundach) potrzebny na ukończenie przeszukiwania.
W przeciwnym razie None
.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Inicjalizuje preferencje ergonomiczne, jeśli zostały podane.
Argumenty:
ergonomic_bias
dict lub None - Słownik określający preferencje ergonomiczne dla ruchów, jeśli jest dostępny.env
Cube3 - Środowisko Kostki Rubika reprezentujące pomieszany stan.
Zwraca:
ergonomic_bias
numpy.ndarray - Preferencje ergonomiczne dla ruchów, jeśli są dostępne.env
Cube3 - Środowisko Kostki Rubika reprezentujące pomieszany stan.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Przewiduje rozkład prawdopodobieństwa następnych ruchów dla każdego stanu.
Argumenty:
model
torch.nn.Module - Sieć neuronowa (DNN) używana do przewidywania rozkładu prawdopodobieństwa następnych ruchów dla każdego stanu.batch_x
numpy.ndarray - Paczka (batch) stanów.ergonomic_bias
dict lub None - Słownik określający preferencje ergonomiczne dla ruchów, jeśli jest dostępny.env
Cube3 - Środowisko Kostki Rubika reprezentujące pomieszany stan.
Zwraca:
batch_logprob
numpy.ndarray - Logarytm prawdopodobieństwa każdego ruchu dla każdego stanu.
Z pewnych powodów, wnioskowanie z użyciem Automatycznej Mieszanej Precyzji jest nieco szybsze niż
prosta połówkowa precyzja (z model.half()
).
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Rozszerza ścieżki kandydatów o przewidziane prawdopodobieństwa następnych ruchów.
Argumenty:
candidates
dict - Słownik zawierający ścieżki kandydatów, skumulowane prawdopodobieństwa i stany.batch_logprob
numpy.ndarray - Logarytm prawdopodobieństwa każdego ruchu dla każdego stanu.env
Cube3 - Środowisko Kostki Rubika reprezentujące pomieszany stan.depth
int - Bieżąca głębokość przeszukiwania.beam_width
int - Maksymalna liczba kandydatów do zachowania na każdym kroku przeszukiwania.
Zwraca:
candidates
dict - Zaktualizowany słownik zawierający ścieżki kandydatów, skumulowane prawdopodobieństwa i stany.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Pobiera indeksy kandydatów do odrzucenia (prune) na podstawie poprzednich ruchów.
Argumenty:
candidates_paths
numpy.ndarray - Ścieżki stanów-kandydatów.allow_wide
bool - Czy zezwalać na ruchy szerokie (wide moves).depth
int - Bieżąca głębokość przeszukiwania.
Zwraca:
prune_idx
numpy.ndarray - Indeksy kandydatów do odrzucenia.
Użycie numba.jit
w rzeczywistości spowalnia tę funkcję.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Aktualizuje stany na podstawie rozszerzonych ścieżek.
Argumenty:
candidate_states
numpy.ndarray - Stany kandydatów.candidate_paths
numpy.ndarray - Ścieżki stanów-kandydatów.env
Cube3 - Środowisko Kostki Rubika reprezentujące pomieszany stan.
Zwraca:
candidate_states
numpy.ndarray - Zaktualizowane stany kandydatów.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Wykonuje zamianę naklejek na poziomie paczki (batch).
Argumenty:
candidate_states
numpy.ndarray - Stany kandydatów.target_ix
numpy.ndarray - Docelowe indeksy do zamiany naklejek.source_ix
numpy.ndarray - Źródłowe indeksy do zamiany naklejek.
Zwraca:
candidate_states
numpy.ndarray - Zaktualizowane stany kandydatów.
Użycie numba.jit
w rzeczywistości spowalnia tę funkcję.