这个页面将为你介绍如何使用 TensorFlow 进行简单的 MNIST 数据集分类。MNIST 是一个手写数字的数据库,包含了 0 到 9 的手写数字图片。

MNIST 数据集简介

MNIST 数据集包含 60,000 个训练样本和 10,000 个测试样本。每个样本都是一个 28x28 的灰度图像,代表一个手写数字。

安装 TensorFlow

在开始之前,请确保你已经安装了 TensorFlow。你可以通过以下命令安装:

pip install tensorflow

示例代码

以下是一个简单的 MNIST 分类示例代码:

import tensorflow as tf

# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化像素值
x_train, x_test = x_train / 255.0, x_test / 255.0

# 创建模型
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

# 编译模型
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=5)

# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)

扩展阅读

如果你对 TensorFlow 感兴趣,可以阅读以下教程:

图片示例

下面是一个 MNIST 数据集中的示例图片:

MNIST sample image