Pular para o conteúdo principal

model

Carregador de Modelo

Este módulo fornece uma função e classes para carregar e usar um modelo de solucionador AlphaCube pré-treinado.

Função: load_model(model_id, cache_dir): Carrega o modelo de solucionador AlphaCube pré-treinado.

Classes:

  • Model: A arquitetura MLP para o solucionador AlphaCube.
  • LinearBlock: Um bloco de construção para a arquitetura MLP.

load_model

def load_model(
model_id="small",
cache_dir=os.path.expanduser("~/.cache/alphacube")
)

Carrega o modelo de solucionador AlphaCube pré-treinado.

Argumentos:

  • model_id str - Identificador para a variante do modelo a ser carregada ("small", "base" ou "large").
  • cache_dir str - Diretório para armazenar em cache o modelo baixado.

Retorna:

  • nn.Module - Modelo de solucionador AlphaCube carregado.

Model_v1

class Model_v1(nn.Module)

A arquitetura MLP para o solucionador de Cubo de Rubik computacionalmente ótimo introduzido no seguinte artigo: https://openreview.net/forum?id=bnBeNFB27b

forward

def forward(inputs)

Passagem direta (forward pass) do modelo.

Argumentos:

  • inputs torch.Tensor - Tensor de entrada representando o estado do problema.

Retorna:

  • torch.Tensor - Distribuição prevista sobre as possíveis soluções.

Model

class Model(nn.Module)

Uma arquitetura melhor que a Model.

Mudanças:

  • Remove a ativação ReLU da primeira camada (embedding), que tinha o problema do ReLU moribundo (dying ReLU).
  • Seguindo a convenção recente, a camada embedding não conta como uma camada oculta.

reset_parameters

def reset_parameters()

Inicializa todos os pesos de forma que as variâncias de ativação sejam aproximadamente 1.0, com os termos de bias e os pesos da camada de saída sendo zeros

forward

def forward(inputs)

Passagem direta (forward pass) do modelo.

Argumentos:

  • inputs torch.Tensor - Tensor de entrada representando o estado do problema.

Retorna:

  • torch.Tensor - Distribuição prevista sobre as possíveis soluções.

LinearBlock

class LinearBlock(nn.Module)

Um bloco de construção para a arquitetura MLP.

Este bloco consiste em uma camada linear seguida por ativação ReLU e normalização de lote (batch normalization).

forward

def forward(inputs)

Passagem direta (forward pass) do bloco linear.

Argumentos:

  • inputs torch.Tensor - Tensor de entrada.

Retorna:

  • torch.Tensor - Tensor de saída após transformação linear, ativação ReLU e normalização de lote.