MNIST 项目是一个手写数字识别的数据集,它包含了 60,000 个训练样本和 10,000 个测试样本。这些样本都是 28x28 像素的灰度图像,其中每个图像都对应一个 0 到 9 的数字。

项目结构

以下是一个简单的 MNIST 项目结构:

  • data/: 存放 MNIST 数据集的文件夹
  • models/: 存放训练好的模型
  • src/: 项目源代码
    • __init__.py
    • data_loader.py: 数据加载相关代码
    • model.py: 模型定义相关代码
    • train.py: 训练模型相关代码
    • test.py: 测试模型相关代码

安装依赖

在开始之前,请确保已安装以下依赖:

  • Python 3.x
  • TensorFlow
  • NumPy

您可以使用以下命令安装 TensorFlow:

pip install tensorflow

数据加载

data_loader.py 文件中,我们定义了一个名为 load_data 的函数,用于加载数据集:

import numpy as np

def load_data():
    # 加载训练数据
    x_train, y_train = np.load('data/train_data.npy'), np.load('data/train_labels.npy')
    # 加载测试数据
    x_test, y_test = np.load('data/test_data.npy'), np.load('data/test_labels.npy')
    return (x_train, y_train), (x_test, y_test)

模型定义

model.py 文件中,我们定义了一个简单的卷积神经网络:

import tensorflow as tf

def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

训练模型

train.py 文件中,我们定义了一个名为 train_model 的函数,用于训练模型:

import tensorflow as tf
from model import create_model

def train_model():
    (x_train, y_train), (x_test, y_test) = load_data()
    model = create_model()
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
    model.save('models/mnist_model.h5')

测试模型

test.py 文件中,我们定义了一个名为 test_model 的函数,用于测试模型:

import tensorflow as tf
from model import create_model

def test_model():
    (x_train, y_train), (x_test, y_test) = load_data()
    model = create_model()
    model.load_weights('models/mnist_model.h5')
    predictions = model.predict(x_test)
    print('Test accuracy:', np.mean(predictions.argmax(axis=1) == y_test))

扩展阅读

更多关于 MNIST 项目的信息,请访问我们的 MNIST 项目页面

中心图片:MNIST Sample Image{center}