TensorFlow 2 中文文档 - 保存与加载模型
TensorFlow2 文档系列文章链接:
- TensorFlow 2 / 2.0 中文文档 (Jul 9, 2019)
- TensorFlow 2 中文文档 - MNIST 图像分类 (Jul 9, 2019)
- TensorFlow 2 中文文档 - IMDB 文本分类 (Jul 9, 2019)
- TensorFlow 2 中文文档 - 特征工程结构化数据分类 (Jul 9, 2019)
- TensorFlow 2 中文文档 - 回归预测燃油效率 (Jul 11, 2019)
- TensorFlow 2 中文文档 - 过拟合与欠拟合 (Jul 12, 2019)
- TensorFlow 2 中文文档 - 保存与加载模型 (Jul 13, 2019)
- TensorFlow 2 中文文档 - 卷积神经网络分类 CIFAR-10 (Jul 19, 2019)
- TensorFlow 2 中文文档 - TFHub 迁移学习 (Jul 19, 2019)
- TensorFlow 2 中文文档 - RNN LSTM 文本分类 (Jul 22, 2019)
源代码/数据集已上传到 Github - tensorflow2-docs-zh
TF2.0 TensorFlow 2 / 2.0 中文文档:保存与加载模型 Save and Restore model
主要内容:使用 tf.keras
接口训练、保存、加载模型,数据集选用 MNIST 。
1 | $ pip install -q tensorflow==2.0.0-beta1 |
准备训练数据
1 | import tensorflow as tf |
搭建模型
1 | def create_model(): |
自动保存 checkpoints
这样做,一是训练结束后得到了训练好的模型,使用得不必再重新训练,二是训练过程被中断,可以从断点处继续训练。
设置tf.keras.callbacks.ModelCheckpoint
回调可以实现这一点。
1 | # 存储模型的文件名,语法与 str.format 一致 |
1 | Epoch 00010: saving model to training_2/cp-0010.ckpt |
加载权重:
1 | latest = tf.train.latest_checkpoint(checkpoint_dir) |
1 | 1000/1000 [===] - 0s 90us/sample - loss: 0.4703 - accuracy: 0.8780 |
手动保存权重
1 | # 手动保存权重 |
1 | 1000/1000 [===] - 0s 90us/sample - loss: 0.4703 - accuracy: 0.8780 |
保存整个模型
上面的示例仅仅保存了模型中的权重(weights),模型和优化器都可以一起保存,包括权重(weights)、模型配置(architecture)和优化器配置(optimizer configuration)。这样做的好处是,当你恢复模型时,完全不依赖于原来搭建模型的代码。
保存完整的模型有很多应用场景,比如在浏览器中使用 TensorFlow.js 加载运行,比如在移动设备上使用 TensorFlow Lite 加载运行。
HDF5
直接调用model.save
即可保存为 HDF5 格式的文件。
1 | model.save('my_model.h5') |
从 HDF5 中恢复完整的模型。
1 | new_model = models.load_model('my_model.h5') |
1 | 1000/1000 [===] - 0s 90us/sample - loss: 0.4703 - accuracy: 0.8780 |
saved_model
保存为saved_model
格式。
1 | import time |
恢复模型并预测
1 | new_model = tf.keras.experimental.load_from_saved_model(saved_model_path) |
1 | (1000, 10) |
saved_model
格式的模型可以直接用来预测(predict),但是 saved_model 没有保存优化器配置,如果要使用evaluate
方法,则需要先 compile。
1 | new_model.compile(optimizer=model.optimizer, |
1 | 1000/1000 [===] - 0s 90us/sample - loss: 0.4703 - accuracy: 0.8780 |
最后
TensorFlow 中还有其他的方式可以保存模型。
- Saving in eager eager 模型保存模型
- Save and Restore – low-level 的接口。
返回文档首页
完整代码:Github - save_restore_model.ipynb
参考文档:Save and restore models
附 推荐
- 一篇文章入门 Python
专题: TensorFlow2 文档
本文发表于 2019-07-13,最后修改于 2022-09-10。
本站永久域名「 geektutu.com 」,也可搜索「 极客兔兔 」找到我。
上一篇 « TensorFlow 2 中文文档 - 过拟合与欠拟合 下一篇 » TensorFlow 2 中文文档 - 卷积神经网络分类 CIFAR-10