search
Este módulo proporciona una función para realizar una búsqueda por haz (beam search) y encontrar soluciones para un estado dado.
Función:
beam_search: Realiza una búsqueda por haz para encontrar soluciones en un entorno de Cubo de Rubik.
MAX_BATCH_SIZE
El número máximo de estados que pasan a través de una DNN en un solo paso hacia adelante (forward-pass).
beam_search
def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)
Realiza una búsqueda por haz para encontrar soluciones para un estado desordenado dado.
Argumentos:
envCube3 - El entorno del Cubo de Rubik que representa el estado desordenado.modeltorch.nn.Module - La DNN utilizada para predecir la distribución de probabilidad de los siguientes movimientos para cada estado.beam_widthint - El número máximo de candidatos a mantener en cada paso de la búsqueda.ergonomic_biasdict o None - Un diccionario que especifica el sesgo ergonómico para los movimientos, si está disponible.extra_depthsint - El número de profundidades adicionales a buscar más allá de la profundidad de la primera solución.max_depthint - La profundidad máxima de búsqueda, debe ser igual o mayor que el Número de Dios (20 para el Cubo de Rubik en HTM).
Retorna:
dict|None: Si se encuentra al menos una solución, un diccionario con las siguientes claves:
"solutions": Una lista de soluciones óptimas o casi óptimas encontradas durante la búsqueda."num_nodes": El número total de nodos expandidos durante la búsqueda."time": El tiempo empleado (en segundos) para completar la búsqueda.
En caso contrario, None.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Inicializa el sesgo ergonómico si se proporciona.
Argumentos:
ergonomic_biasdict o None - Un diccionario que especifica el sesgo ergonómico para los movimientos, si está disponible.envCube3 - El entorno del Cubo de Rubik que representa el estado desordenado.
Retorna:
ergonomic_biasnumpy.ndarray - El sesgo ergonómico para los movimientos, si está disponible.envCube3 - El entorno del Cubo de Rubik que representa el estado desordenado.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Predice la distribución de probabilidad de los siguientes movimientos para cada estado.
Argumentos:
modeltorch.nn.Module - La DNN utilizada para predecir la distribución de probabilidad de los siguientes movimientos para cada estado.batch_xnumpy.ndarray - Lote de estados.ergonomic_biasdict o None - Un diccionario que especifica el sesgo ergonómico para los movimientos, si está disponible.envCube3 - El entorno del Cubo de Rubik que representa el estado desordenado.
Retorna:
batch_logprobnumpy.ndarray - El logaritmo de la probabilidad de cada movimiento para cada estado.
La inferencia con Precisión Mixta Automática es ligeramente más rápida que
la de precisión media simple (con model.half()) por algunas razones.
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Expande las rutas candidatas con las probabilidades predichas de los siguientes movimientos.
Argumentos:
candidatesdict - Un diccionario que contiene las rutas candidatas, las probabilidades acumuladas y los estados.batch_logprobnumpy.ndarray - El logaritmo de la probabilidad de cada movimiento para cada estado.envCube3 - El entorno del Cubo de Rubik que representa el estado desordenado.depthint - La profundidad actual de la búsqueda.beam_widthint - El número máximo de candidatos a mantener en cada paso de la búsqueda.
Retorna:
candidatesdict - El diccionario actualizado que contiene las rutas candidatas, las probabilidades acumuladas y los estados.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Obtiene los índices de los candidatos a podar basándose en los movimientos anteriores.
Argumentos:
candidates_pathsnumpy.ndarray - Las rutas de los estados candidatos.allow_widebool - Si se permiten movimientos anchos (wide moves).depthint - La profundidad actual de la búsqueda.
Retorna:
prune_idxnumpy.ndarray - Los índices de los candidatos a podar.
Usar numba.jit en realidad ralentiza esta función.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Actualiza los estados basándose en las rutas expandidas.
Argumentos:
candidate_statesnumpy.ndarray - Los estados de los estados candidatos.candidate_pathsnumpy.ndarray - Las rutas de los estados candidatos.envCube3 - El entorno del Cubo de Rubik que representa el estado desordenado.
Retorna:
candidate_statesnumpy.ndarray - Los estados actualizados de los estados candidatos.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Realiza el reemplazo de pegatinas a nivel de lote.
Argumentos:
candidate_statesnumpy.ndarray - Los estados de los estados candidatos.target_ixnumpy.ndarray - Los índices de destino para el reemplazo de pegatinas.source_ixnumpy.ndarray - Los índices de origen para el reemplazo de pegatinas.
Retorna:
candidate_statesnumpy.ndarray - Los estados actualizados de los estados candidatos.
Usar numba.jit en realidad ralentiza esta función.