This article was last updated on <span id="expire-date"></span> days ago, the information described in the article may be outdated.
Keras 中的TensorBoard
回调可以方便的用来记录模型训练过程中的各种参数变化,以在训练后对模型进行分析。在 Keras 的官方示例中,展示了如何将TensorBoard
回调用在自定义的 Keras 模型中。但是在实际使用中,如果你像我一样继承了写 PyTorch 的习惯,将整个 Keras 模型模块化成多个子模型的话,就会发现TensorBoard
回调在这个时候失效了。这个时候就需要对TensorBoard
回调的功能做增强,让它支持 Keras 模型中的子模型。
原理
通过分析TensorBoard
回调的代码可知,_log_weights
函数是用来执行权重记录操作的。
1 | def _log_weights(self, epoch): |
其中函数内第4行和第5行的两个 for 循环分别对模型中的 layers 和 layer 中的权重做循环。然而如果定义的 Keras 模型中包含子模块,那么获取到的layer
就是相应的子模型,此时下一步获取到的weight
就不是模型中真正用到的各种网络。因此需要视实际定义的 Keras 模型的结构,对该函数的功能进行修改。
示例
假设一个完整的Keras模型的结构如图
graph LR A[MainModel] --> B[SubModel1] A --> C[SubModel2] A --> D[SubModel3] B --> E(TimeDistributed) B --> F(TimeDistributed) C --> G(ConvLSTM2D) C --> H(ConvLSTM2D) D --> I(Conv2DTranspose) D --> J(Conv2DTranspose) E --> K(Conv2D) F --> L(Conv2D)
模型MainModel
包含三个子模型,SubModel1
、SubModel2
和SubModel3
,这三个子模型又别包含 Keras 中的一些网络层。为了改进_log_weights
函数的功能,我们可以新写一个类继承自TensorBoard
回调。
1 | from keras.callbacks import TensorBoard |
这个回调中仅重写了_log_weights
函数,以保证其他功能与原回调完全一致。不同的地方主要在以下几个地方:
- 在原本的第一层 for 循环内,额外添加了一层循环对子模型中的 layers 做循环,以正确拿到真正要监控参数的网络。
- 由于使用了
TimeDistributed
,于是添加了类型判断以从TimeDistributed
中取出包装的网络。 - 由于不同网络的
weight_name
均为kernel
或bias
,因此将 layer 的 name 作为前缀添加到histogram_weight_name
中,以正确区分不同网络的参数。
将该回调代替原本的TensorBoard
回调,即可正确记录MainModel
中各种网络的参数变化了。
Author: Syize
Permalink: https://blog.syize.cn/2025/09/02/tensorboard-support-custom-keras-model/
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Syizeのblog!
Comments