MNIST 是一个经典的手写数字识别数据集,常用于机器学习入门。本教程将使用 TensorFlow 实现一个简单的 CNN 模型来识别 MNIST 图像。📚
步骤概览 🛠️
- 环境准备:安装 TensorFlow 和相关依赖
- 数据加载:从 Keras 数据库加载 MNIST 数据集
- 模型构建:设计卷积神经网络结构
- 训练与评估:训练模型并测试准确率
- 保存模型:导出训练好的模型以备后续使用
代码示例 ✅
import tensorflow as tf
from tensorflow.keras import layers, models
# 加载 MNIST 数据集 📁
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 构建模型 🏗️
model = models.Sequential([
layers.Flatten(input_shape=(28, 28)),
layers.Dense(128, activation='relu'),
layers.Dropout(0.2),
layers.Dense(10)
])
# 编译模型 🔄
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 训练模型 🚀
model.fit(x_train, y_train, epochs=5)
# 评估模型 📊
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc}")