基础结构

import torch
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_len, num_layers=1, num_epochs=10):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, 
                                          num_encoder_layers=num_encoder_layers, 
                                          num_decoder_layers=num_decoder_layers, 
                                          dim_feedforward=dim_feedforward)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.max_seq_len = max_seq_len
        self.num_epochs = num_epochs

    def forward(self, src, tgt):
        src_emb = self.embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        tgt_emb = self.embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        out = self.transformer(src_emb, tgt_emb)
        return self.fc_out(out)

关键组件说明

  • 🧠 嵌入层:将输入转换为固定维度向量
  • 📌 位置编码:通过 torch.nn.Transformer 自动处理
  • 🔄 Transformer 核心:包含多头注意力机制与前馈网络
  • 📈 训练参数:设置 num_epochs 控制训练轮数

扩展学习

如需了解不同框架实现对比,可查看 code_examples/transformer/tensorflow 路径中的 TensorFlow 示例
如需了解训练技巧,可参考 code_examples/transformer/training 路径中的详细说明

Transformer_模型结构
PyTorch_Logo