1. 教程概述 📚
MNIST 是经典的机器学习数据集,包含 60,000 张 28x28 的手写数字图像,常用于入门深度学习和神经网络。通过本教程,你将学习如何使用 TensorFlow Keras 实现一个简单的图像分类模型。
2. 步骤详解 🧱
2.1 导入必要的库
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
🧠 TensorFlow Keras 提供了高级 API,简化了模型构建流程。
2.2 加载和预处理数据
# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist.load_data()
(x_train, y_train), (x_test, y_test) = mnist
# 数据归一化
x_train = x_train / 255.0
x_test = x_test / 255.0
📊 你可以通过 TensorFlow 数据集指南 了解更多数据处理技巧。
2.3 构建模型
model = models.Sequential([
layers.Flatten(input_shape=(28, 28)),
layers.Dense(128, activation='relu'),
layers.Dropout(0.2),
layers.Dense(10, activation='softmax')
])
🧱 模型结构如上图所示,包含全连接层和Dropout层以防止过拟合。
2.4 编译和训练模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
📈 训练完成后,模型将在测试集上达到约 98% 的准确率。
2.5 评估和预测
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"测试准确率: {test_acc}")
# 预测示例
predictions = model.predict(x_test)
🔍 你可以在 Keras 模型调优指南 中探索更复杂的预测方法。
3. 扩展学习 📈
- TensorFlow 官方 MNIST 教程(英文)
- 如何用 Keras 进行图像分类(中文)
- 深度学习模型可视化工具(中文)
4. 小贴士 📌
- 使用
plt.imshow()
可以直观查看图像数据 - 尝试调整
epochs
数量观察训练效果变化 - 更多实战案例请查看 TensorFlow 教程合集