Pular para o conteúdo principal

search

Este módulo fornece uma função para realizar uma busca em feixe (beam search) e encontrar soluções para um determinado estado.

Função: beam_search: Realiza uma busca em feixe para encontrar soluções em um ambiente de Cubo de Rubik.

MAX_BATCH_SIZE

O número máximo de estados a serem processados (forward-pass) por uma DNN de uma só vez.

def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)

Realiza uma busca em feixe para encontrar soluções para um determinado estado embaralhado.

Argumentos:

  • env Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.
  • model torch.nn.Module - A DNN usada para prever a distribuição de probabilidade dos próximos movimentos para cada estado.
  • beam_width int - O número máximo de candidatos a serem mantidos em cada passo da busca.
  • ergonomic_bias dict ou None - Um dicionário que especifica o viés ergonômico para os movimentos, se disponível.
  • extra_depths int - O número de profundidades adicionais a serem pesquisadas além da profundidade da primeira solução encontrada.
  • max_depth int - A profundidade máxima da busca, deve ser igual ou maior que o Número de Deus (20 para o Cubo de Rubik em HTM).

Retorna:

  • dict | None: Com pelo menos uma solução, um dicionário com as seguintes chaves:
  1. "solutions": Uma lista de soluções ótimas ou quase ótimas encontradas durante a busca.
  2. "num_nodes": O número total de nós expandidos durante a busca.
  3. "time": O tempo gasto (em segundos) para completar a busca.

Caso contrário, None.

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

Inicializa o viés ergonômico, se fornecido.

Argumentos:

  • ergonomic_bias dict ou None - Um dicionário que especifica o viés ergonômico para os movimentos, se disponível.
  • env Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.

Retorna:

  • ergonomic_bias numpy.ndarray - O viés ergonômico para os movimentos, se disponível.
  • env Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

Prevê a distribuição de probabilidade dos próximos movimentos para cada estado.

Argumentos:

  • model torch.nn.Module - A DNN usada para prever a distribuição de probabilidade dos próximos movimentos para cada estado.
  • batch_x numpy.ndarray - Lote (batch) de estados.
  • ergonomic_bias dict ou None - Um dicionário que especifica o viés ergonômico para os movimentos, se disponível.
  • env Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.

Retorna:

  • batch_logprob numpy.ndarray - O logaritmo da probabilidade de cada movimento para cada estado.
nota

A inferência com Precisão Mista Automática (Automatic Mixed Precision) é ligeiramente mais rápida do que a de meia precisão simples (com model.half()) por algumas razões.

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

Expande os caminhos candidatos com as probabilidades previstas dos próximos movimentos.

Argumentos:

  • candidates dict - Um dicionário contendo os caminhos candidatos, as probabilidades acumuladas e os estados.
  • batch_logprob numpy.ndarray - O logaritmo da probabilidade de cada movimento para cada estado.
  • env Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.
  • depth int - A profundidade atual da busca.
  • beam_width int - O número máximo de candidatos a serem mantidos em cada passo da busca.

Retorna:

  • candidates dict - O dicionário atualizado contendo os caminhos candidatos, as probabilidades acumuladas e os estados.

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

Obtém os índices dos candidatos a serem podados (prune) com base nos movimentos anteriores.

Argumentos:

  • candidates_paths numpy.ndarray - Os caminhos dos estados candidatos.
  • allow_wide bool - Se deve permitir movimentos largos (wide moves).
  • depth int - A profundidade atual da busca.

Retorna:

  • prune_idx numpy.ndarray - Os índices dos candidatos a serem podados.
nota

O uso de numba.jit na verdade torna esta função mais lenta.

_update_states

def _update_states(candidate_states, candidate_paths, env)

Atualiza os estados com base nos caminhos expandidos.

Argumentos:

  • candidate_states numpy.ndarray - Os estados dos estados candidatos.
  • candidate_paths numpy.ndarray - Os caminhos dos estados candidatos.
  • env Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.

Retorna:

  • candidate_states numpy.ndarray - Os estados atualizados dos estados candidatos.

_map_state

def _map_state(candidate_states, target_ix, source_ix)

Realiza a substituição de adesivos (stickers) em nível de lote (batch).

Argumentos:

  • candidate_states numpy.ndarray - Os estados dos estados candidatos.
  • target_ix numpy.ndarray - Os índices de destino para a substituição de adesivos.
  • source_ix numpy.ndarray - Os índices de origem para a substituição de adesivos.

Retorna:

  • candidate_states numpy.ndarray - Os estados atualizados dos estados candidatos.
nota

O uso de numba.jit na verdade torna esta função mais lenta.