This article was last updated on <span id="expire-date"></span> days ago, the information described in the article may be outdated.
翻译自斯坦福大学网站上的一篇文章,A detailed example of how to use data generators with Keras。主要介绍了在面对大量数据时,如何通过编写数据生成器(DataGenerator)来实现动态加载数据的功能。
目的
你是否有过由于训练数据量过大导致内存消耗太多,因此想找一个魔法般的功能来解决这个问题并且不影响已有的代码?在这个大数据的时代,”大数据”变得越来越重要。
有时候即使你拥有顶配的设备,设备的内存可能也无法存放下所有的训练数据。这个时候,我们就需要寻找其他的解决方法。本篇文章将展示如何借助 Python 中高度封装的深度学习库Keras
实现实时并发的数据生成功能,并将该功能嵌入到你的深度学习模型中。
前言
在阅读这篇文章之前,也许你的Keras
脚本代码结构大致如下
1 | import numpy as np |
这篇文章所涉及到的所有更改都是关于修改第5行加载数据的代码的。可以看出,原始代码存在的问题是有可能无法将所有训练使用的样本一次性全部加载进内存中。
为了达到最终的目的,让我们一步一步来介绍如何搭建一个数据生成器。顺便一提,后面的代码都被尽量编写的结构清晰明了,也许很适合用在你自己的项目中。你可以直接复制粘贴,并自己补全其中省略的部分。
声明
在开始之前,让我们做一些声明以防止在后文中出现有歧义的地方。
假定后面所说的ID
代表了Python中的一个字符串,该字符串对应了一个样本。并且为了清晰的分辨不同的样本以及它们的标签,数据被以下列的方式存储:
- 一个名叫
partition
的字典,其中包括partition['train']
:存储训练集及样本对应的ID
partition['validation']
:存储测试集样本对应的ID
- 一个名叫
labels
的字典,ID
作为其中的key
,ID
对应的样本的标签作为value
,可以通过labels[ID]
的形式获取到样本对应的标签。
举个例子,假设我们的训练集有三个样本,对应的ID
分别是'id-1', 'id-2', 'id-3'
,对应的标签分别是0, 1, 2
;测试集有一个样本,ID
为id-4
,标签为1
。这种情况下,partition
和labels
分别为
1 | >>> partition |
1 | labels |
另外,为了践行模块化的准则,后面的Keras
代码和自定义类会写在不同的Python文件中,文件树如下
1 | folder/ |
其中data/
是存放数据集的目录。
最后的最后,本文中的代码编写尽量保证了泛用性和最简化,因此你应该可以轻易的将其更改以适用于你自己的数据集。
Data generator
现在,我们真正开始编写Python类DataGenerator
,并将其用于Keras
模型中来实时提供训练数据。
首先我们需要编写类的初始化函数,我们将DataGenerator
设计为keras.utils.Sequence
的子类以便可以使用其某些便捷的特性例如多进程。
1 | def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1, |
初始化函数接受很多重要的参数,例如数据的ID
(list_IDs
)和对应的标签labels
,数据的维度大小dim
(例如三维、大小是32×32×32
的数据为dims=(32, 32)
),数据的通道数n_channels
(处理图像数据等多通道数据时很有用),数据的种类数n_classes
,训练使用的数据块(batch
)的大小batch_size
,以及生成数据时是否打乱数据shuffle
。
最后调用的方法on_epoch_end
会在类初始化以及每次训练周期(epoch
)结束时调用,用来生成取数据ID
使用的索引。如果shuffle=True
,则会打乱每次生成的索引。
1 | def on_epoch_end(self): |
打乱索引是为了在每次训练周期向模型喂数据时保证喂的数据是不那么相似的,能够保证最后训练出来的模型更加健壮。
另外一个关键的函数是__data_generation
,这个函数的工作便是生成每次训练使用的batch
,它接受一个参数list_IDs_temp
,这个列表包含了用于生成batch
的样本的ID
。
1 | def __data_generation(self, list_IDs_temp): |
这个函数会实时的读取对应样本的npy
文件。由于我们的代码对多进程友好,因此你可以在这里做更多更复杂的工作而无需担心数据的生成成为训练过程中的瓶颈(感觉瓶颈变成磁盘IO了)。
这里我们还使用了Keras
的keras.utils.to_categorical
函数来把存储在y
中的数字标签转化为了适用于分类的二进制的形式(举个例子,对于6分类的问题,标签3
对应的形式为001000
)。
下面我们将上面提到的三个部分整合到一起我们首先需要实现获取训练时当前batch
的索引的功能(即这是所有batch
中的第几个batch
)。这个功能我们通过魔术方法__len__
实现。
1 | def __len__(self): |
通常这个长度的计算公式如下
这样在每次训练周期模型最多遍历所有batch
一次。
当模型利用index
取对应的batch
时,DataGenerator
需要可以返回对应的batch
,这个功能通过魔术方法__getitem__
实现。
1 | def __getitem__(self, index): |
将上面的代码整合如下
1 | import numpy as np |
Keras script
现在我们需要修改Keras
代码以使其可以使用我们定义好的生成器。
1 | import numpy as np |
如你所见,我们使用了model.fit_generator
方法,该方法可以接受我们构造的生成器,接下来的事情Keras
会帮我们搞定!
另外需要注意的是,我们将use_multiprocessing
设置为True
来启用了多进程的特性,使用的进程数在workers
那里进行了指定。如果你的workers
设置的足够高,那么你的CPU就会努力的工作,这样训练的瓶颈就只会在GPU对神经网络的误差进行传播的过程中(而不是数据的生成)。(还有磁盘IO啦)
结论
没了,就这些。现在运行你的Keras
脚本,你就会发现你的CPU和GPU同时在工作。你还可以到GitHub上看一看将此生成器应用到具体个例上的代码,data generation和Keras脚本都在哟。
Author: Syize
Permalink: https://blog.syize.cn/2023/08/04/generators-in-keras/
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Syizeのblog!
Comments