📌 什么是 GRU?

GRU(Gated Recurrent Unit)是一种简化版的 RNN,通过重置门更新门机制捕捉序列依赖关系。相比 LSTM,GRU 参数更少,训练更快,常用于自然语言处理和时间序列预测。

Gated_Recurrent_Unit

🧠 代码示例:使用 PyTorch 实现 GRU

1. 导入库

import torch
import torch.nn as nn

2. 定义 GRU 模型

class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.gru(x, h0)
        out = self.fc(out[:, -1, :])  # 取最后一个时间步的输出
        return out

3. 训练循环

model = GRUModel(input_size=10, hidden_size=20, num_layers=2, num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

⚠️ 注意事项

  • GRU 的输入形状应为 (batch_size, sequence_length, input_size)
  • 可通过调整 hidden_sizenum_layers 改善模型性能
  • 适合处理较短序列数据,长序列建议使用 LSTM 或 Transformer

📄 延伸阅读

想深入了解 LSTM?点击这里查看对比教程!
PyTorch 官方文档中关于 GRU 的详细说明:GRU 官方文档