什么是 RNN?

RNN(Recurrent Neural Network)是一种专为序列数据设计的神经网络,通过时间步(Time Step)机制捕捉数据间的时序依赖关系。其核心结构包含:

  • 隐藏层(Hidden Layer):存储序列状态
  • 循环连接:允许信息传递
  • 遗忘门:控制信息保留
循环神经网络结构

RNN 的典型应用场景 ✅

  1. 自然语言处理

    • 情感分析(如:自然语言处理_情感分析)
    • 文本生成(如:聊天机器人、摘要工具)
    • 机器翻译(如:序列到序列模型)
  2. 时间序列预测

    • 股票价格预测(如:时间序列预测)
    • 天气预测
  3. 语音识别

    • 音频到文本转换
自然语言处理_情感分析

快速入门示例 📈

import torch
from torch import nn

class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output, (hidden, cell) = self.lstm(embedded)
        return self.fc(output)

扩展学习 🔍

时间序列预测