注意力机制是深度学习中重要的技术,以下为常见实现方式:
1. 基础 Self-Attention 实现
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_size):
super().__init__()
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
def forward(self, x):
Q = self.query(x)
K = self.key(x)
V = self.value(x)
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor([x.size(-1)]))
attention_weights = torch.softmax(attention_scores, dim=-1)
return torch.matmul(attention_weights, V)
2. Transformer 模型示例
class TransformerBlock(nn.Module):
def __init__(self, embed_size, num_heads, ff_size):
super().__init__()
self.attention = SelfAttention(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, ff_size),
nn.ReLU(),
nn.Linear(ff_size, embed_size)
)
def forward(self, x):
x = self.attention(x)
x = self.feed_forward(x)
return x
3. 多头注意力机制
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super().__init__()
self.heads = nn.ModuleList([SelfAttention(embed_size) for _ in range(num_heads)])
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)