让 TensorBoard 支持具有自定义子类的 Keras 模型

Python Keras TensorBoard
Article Directory
  1. 1. 原理
  2. 2. 示例

This article was last updated on <span id="expire-date"></span> days ago, the information described in the article may be outdated.

Keras 中的TensorBoard回调可以方便的用来记录模型训练过程中的各种参数变化,以在训练后对模型进行分析。在 Keras 的官方示例中,展示了如何将TensorBoard回调用在自定义的 Keras 模型中。但是在实际使用中,如果你像我一样继承了写 PyTorch 的习惯,将整个 Keras 模型模块化成多个子模型的话,就会发现TensorBoard回调在这个时候失效了。这个时候就需要对TensorBoard回调的功能做增强,让它支持 Keras 模型中的子模型。

原理

通过分析TensorBoard回调的代码可知,_log_weights函数是用来执行权重记录操作的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def _log_weights(self, epoch):
"""Logs the weights of the Model to TensorBoard."""
with self._train_writer.as_default():
for layer in self.model.layers:
for weight in layer.weights:
weight_name = weight.name.replace(":", "_")
# Add a suffix to prevent summary tag name collision.
histogram_weight_name = f"{weight_name}/histogram"
self.summary.histogram(
histogram_weight_name, weight, step=epoch
)
if self.write_images:
# Add a suffix to prevent summary tag name
# collision.
image_weight_name = f"{weight_name}/image"
self._log_weight_as_image(
weight, image_weight_name, epoch
)
self._train_writer.flush()

其中函数内第4行和第5行的两个 for 循环分别对模型中的 layers 和 layer 中的权重做循环。然而如果定义的 Keras 模型中包含子模块,那么获取到的layer就是相应的子模型,此时下一步获取到的weight就不是模型中真正用到的各种网络。因此需要视实际定义的 Keras 模型的结构,对该函数的功能进行修改。

示例

假设一个完整的Keras模型的结构如图

graph LR
    A[MainModel] --> B[SubModel1]
    A --> C[SubModel2]
    A --> D[SubModel3]
    B --> E(TimeDistributed)
    B --> F(TimeDistributed)
    C --> G(ConvLSTM2D)
    C --> H(ConvLSTM2D)
    D --> I(Conv2DTranspose)
    D --> J(Conv2DTranspose)
    E --> K(Conv2D)
    F --> L(Conv2D)

模型MainModel包含三个子模型,SubModel1SubModel2SubModel3,这三个子模型又别包含 Keras 中的一些网络层。为了改进_log_weights函数的功能,我们可以新写一个类继承自TensorBoard回调。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from keras.callbacks import TensorBoard
from keras.layers import TimeDistributed


class CustomModelSupportedTensorBoard(TensorBoard):

def _log_weights(self, epoch):
with self._train_writer.as_default():
for layer in self.model.layers: # type: ignore

# Add another loop to record the corresponding layers
for sublayer in layer.layers:

# Check if the layer is TimeDistributed layer
if isinstance(sublayer, TimeDistributed):
_sublayer = sublayer._layers[0]
else:
_sublayer = sublayer

prefix_name = _sublayer.name

for weight in _sublayer.weights:
weight_name = weight.name.replace(":", "_")
# Add a suffix to prevent summary tag name collision.
histogram_weight_name = f"{prefix_name}/{weight_name}/histogram"
self.summary.histogram(
histogram_weight_name, weight, step=epoch
)
if self.write_images:
# Add a suffix to prevent summary tag name
# collision.
image_weight_name = f"{weight_name}/image"
self._log_weight_as_image(
weight, image_weight_name, epoch
)
self._train_writer.flush()

这个回调中仅重写了_log_weights函数,以保证其他功能与原回调完全一致。不同的地方主要在以下几个地方:

  1. 在原本的第一层 for 循环内,额外添加了一层循环对子模型中的 layers 做循环,以正确拿到真正要监控参数的网络。
  2. 由于使用了TimeDistributed,于是添加了类型判断以从TimeDistributed中取出包装的网络。
  3. 由于不同网络的weight_name均为kernelbias,因此将 layer 的 name 作为前缀添加到histogram_weight_name中,以正确区分不同网络的参数。

将该回调代替原本的TensorBoard回调,即可正确记录MainModel中各种网络的参数变化了。

Author: Syize

Permalink: https://blog.syize.cn/2025/09/02/tensorboard-support-custom-keras-model/

本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Syizeのblog

Comments