跳到主要内容

search

这个模块提供了一个函数来执行 beam search 并为给定状态找到解法。

函数: beam_search:在魔方环境中执行 beam search 以找到解法。

MAX_BATCH_SIZE

一次通过 DNN 前向传递的最大状态数。

def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)

执行 beam search 以找到给定打乱状态的解法。

参数

  • env Cube3 - 表示打乱状态的魔方环境。
  • model torch.nn.Module - 用于预测每个状态下一步动作概率分布的 DNN。
  • beam_width int - 搜索过程中每一步保留的最大候选数。
  • ergonomic_bias dict or None - 指定动作人体工程学偏好的字典,如果有的话。
  • extra_depths int - 在找到第一个解法深度之后额外搜索的深度数。
  • max_depth int - 搜索的最大深度,应该等于或大于上帝数(HTM 表示法下的魔方为 20)。

返回

  • dict | None: 如果至少找到一个解法,返回一个包含以下键的字典:
  1. "solutions":搜索过程中找到的最优或近似最优解法列表。
  2. "num_nodes":搜索过程中扩展的节点总数。
  3. "time":完成搜索所花费的时间(以秒为单位)。

否则,返回 None

_reflect_setup

def _reflect_setup(ergonomic_bias, env)

如果提供了人体工程学偏好,则进行初始化。

参数

  • ergonomic_bias dict or None - 指定动作人体工程学偏好的字典,如果有的话。
  • env Cube3 - 表示打乱状态的魔方环境。

返回

  • ergonomic_bias numpy.ndarray - 动作的人体工程学偏好,如果有的话。
  • env Cube3 - 表示打乱状态的魔方环境。

predict

@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)

预测每个状态下一步动作的概率分布。

参数

  • model torch.nn.Module - 用于预测每个状态下一步动作概率分布的 DNN。
  • batch_x numpy.ndarray - 状态批次。
  • ergonomic_bias dict or None - 指定动作人体工程学偏好的字典,如果有的话。
  • env Cube3 - 表示打乱状态的魔方环境。

返回

  • batch_logprob numpy.ndarray - 每个状态下每个动作的对数概率。
备注

出于某些原因,使用 Automatic Mixed Prevision 进行推理比简单的半精度(使用 model.half())略快。

update_candidates

def update_candidates(candidates, batch_logprob, env, depth, beam_width)

使用预测的下一步动作概率扩展候选路径。

参数

  • candidates dict - 包含候选路径、累积概率和状态的字典。
  • batch_logprob numpy.ndarray - 每个状态下每个动作的对数概率。
  • env Cube3 - 表示打乱状态的魔方环境。
  • depth int - 当前搜索深度。
  • beam_width int - 搜索过程中每一步保留的最大候选数。

返回

  • candidates dict - 更新后包含候选路径、累积概率和状态的字典。

_get_prune_idx

def _get_prune_idx(candidates_paths, allow_wide, depth)

根据先前的动作获取要剪枝的候选索引。

参数

  • candidates_paths numpy.ndarray - 候选状态的路径。
  • allow_wide bool - 是否允许宽动作。
  • depth int - 当前搜索深度。

返回

  • prune_idx numpy.ndarray - 要剪枝的候选索引。
备注

使用 numba.jit 实际上会减慢这个函数的速度。

_update_states

def _update_states(candidate_states, candidate_paths, env)

根据扩展的路径更新状态。

参数

  • candidate_states numpy.ndarray - 候选状态的状态。
  • candidate_paths numpy.ndarray - 候选状态的路径。
  • env Cube3 - 表示打乱状态的魔方环境。

返回

  • candidate_states numpy.ndarray - 更新后的候选状态的状态。

_map_state

def _map_state(candidate_states, target_ix, source_ix)

在批次级别执行贴纸替换。

参数

  • candidate_states numpy.ndarray - 候选状态的状态。
  • target_ix numpy.ndarray - 贴纸替换的目标索引。
  • source_ix numpy.ndarray - 贴纸替换的源索引。

返回

  • candidate_states numpy.ndarray - 更新后的候选状态的状态。
备注

使用 numba.jit 实际上会减慢这个函数的速度。