logo

TensorFlow 模型保存和加载

王哲峰 / 2022-07-15


目录

tf.train.Checkpoint: 变量的保存与恢复

很多时候, 希望在模型训练完成后能将训练好的参数(变量)保存起来, 这样在需要使用模型的其他地方载入模型和参数, 就能直接得到训练好的模型, 保存模型有很多中方式:

import pickle

tf.train.Checkpoint 介绍

TensorFlow 提供了 tf.train.Checkpoint 这一强大的变量保存与恢复类, 提供的方法可以保存和恢复 TensorFlow 中的大部分对象, 比如下面类的实例都可以被保存:

# 保存训练好的模型, 先声明一个 Checkpoint
model = TrainedModel()
checkpoint = tf.train.Checkpoint(myAwesomeModel = model, myAwesomeOptimizer = optimizer)
checkpoint.save(save_path_with_prefix)

# 载入保存的训练模型
model_to_be_restored = MyModel()  # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel = model_to_be_restored)
checkpoint.restore(save_path_with_prefix_and_index)

# 为了载入最近的一个模型文件, 返回目录下最近一次检查点的文件名
tf.train.latest_checkpoint(save_path)

.. note::

tf.train.Checkpoint 代码框架

1.train.py 模型训练阶段


# 训练好的模型
model = MyModel()

# 实例化 Checkpoint, 指定保存对象为 model(如果需要保存 Optimizer 的参数也可以加入)
checkpoint = tf.train.Checkpoint(myModel = model)
manager = tf.train.CheckpointManager(checkpoint, directory = "./save", checkpoint_name = "model.ckpt", max_to_keep = 10)

# ...(模型训练代码)

# 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)
if manager:
    manager.save(checkpoint_number = 100)
else:
    checkpoint.save("./save/model.ckpt")

2.test.py 模型使用阶段


# 要使用的模型
model = MyModel()

# 实例化 Checkpoint, 指定恢复对象为 model
checkpoint = tf.train.Checkpoint(myModel = model)

# 从文件恢复模型参数
checkpoint.restore(tf.train.latest_checkpoint("./save))

# ...(模型使用代码)

.. note::

使用 SaveModel 完整导出模型

作为模型导出格式的 SaveModel 包含了一个 TensorFlow 程序的完整信息: 不仅包含参数的权值, 还包含计算的流程(计算图)。 当模型导出为 SaveModel 文件时, 无须模型的源代码即可再次运行模型, 这使得 SaveModel 尤其适用于模型的分享和部署。

Keras 模型均可以方便地导出为 SaveModel 格式。不过需要注意的是, 因为 SaveModel 基于计算图, 所以对于通过继承 tf.keras.Model 类建立的 Keras 模型来说, 需要导出为 SaveModel 格式的方法(比如 call) 都需要 使用 @tf.function 修饰。

语法:

# 保存
tf.saved_model.save(model, "保存的目标文件夹名称")

# 载入
model = tf.saved_model.load("保存的目标文件夹名称")

示例:


Keras 自有的模型导出格式

示例:

curl -LO https://raw.githubcontent.com/keras-team/keras/master/examples/mnist_cnn.py
model.save("mnist_cnn.h5")

import keras

keras.models.load_model("mnist_cnn.h5")