在深度学习项目中,保存和加载模型是至关重要的步骤。Keras 提供了简单而强大的接口来保存和加载模型,无论是为了部署还是进行进一步的训练。

保存模型

保存模型可以通过 save() 方法实现。以下是一个简单的示例:

from keras.models import Sequential
from keras.layers import Dense

# 创建一个简单的模型
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# 假设我们有一些数据
# X_train, y_train = ...

# 训练模型
# model.fit(X_train, y_train, epochs=10, batch_size=10)

# 保存模型
model.save('/path/to/my_model.h5')

在上面的代码中,model.save() 方法将整个模型保存到一个 .h5 文件中。

加载模型

加载模型同样简单,使用 load_model() 方法即可:

from keras.models import load_model

# 加载模型
loaded_model = load_model('/path/to/my_model.h5')

# 使用加载的模型进行预测
# predictions = loaded_model.predict(X_test)

高级保存选项

Keras 允许你保存模型的不同部分,例如仅保存模型架构或仅保存训练状态:

# 保存模型架构
model.save('/path/to/my_model_architecture.json', save_weights_only=False)

# 加载模型架构
from keras.models import model_from_json
model_architecture = model_from_json(open('/path/to/my_model_architecture.json').read())

# 保存训练状态
model.save_weights('/path/to/my_model_weights.h5')

# 加载训练状态
model.load_weights('/path/to/my_model_weights.h5')

注意事项

  • 保存的 .h5 文件包含了模型的配置、权重和训练历史。
  • 当你加载一个模型时,确保你的环境中安装了所有必要的库,以避免运行时错误。

更多信息,请访问我们的模型保存与加载教程

神经网络