跳到主要内容

search

This module provides a function to perform beam search and find solutions for a given state.

Function: beam_search: 在魔方环境中执行集束搜索以寻找解法。

MAX_BATCH_SIZE

一次通过 DNN 进行前向传播的最大状态数。

def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)

为给定的打乱状态执行集束搜索以寻找解法。

参数

  • env Cube3 - 代表打乱状态的魔方环境。
  • model torch.nn.Module - 用于预测每个状态下一步转动概率分布的 DNN。
  • beam_width int - 在搜索的每一步中保留的最大候选数量。
  • ergonomic_bias dict or None - 一个指定转动的人体工程学偏差的字典(如果可用)。
  • extra_depths int - 在找到第一个解法后,需要额外搜索的深度。
  • max_depth int - 搜索的最大深度,应等于或大于上帝之数(对于半圈计算法的魔方为 20)。

返回

  • dict | None: 如果找到至少一个解法,则返回一个包含以下键的字典:
  1. "solutions": 搜索过程中找到的最优或接近最优解法的列表。
  2. "num_nodes": 搜索过程中扩展的节点总数。
  3. "time": 完成搜索所花费的时间(秒)。

否则返回 None

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

如果提供了人体工程学偏差,则进行初始化。

参数

  • ergonomic_bias dict or None - 一个指定转动的人体工程学偏差的字典(如果可用)。
  • env Cube3 - 代表打乱状态的魔方环境。

返回

  • ergonomic_bias numpy.ndarray - 转动的人体工程学偏差(如果可用)。
  • env Cube3 - 代表打乱状态的魔方环境。

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

预测每个状态下一步转动的概率分布。

参数

  • model torch.nn.Module - 用于预测每个状态下一步转动概率分布的 DNN。
  • batch_x numpy.ndarray - 一批状态。
  • ergonomic_bias dict or None - 一个指定转动的人体工程学偏差的字典(如果可用)。
  • env Cube3 - 代表打乱状态的魔方环境。

返回

  • batch_logprob numpy.ndarray - 每个状态下每次转动的对数概率。
备注

由于某些原因,使用自动混合精度进行推理比简单的半精度(使用 model.half())要快一些。

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

使用预测的下一步转动概率来扩展候选路径。

参数

  • candidates dict - 一个包含候选路径、累积概率和状态的字典。
  • batch_logprob numpy.ndarray - 每个状态下每次转动的对数概率。
  • env Cube3 - 代表打乱状态的魔方环境。
  • depth int - 当前搜索的深度。
  • beam_width int - 在搜索的每一步中保留的最大候选数量。

返回

  • candidates dict - 更新后的字典,包含候选路径、累积概率和状态。

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

根据之前的转动获取要修剪的候选者的索引。

参数

  • candidates_paths numpy.ndarray - 候选状态的路径。
  • allow_wide bool - 是否允许宽转。
  • depth int - 当前搜索的深度。

返回

  • prune_idx numpy.ndarray - 要修剪的候选者的索引。
备注

使用 numba.jit 实际上会减慢这个函数的速度。

_update_states

def _update_states(candidate_states, candidate_paths, env)

根据扩展的路径更新状态。

参数

  • candidate_states numpy.ndarray - 候选状态的状态。
  • candidate_paths numpy.ndarray - 候选状态的路径。
  • env Cube3 - 代表打乱状态的魔方环境。

返回

  • candidate_states numpy.ndarray - 更新后的候选状态的状态。

_map_state

def _map_state(candidate_states, target_ix, source_ix)

在批处理级别执行贴纸替换。

参数

  • candidate_states numpy.ndarray - 候选状态的状态。
  • target_ix numpy.ndarray - 贴纸替换的目标索引。
  • source_ix numpy.ndarray - 贴纸替换的源索引。

返回

  • candidate_states numpy.ndarray - 更新后的候选状态的状态。
备注

使用 numba.jit 实际上会减慢这个函数的速度。