1. 教程概述 📚

MNIST 是经典的机器学习数据集,包含 60,000 张 28x28 的手写数字图像,常用于入门深度学习和神经网络。通过本教程,你将学习如何使用 TensorFlow Keras 实现一个简单的图像分类模型。

2. 步骤详解 🧱

2.1 导入必要的库

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt

🧠 TensorFlow Keras 提供了高级 API,简化了模型构建流程。

2.2 加载和预处理数据

# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist.load_data()
(x_train, y_train), (x_test, y_test) = mnist

# 数据归一化
x_train = x_train / 255.0
x_test = x_test / 255.0

📊 你可以通过 TensorFlow 数据集指南 了解更多数据处理技巧。

2.3 构建模型

model = models.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(10, activation='softmax')
])

🧱 模型结构如上图所示,包含全连接层和Dropout层以防止过拟合。

mnist_data

2.4 编译和训练模型

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

📈 训练完成后,模型将在测试集上达到约 98% 的准确率。

2.5 评估和预测

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"测试准确率: {test_acc}")

# 预测示例
predictions = model.predict(x_test)

🔍 你可以在 Keras 模型调优指南 中探索更复杂的预测方法。

3. 扩展学习 📈

4. 小贴士 📌

  • 使用 plt.imshow() 可以直观查看图像数据
  • 尝试调整 epochs 数量观察训练效果变化
  • 更多实战案例请查看 TensorFlow 教程合集
神经网络