MNIST 是一个经典的手写数字识别数据集,常用于入门机器学习和深度学习模型训练。通过本教程,你将学习如何使用 TensorFlow 构建并训练一个简单的神经网络来识别 MNIST 数据。
📚 教程目标
- 使用 TensorFlow 加载 MNIST 数据集
- 构建一个全连接神经网络模型
- 训练并评估模型性能
- 保存模型并进行预测
🧰 步骤指南
1. 导入依赖
import tensorflow as tf
from tensorflow.keras import layers, models
2. 加载数据
# 加载 MNIST 数据集(已内置在 TensorFlow 中)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
3. 数据预处理
# 归一化数据到 [0,1] 范围
x_train = x_train / 255.0
x_test = x_test / 255.0
# 展开维度以适配模型输入
x_train = x_train.reshape(-1, 28*28)
x_test = x_test.reshape(-1, 28*28)
4. 构建模型
model = models.Sequential([
layers.Dense(512, activation='relu', input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
5. 训练模型
model.fit(x_train, y_train, epochs=5)
6. 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试准确率: {test_acc:.4f}")
7. 保存与预测
model.save('mnist_model.h5')
# 使用模型进行预测
predictions = model.predict(x_test)
📚 扩展学习
想要深入了解 TensorFlow 的其他用法?可以参考 TensorFlow 入门指南 获取更多基础知识。
对于图像分类任务,也可以尝试 PyTorch 实现 MNIST 进行对比学习。
📌 小贴士
- 使用
tf.keras
可以快速构建模型 - Dropout 层有助于防止过拟合
- 适当调整
epochs
数量可优化训练效果