search
This module provides a function to perform beam search and find solutions for a given state.
Function:
beam_search
: 在魔方环境中执行集束搜索以寻找解法。
MAX_BATCH_SIZE
一次通过 DNN 进行前向传播的最大状态数。
beam_search
def beam_search(
env,
model,
beam_width,
ergonomic_bias=None,
extra_depths=0,
max_depth=100
)
为给定的打乱状态执行集束搜索以寻找解法。
参数:
env
Cube3 - 代表打乱状态的魔方环境。model
torch.nn.Module - 用于预测每个状态下一步转动概率分布的 DNN。beam_width
int - 在搜索的每一步中保留的最大候选数量。ergonomic_bias
dict or None - 一个指定转动的人体工程学偏差的字典(如果可用)。extra_depths
int - 在找到第一个解法后,需要额外搜索的深度。max_depth
int - 搜索的最大深度,应等于或大于上帝之数(对于半圈计算法的魔方为 20)。
返回:
dict
|None
: 如果找到至少一个解法,则返回一个包含以下键的字典:
"solutions"
: 搜索过程中找到的最优或接近最优解法的列表。"num_nodes"
: 搜索过程中扩展的节点总数。"time"
: 完成搜索所花费的时间(秒)。
否则返回 None
。
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
如果提供了人体工程学偏差,则进行初始化。
参数:
ergonomic_bias
dict or None - 一个指定转动的人体工程学偏差的字典(如果可用)。env
Cube3 - 代表打乱状态的魔方环境。
返回:
ergonomic_bias
numpy.ndarray - 转动的人体工程学偏差(如果可用)。env
Cube3 - 代表打乱状态的魔方环境。
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
预测每个状态下一步转动的概率分布。
参数:
model
torch.nn.Module - 用于预测每个状态下一步转动概率分布的 DNN。batch_x
numpy.ndarray - 一批状态。ergonomic_bias
dict or None - 一个指定转动的人体工程学偏差的字典(如果可用)。env
Cube3 - 代表打乱状态的魔方环境。
返回:
batch_logprob
numpy.ndarray - 每个状态下每次转动的对数概率。
备注
由于某些原因,使用自动混合精度进行推理比简单的半精度(使用 model.half()
)要快一些。
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
使用预测的下一步转动概率来扩展候选路径。
参数:
candidates
dict - 一个包含候选路径、累积概率和状态的字典。batch_logprob
numpy.ndarray - 每个状态下每次转动的对数概率。env
Cube3 - 代表打乱状态的魔方环境。depth
int - 当前搜索的深度。beam_width
int - 在搜索的每一步中保留的最大候选数量。
返回:
candidates
dict - 更新后的字典,包含候选路径、累积概率和状态。
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
根据之前的转动获取要修剪的候选者的索引。
参数:
candidates_paths
numpy.ndarray - 候选状态的路径。allow_wide
bool - 是否允许宽转。depth
int - 当前搜索的深度。
返回:
prune_idx
numpy.ndarray - 要修剪的候选者的索引。
备注
使用 numba.jit
实际上会减慢这个函数的速度。
_update_states
def _update_states(candidate_states, candidate_paths, env)
根据扩展的路径更新状态。
参数:
candidate_states
numpy.ndarray - 候选状态的状态。candidate_paths
numpy.ndarray - 候选状态的路径。env
Cube3 - 代表打乱状态的魔方环境。
返回:
candidate_states
numpy.ndarray - 更新后的候选状态的状态。
_map_state
def _map_state(candidate_states, target_ix, source_ix)
在批处理级别执行贴纸替换。
参数:
candidate_states
numpy.ndarray - 候选状态的状态。target_ix
numpy.ndarray - 贴纸替换的目标索引。source_ix
numpy.ndarray - 贴纸替换的源索引。
返回:
candidate_states
numpy.ndarray - 更新后的候选状态的状态。
备注
使用 numba.jit
实际上会减慢这个函数的速度。