注意力机制是深度学习中重要的技术,以下为常见实现方式:

1. 基础 Self-Attention 实现

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
    
    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor([x.size(-1)]))
        attention_weights = torch.softmax(attention_scores, dim=-1)
        return torch.matmul(attention_weights, V)
Self_Attention

2. Transformer 模型示例

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, num_heads, ff_size):
        super().__init__()
        self.attention = SelfAttention(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, ff_size),
            nn.ReLU(),
            nn.Linear(ff_size, embed_size)
        )
    
    def forward(self, x):
        x = self.attention(x)
        x = self.feed_forward(x)
        return x

🔗 了解更多 Transformer 模型结构

3. 多头注意力机制

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(embed_size) for _ in range(num_heads)])
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
Multihead_Attention

扩展阅读