📌 什么是 GRU?
GRU(Gated Recurrent Unit)是一种简化版的 RNN,通过重置门和更新门机制捕捉序列依赖关系。相比 LSTM,GRU 参数更少,训练更快,常用于自然语言处理和时间序列预测。
🧠 代码示例:使用 PyTorch 实现 GRU
1. 导入库
import torch
import torch.nn as nn
2. 定义 GRU 模型
class GRUModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(GRUModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, _ = self.gru(x, h0)
out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出
return out
3. 训练循环
model = GRUModel(input_size=10, hidden_size=20, num_layers=2, num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
⚠️ 注意事项
- GRU 的输入形状应为
(batch_size, sequence_length, input_size)
- 可通过调整
hidden_size
和num_layers
改善模型性能 - 适合处理较短序列数据,长序列建议使用 LSTM 或 Transformer