search
Este módulo fornece uma função para realizar uma busca em feixe (beam search) e encontrar soluções para um determinado estado.
Função:
beam_search
: Realiza uma busca em feixe para encontrar soluções em um ambiente de Cubo de Rubik.
MAX_BATCH_SIZE
O número máximo de estados a serem processados (forward-pass) por uma DNN de uma só vez.
beam_search
def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)
Realiza uma busca em feixe para encontrar soluções para um determinado estado embaralhado.
Argumentos:
env
Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.model
torch.nn.Module - A DNN usada para prever a distribuição de probabilidade dos próximos movimentos para cada estado.beam_width
int - O número máximo de candidatos a serem mantidos em cada passo da busca.ergonomic_bias
dict ou None - Um dicionário que especifica o viés ergonômico para os movimentos, se disponível.extra_depths
int - O número de profundidades adicionais a serem pesquisadas além da profundidade da primeira solução encontrada.max_depth
int - A profundidade máxima da busca, deve ser igual ou maior que o Número de Deus (20 para o Cubo de Rubik em HTM).
Retorna:
dict
|None
: Com pelo menos uma solução, um dicionário com as seguintes chaves:
"solutions"
: Uma lista de soluções ótimas ou quase ótimas encontradas durante a busca."num_nodes"
: O número total de nós expandidos durante a busca."time"
: O tempo gasto (em segundos) para completar a busca.
Caso contrário, None
.
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
Inicializa o viés ergonômico, se fornecido.
Argumentos:
ergonomic_bias
dict ou None - Um dicionário que especifica o viés ergonômico para os movimentos, se disponível.env
Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.
Retorna:
ergonomic_bias
numpy.ndarray - O viés ergonômico para os movimentos, se disponível.env
Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
Prevê a distribuição de probabilidade dos próximos movimentos para cada estado.
Argumentos:
model
torch.nn.Module - A DNN usada para prever a distribuição de probabilidade dos próximos movimentos para cada estado.batch_x
numpy.ndarray - Lote (batch) de estados.ergonomic_bias
dict ou None - Um dicionário que especifica o viés ergonômico para os movimentos, se disponível.env
Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.
Retorna:
batch_logprob
numpy.ndarray - O logaritmo da probabilidade de cada movimento para cada estado.
A inferência com Precisão Mista Automática (Automatic Mixed Precision) é ligeiramente mais rápida do que
a de meia precisão simples (com model.half()
) por algumas razões.
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
Expande os caminhos candidatos com as probabilidades previstas dos próximos movimentos.
Argumentos:
candidates
dict - Um dicionário contendo os caminhos candidatos, as probabilidades acumuladas e os estados.batch_logprob
numpy.ndarray - O logaritmo da probabilidade de cada movimento para cada estado.env
Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.depth
int - A profundidade atual da busca.beam_width
int - O número máximo de candidatos a serem mantidos em cada passo da busca.
Retorna:
candidates
dict - O dicionário atualizado contendo os caminhos candidatos, as probabilidades acumuladas e os estados.
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
Obtém os índices dos candidatos a serem podados (prune) com base nos movimentos anteriores.
Argumentos:
candidates_paths
numpy.ndarray - Os caminhos dos estados candidatos.allow_wide
bool - Se deve permitir movimentos largos (wide moves).depth
int - A profundidade atual da busca.
Retorna:
prune_idx
numpy.ndarray - Os índices dos candidatos a serem podados.
O uso de numba.jit
na verdade torna esta função mais lenta.
_update_states
def _update_states(candidate_states, candidate_paths, env)
Atualiza os estados com base nos caminhos expandidos.
Argumentos:
candidate_states
numpy.ndarray - Os estados dos estados candidatos.candidate_paths
numpy.ndarray - Os caminhos dos estados candidatos.env
Cube3 - O ambiente do Cubo de Rubik que representa o estado embaralhado.
Retorna:
candidate_states
numpy.ndarray - Os estados atualizados dos estados candidatos.
_map_state
def _map_state(candidate_states, target_ix, source_ix)
Realiza a substituição de adesivos (stickers) em nível de lote (batch).
Argumentos:
candidate_states
numpy.ndarray - Os estados dos estados candidatos.target_ix
numpy.ndarray - Os índices de destino para a substituição de adesivos.source_ix
numpy.ndarray - Os índices de origem para a substituição de adesivos.
Retorna:
candidate_states
numpy.ndarray - Os estados atualizados dos estados candidatos.
O uso de numba.jit
na verdade torna esta função mais lenta.