model
Carregador de Modelo
Este módulo fornece uma função e classes para carregar e usar um modelo de solucionador AlphaCube pré-treinado.
Função:
load_model(model_id, cache_dir)
: Carrega o modelo de solucionador AlphaCube pré-treinado.
Classes:
Model
: A arquitetura MLP para o solucionador AlphaCube.LinearBlock
: Um bloco de construção para a arquitetura MLP.
load_model
def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)
Carrega o modelo de solucionador AlphaCube pré-treinado.
Argumentos:
model_id
str - Identificador para a variante do modelo a ser carregada ("small", "base" ou "large").cache_dir
str - Diretório para armazenar em cache o modelo baixado.
Retorna:
nn.Module
- Modelo de solucionador AlphaCube carregado.
Model_v1
class Model_v1(nn.Module)
A arquitetura MLP para o solucionador de Cubo de Rubik computacionalmente ótimo introduzido no seguinte artigo: https://openreview.net/forum?id=bnBeNFB27b
forward
def forward(inputs)
Passagem direta (forward pass) do modelo.
Argumentos:
inputs
torch.Tensor - Tensor de entrada representando o estado do problema.
Retorna:
torch.Tensor
- Distribuição prevista sobre as possíveis soluções.
Model
class Model(nn.Module)
Uma arquitetura melhor que a Model
.
Mudanças:
- Remove a ativação ReLU da primeira camada (
embedding
), que tinha o problema do ReLU moribundo (dying ReLU). - Seguindo a convenção recente, a camada
embedding
não conta como uma camada oculta.
reset_parameters
def reset_parameters()
Inicializa todos os pesos de forma que as variâncias de ativação sejam aproximadamente 1.0, com os termos de bias e os pesos da camada de saída sendo zeros
forward
def forward(inputs)
Passagem direta (forward pass) do modelo.
Argumentos:
inputs
torch.Tensor - Tensor de entrada representando o estado do problema.
Retorna:
torch.Tensor
- Distribuição prevista sobre as possíveis soluções.
LinearBlock
class LinearBlock(nn.Module)
Um bloco de construção para a arquitetura MLP.
Este bloco consiste em uma camada linear seguida por ativação ReLU e normalização de lote (batch normalization).
forward
def forward(inputs)
Passagem direta (forward pass) do bloco linear.
Argumentos:
inputs
torch.Tensor - Tensor de entrada.
Retorna:
torch.Tensor
- Tensor de saída após transformação linear, ativação ReLU e normalização de lote.