search
Ez a modul egy függvényt biztosít a nyalábalapú keresés (beam search) végrehajtására és egy adott állapot megoldásainak megtalálására.
Függvény:
beam_search: Nyalábalapú keresést hajt végre megoldások keresésére egy Rubik-kocka környezetben.
MAX_BATCH_SIZE
Az egyszerre egy DNN-en előrecsatolással (forward-pass) feldolgozott állapotok maximális száma.
beam_search
def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)
Nyalábalapú keresést (beam search) hajt végre egy adott kevert állapot megoldásainak megtalálására.
Argumentumok:
envCube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.modeltorch.nn.Module - A DNN, amely minden állapothoz megjósolja a következő lépések valószínűségi eloszlását.beam_widthint - A keresés minden lépésében megtartandó jelöltek maximális száma.ergonomic_biasdict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.extra_depthsint - Az első megoldás mélységén túli további keresési mélységek száma.max_depthint - A maximális keresési mélység, amelynek egyenlőnek vagy nagyobbnak kell lennie Isten számánál (a Rubik-kocka esetében ez 20 HTM-ben).
Visszatérési érték:
dict|None: Ha legalább egy megoldás található, egy szótár a következő kulcsokkal:
"solutions": A keresés során talált optimális vagy közel optimális megoldások listája."num_nodes": A keresés során kibontott csomópontok teljes száma."time": A keresés befejezéséhez szükséges idő (másodpercben).
Ellenkező esetben None.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Inicializálja az ergonomikus torzítást, ha meg van adva.
Argumentumok:
ergonomic_biasdict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.envCube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
Visszatérési érték:
ergonomic_biasnumpy.ndarray - A lépések ergonomikus torzítása, ha van ilyen.envCube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Megjósolja a következő lépések valószínűségi eloszlását minden állapothoz.
Argumentumok:
modeltorch.nn.Module - A DNN, amely minden állapothoz megjósolja a következő lépések valószínűségi eloszlását.batch_xnumpy.ndarray - Állapotok kötege (batch).ergonomic_biasdict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.envCube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
Visszatérési érték:
batch_logprobnumpy.ndarray - Az egyes lépések logaritmikus valószínűsége minden állapothoz.
Az Automatikus Kevert Pontosságú (Automatic Mixed Precision) inferencia valamilyen okból kifolyólag valamivel gyorsabb, mint az egyszerű félpontosságú (a model.half() használatával).
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Kibővíti a jelölt útvonalakat a következő lépések megjósolt valószínűségeivel.
Argumentumok:
candidatesdict - Egy szótár, amely a jelölt útvonalakat, a kumulatív valószínűségeket és az állapotokat tartalmazza.batch_logprobnumpy.ndarray - Az egyes lépések logaritmikus valószínűsége minden állapothoz.envCube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.depthint - A keresés aktuális mélysége.beam_widthint - A keresés minden lépésében megtartandó jelöltek maximális száma.
Visszatérési érték:
candidatesdict - A frissített szótár, amely a jelölt útvonalakat, a kumulatív valószínűségeket és az állapotokat tartalmazza.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Lekéri a korábbi lépések alapján metszendő jelöltek indexeit.
Argumentumok:
candidates_pathsnumpy.ndarray - A jelölt állapotok útvonalai.allow_widebool - Meghatározza, hogy engedélyezettek-e a széles (wide) mozdulatok.depthint - A keresés aktuális mélysége.
Visszatérési érték:
prune_idxnumpy.ndarray - A metszendő jelöltek indexei.
A numba.jit használata valójában lelassítja ezt a függvényt.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Frissíti az állapotokat a kibővített útvonalak alapján.
Argumentumok:
candidate_statesnumpy.ndarray - A jelölt állapotok állapotai.candidate_pathsnumpy.ndarray - A jelölt állapotok útvonalai.envCube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
Visszatérési érték:
candidate_statesnumpy.ndarray - A jelölt állapotok frissített állapotai.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Matricacsere végrehajtása kötegszinten.
Argumentumok:
candidate_statesnumpy.ndarray - A jelölt állapotok állapotai.target_ixnumpy.ndarray - A matricacsere célindexei.source_ixnumpy.ndarray - A matricacsere forrásindexei.
Visszatérési érték:
candidate_statesnumpy.ndarray - A jelölt állapotok frissített állapotai.
A numba.jit használata valójában lelassítja ezt a függvényt.