search
Ez a modul egy függvényt biztosít a nyalábalapú keresés (beam search) végrehajtására és egy adott állapot megoldásainak megtalálására.
Függvény:
beam_search
: Nyalábalapú keresést hajt végre megoldások keresésére egy Rubik-kocka környezetben.
MAX_BATCH_SIZE
Az egyszerre egy DNN-en előrecsatolással (forward-pass) feldolgozott állapotok maximális száma.
beam_search
def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)
Nyalábalapú keresést (beam search) hajt végre egy adott kevert állapot megoldásainak megtalálására.
Argumentumok:
env
Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.model
torch.nn.Module - A DNN, amely minden állapothoz megjósolja a következő lépések valószínűségi eloszlását.beam_width
int - A keresés minden lépésében megtartandó jelöltek maximális száma.ergonomic_bias
dict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.extra_depths
int - Az első megoldás mélységén túli további keresési mélységek száma.max_depth
int - A maximális keresési mélység, amelynek egyenlőnek vagy nagyobbnak kell lennie Isten számánál (a Rubik-kocka esetében ez 20 HTM-ben).
Visszatérési érték:
dict
|None
: Ha legalább egy megoldás található, egy szótár a következő kulcsokkal:
"solutions"
: A keresés során talált optimális vagy közel optimális megoldások listája."num_nodes"
: A keresés során kibontott csomópontok teljes száma."time"
: A keresés befejezéséhez szükséges idő (másodpercben).
Ellenkező esetben None
.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Inicializálja az ergonomikus torzítást, ha meg van adva.
Argumentumok:
ergonomic_bias
dict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.env
Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
Visszatérési érték:
ergonomic_bias
numpy.ndarray - A lépések ergonomikus torzítása, ha van ilyen.env
Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Megjósolja a következő lépések valószínűségi eloszlását minden állapothoz.
Argumentumok:
model
torch.nn.Module - A DNN, amely minden állapothoz megjósolja a következő lépések valószínűségi eloszlását.batch_x
numpy.ndarray - Állapotok kötege (batch).ergonomic_bias
dict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.env
Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
Visszatérési érték:
batch_logprob
numpy.ndarray - Az egyes lépések logaritmikus valószínűsége minden állapothoz.
Az Automatikus Kevert Pontosságú (Automatic Mixed Precision) inferencia valamilyen okból kifolyólag valamivel gyorsabb, mint az egyszerű félpontosságú (a model.half()
használatával).
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Kibővíti a jelölt útvonalakat a következő lépések megjósolt valószínűségeivel.
Argumentumok:
candidates
dict - Egy szótár, amely a jelölt útvonalakat, a kumulatív valószínűségeket és az állapotokat tartalmazza.batch_logprob
numpy.ndarray - Az egyes lépések logaritmikus valószínűsége minden állapothoz.env
Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.depth
int - A keresés aktuális mélysége.beam_width
int - A keresés minden lépésében megtartandó jelöltek maximális száma.
Visszatérési érték:
candidates
dict - A frissített szótár, amely a jelölt útvonalakat, a kumulatív valószínűségeket és az állapotokat tartalmazza.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Lekéri a korábbi lépések alapján metszendő jelöltek indexeit.
Argumentumok:
candidates_paths
numpy.ndarray - A jelölt állapotok útvonalai.allow_wide
bool - Meghatározza, hogy engedélyezettek-e a széles (wide) mozdulatok.depth
int - A keresés aktuális mélysége.
Visszatérési érték:
prune_idx
numpy.ndarray - A metszendő jelöltek indexei.
A numba.jit
használata valójában lelassítja ezt a függvényt.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Frissíti az állapotokat a kibővített útvonalak alapján.
Argumentumok:
candidate_states
numpy.ndarray - A jelölt állapotok állapotai.candidate_paths
numpy.ndarray - A jelölt állapotok útvonalai.env
Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
Visszatérési érték:
candidate_states
numpy.ndarray - A jelölt állapotok frissített állapotai.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Matricacsere végrehajtása kötegszinten.
Argumentumok:
candidate_states
numpy.ndarray - A jelölt állapotok állapotai.target_ix
numpy.ndarray - A matricacsere célindexei.source_ix
numpy.ndarray - A matricacsere forrásindexei.
Visszatérési érték:
candidate_states
numpy.ndarray - A jelölt állapotok frissített állapotai.
A numba.jit
használata valójában lelassítja ezt a függvényt.