MNIST 是机器学习领域经典的入门数据集,包含 28x28 的手写数字图像。以下将用 TensorFlow 实现一个简单的分类模型,适合初学者快速上手 🚀
步骤概览 📝
数据加载 📁
使用tf.keras.datasets.mnist
加载数据集,包含 60,000 张训练图和 10,000 张测试图模型构建 🛠️
创建一个包含Flatten
、Dense
层的全连接网络训练与评估 📈
使用交叉熵损失函数和 Adam 优化器,准确率可达 98%+
代码示例 🧪
import tensorflow as tf
# 加载数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# 数据预处理
x_train = x_train.reshape(-1, 28*28) / 255.0
x_test = x_test.reshape(-1, 28*28) / 255.0
# 构建模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=5)
# 评估模型
model.evaluate(x_test, y_test)