Ugrás a fő tartalomhoz

search

Ez a modul egy függvényt biztosít a nyalábalapú keresés (beam search) végrehajtására és egy adott állapot megoldásainak megtalálására.

Függvény: beam_search: Nyalábalapú keresést hajt végre megoldások keresésére egy Rubik-kocka környezetben.

MAX_BATCH_SIZE

Az egyszerre egy DNN-en előrecsatolással (forward-pass) feldolgozott állapotok maximális száma.

def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)

Nyalábalapú keresést (beam search) hajt végre egy adott kevert állapot megoldásainak megtalálására.

Argumentumok:

  • env Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
  • model torch.nn.Module - A DNN, amely minden állapothoz megjósolja a következő lépések valószínűségi eloszlását.
  • beam_width int - A keresés minden lépésében megtartandó jelöltek maximális száma.
  • ergonomic_bias dict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.
  • extra_depths int - Az első megoldás mélységén túli további keresési mélységek száma.
  • max_depth int - A maximális keresési mélység, amelynek egyenlőnek vagy nagyobbnak kell lennie Isten számánál (a Rubik-kocka esetében ez 20 HTM-ben).

Visszatérési érték:

  • dict | None: Ha legalább egy megoldás található, egy szótár a következő kulcsokkal:
  1. "solutions": A keresés során talált optimális vagy közel optimális megoldások listája.
  2. "num_nodes": A keresés során kibontott csomópontok teljes száma.
  3. "time": A keresés befejezéséhez szükséges idő (másodpercben).

Ellenkező esetben None.

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

Inicializálja az ergonomikus torzítást, ha meg van adva.

Argumentumok:

  • ergonomic_bias dict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.
  • env Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.

Visszatérési érték:

  • ergonomic_bias numpy.ndarray - A lépések ergonomikus torzítása, ha van ilyen.
  • env Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

Megjósolja a következő lépések valószínűségi eloszlását minden állapothoz.

Argumentumok:

  • model torch.nn.Module - A DNN, amely minden állapothoz megjósolja a következő lépések valószínűségi eloszlását.
  • batch_x numpy.ndarray - Állapotok kötege (batch).
  • ergonomic_bias dict vagy None - Egy szótár, amely megadja a lépések ergonomikus torzítását, ha van ilyen.
  • env Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.

Visszatérési érték:

  • batch_logprob numpy.ndarray - Az egyes lépések logaritmikus valószínűsége minden állapothoz.
megjegyzés

Az Automatikus Kevert Pontosságú (Automatic Mixed Precision) inferencia valamilyen okból kifolyólag valamivel gyorsabb, mint az egyszerű félpontosságú (a model.half() használatával).

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

Kibővíti a jelölt útvonalakat a következő lépések megjósolt valószínűségeivel.

Argumentumok:

  • candidates dict - Egy szótár, amely a jelölt útvonalakat, a kumulatív valószínűségeket és az állapotokat tartalmazza.
  • batch_logprob numpy.ndarray - Az egyes lépések logaritmikus valószínűsége minden állapothoz.
  • env Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.
  • depth int - A keresés aktuális mélysége.
  • beam_width int - A keresés minden lépésében megtartandó jelöltek maximális száma.

Visszatérési érték:

  • candidates dict - A frissített szótár, amely a jelölt útvonalakat, a kumulatív valószínűségeket és az állapotokat tartalmazza.

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

Lekéri a korábbi lépések alapján metszendő jelöltek indexeit.

Argumentumok:

  • candidates_paths numpy.ndarray - A jelölt állapotok útvonalai.
  • allow_wide bool - Meghatározza, hogy engedélyezettek-e a széles (wide) mozdulatok.
  • depth int - A keresés aktuális mélysége.

Visszatérési érték:

  • prune_idx numpy.ndarray - A metszendő jelöltek indexei.
megjegyzés

A numba.jit használata valójában lelassítja ezt a függvényt.

_update_states

def _update_states(candidate_states, candidate_paths, env)

Frissíti az állapotokat a kibővített útvonalak alapján.

Argumentumok:

  • candidate_states numpy.ndarray - A jelölt állapotok állapotai.
  • candidate_paths numpy.ndarray - A jelölt állapotok útvonalai.
  • env Cube3 - A kevert állapotot reprezentáló Rubik-kocka környezet.

Visszatérési érték:

  • candidate_states numpy.ndarray - A jelölt állapotok frissített állapotai.

_map_state

def _map_state(candidate_states, target_ix, source_ix)

Matricacsere végrehajtása kötegszinten.

Argumentumok:

  • candidate_states numpy.ndarray - A jelölt állapotok állapotai.
  • target_ix numpy.ndarray - A matricacsere célindexei.
  • source_ix numpy.ndarray - A matricacsere forrásindexei.

Visszatérési érték:

  • candidate_states numpy.ndarray - A jelölt állapotok frissített állapotai.
megjegyzés

A numba.jit használata valójában lelassítja ezt a függvényt.