search
这个模块提供了一个函数来执行 beam search 并为给定状态找到解法。
函数:
beam_search
:在魔方环境中执行 beam search 以找到解法。
MAX_BATCH_SIZE
一次通过 DNN 前向传递的最大状态数。
beam_search
def beam_search(env, model, beam_width, ergonomic_bias=None, extra_depths=0, max_depth=100)
执行 beam search 以找到给定打乱状态的解法。
参数:
env
Cube3 - 表示打乱状态的魔方环境。model
torch.nn.Module - 用于预测每个状态下一步动作概率分布的 DNN。beam_width
int - 搜索过程中每一步保留的最大候选数。ergonomic_bias
dict or None - 指定动作人体工程学偏好的字典,如果有的话。extra_depths
int - 在找到第一个解法深度之后额外搜索的深度数。max_depth
int - 搜索的最大深度,应该等于或大于上帝数(HTM 表示法下的魔方为 20)。
返回:
dict
|None
: 如果至少找到一个解法,返回一个包含以下键的字典:
"solutions"
:搜索过程中找到的最优或近似最优解法列表。"num_nodes"
:搜索过程中扩展的节点总数。"time"
:完成搜索所花费的时间(以秒为单位)。
否则,返回 None
。
_reflect_setup
def _reflect_setup(ergonomic_bias, env)
如果提供了人体工程学偏好,则进行初始化。
参数:
ergonomic_bias
dict or None - 指定动作人体工程学偏好的字典,如果有的话。env
Cube3 - 表示打乱状态的魔方环境。
返回:
ergonomic_bias
numpy.ndarray - 动作的人体工程学偏好,如果有的话。env
Cube3 - 表示打乱状态的魔方环境。
predict
@torch.inference_mode()
def predict(model, batch_x, ergonomic_bias, env)
预测每个状态下一步动作的概率分布。
参数:
model
torch.nn.Module - 用于预测每个状态下一步动作概率分布的 DNN。batch_x
numpy.ndarray - 状态批次。ergonomic_bias
dict or None - 指定动作人体工程学偏好 的字典,如果有的话。env
Cube3 - 表示打乱状态的魔方环境。
返回:
batch_logprob
numpy.ndarray - 每个状态下每个动作的对数概率。
备注
出于某些原因,使用 Automatic Mixed Prevision 进行推理比简单的半精度(使用 model.half()
)略快。
update_candidates
def update_candidates(candidates, batch_logprob, env, depth, beam_width)
使用预测的下一步动作概率扩展候选路径。
参数:
candidates
dict - 包含候选路径、累积概率和状态的字典。batch_logprob
numpy.ndarray - 每个状态下每个动作的对数概率。env
Cube3 - 表示打乱状态的魔方环境。depth
int - 当前搜索深度。beam_width
int - 搜索过程中每一步保留的最大候选数。
返回:
candidates
dict - 更新后包含候选路 径、累积概率和状态的字典。
_get_prune_idx
def _get_prune_idx(candidates_paths, allow_wide, depth)
根据先前的动作获取要剪枝的候选索引。
参数:
candidates_paths
numpy.ndarray - 候选状态的路径。allow_wide
bool - 是否允许宽动作。depth
int - 当前搜索深度。
返回:
prune_idx
numpy.ndarray - 要剪枝的候选索引。
备注
使用 numba.jit
实际上会减慢这个函数的速度。
_update_states
def _update_states(candidate_states, candidate_paths, env)
根据扩展的路径更新状态。
参数:
candidate_states
numpy.ndarray - 候选状态的状态。candidate_paths
numpy.ndarray - 候选状态的路径。env
Cube3 - 表示打乱状态的魔方环境。
返回:
candidate_states
numpy.ndarray - 更新后的候选状态的状态。
_map_state
def _map_state(candidate_states, target_ix, source_ix)
在批次级别执行贴纸替换。
参数:
candidate_states
numpy.ndarray - 候选状态的状态。target_ix
numpy.ndarray - 贴纸替换的目标索引。source_ix
numpy.ndarray - 贴纸替换的源索引。
返回:
candidate_states
numpy.ndarray - 更新后的候选状态的状态。
备注
使用 numba.jit
实际上会减慢这个函数的速度。