pytorch的dataset与dataloader解析

2023-05-08,,

整理一下pytorch获取的流程:

    创建Dataset对象
    创建DataLoader对象,装载有dataset对象
    循环DataLoader对象,DataLoader.__iter__返回的是DataLoaderIter对象
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for data in dataloader:
....

根据源码分析:torch.utils.data

1 - Dataset:

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
""" def __getitem__(self, index):
raise NotImplementedError def __len__(self):
raise NotImplementedError def __add__(self, other):
return ConcatDataset([self, other])

Dataset这是一个抽象类,不能实例化,需要重写类方法,关键点有两个:

__getitem__ 这个很重要,规定了如何读数据,比如常用的transform
__len__ 这个就是返回数据集的长度,比如:return len(self.data)

2 - DataLoader:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)

先看一下主要参数:

dataset:就是 torch.utils.data.Dataset 类的实例。也就是说为了使用 DataLoader 类,需要先定义一个 torch.utils.data.Dataset 类的实例。
batch_size:每一个批次需要加载的训练样本个数。
shuffle:如果设置为 True 表示训练样本数据会被随机打乱,默认值为 False。一般会设置为 True 。
sampler:自定义从数据集中取样本的策略,如果指定这个参数,那么 shuffle 必须为 False 。从源码中可以看到,如果指定了该参数,同时 shuffle 设定为 True,DataLoader 的 __init__ 函数就会抛出一个异常 。
batch_sampler:与 sampler 类似,但是一次只返回一个 batch 的 indices(索引),需要注意的是,一旦指定了这个参数,那么 batch_size,shuffle,sampler,drop_last 就不能再指定了。源码中同样做了限制。
num_workers:表示会使用多少个线程来加载训练数据;默认值为 0,表示数据加载直接在主线程中进行。
collate_fn:对每一个 batch 的数据做一些你想要的操作。一个例子,https://zhuanlan.zhihu.com/p/346332974
pin_memory:把数据转移到和 GPU 相关联的 CPU 内存,加速 GPU 载入数据的速度。
drop_last:比如你的batch_size设置为 32,而一个 epoch 只有 100 个样本;如果设置为 True,那么训练的时候后面的 4 个就被扔掉了。如果为 False(默认),那么会继续正常执行,只是最后的 batch_size 会小一点。
timeout:加载一个 batch 数据的超时时间。
worker_init_fn:指定每个数据加载线程的入口函数。

源码分析:

class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle') if batch_sampler is None:
if sampler is None:
if shuffle:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一个 长度为 len(dataset) 的 序列索引(随机的)。
sampler = RandomSampler(dataset)
else:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一个 长度为 len(dataset) 的 序列索引(顺序的)。
sampler = SequentialSampler(dataset)
# Sampler 是个迭代器,一次之只返回一个 索引
# BatchSampler 也是个迭代器,但是一次返回 batch_size 个 索引
batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler
self.batch_sampler = batch_sampler def __iter__(self):
return DataLoaderIter(self) def __len__(self):
return len(self.batch_sampler) 

可以发现__iter__返回的是DataLoaderIter

3 - DataLoaderIter

先看init初始化:

if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn

# 定义了workers相同数量个Queue并放置在index_queues这个list中,
# 这些Queue与worker一一对应,用来给worker传递“工作内容”
self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]

# worker_queue_idx用于下一个工作的workre序号,主进程轮询使用不同workers
self.worker_queue_idx = 0

# 各个workre将自己所取得的数据传递给wokrker_result_queue,供主进程fetch
self.worker_result_queue = multiprocessing.SimpleQueue() # 记录当前时刻分配了多少个任务(可能有处于等待状态的任务)
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
# 发送出去数据的编号
self.send_idx = 0
# 接受到数据的编号
self.rcvd_idx = 0 # 缓存区
self.reorder_dict = {}
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queues[i],
self.worker_result_queue, self.collate_fn, base_seed + i,
self.worker_init_fn, i))
for i in range(self.num_workers)]
# 初始化相应的进程,目标函数为_worker_loop
# 参数:dataset(用于数据读取),index_queues[i]为worker对应的index_queue
# 以及用于输出的queue # 此处主要用于数据读取后的pin_memory操作,不影响多进程主逻辑,暂不展开
if self.pin_memory or self.timeout > 0:
...
else:
self.data_queue = self.worker_result_queue
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
# 将父进程设置为守护进程,保证父进程结束后,worker进程也结束,必须设置在start之前
w.start() # 下面是一些系统信号处理逻辑,对这方面我还不太熟悉就不介绍了。
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True # 初始化后生成2*num_workers数量个prefetch的数据,使dataloader提前工作,提升整体效率。
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()

init过程有两个函数,一个是worker_loop,另个是put_indices

a. 先看worker_loop:

def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal happened again already.
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
_set_worker_signal_handlers() torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed) if init_fn is not None:
init_fn(worker_id) # 父进程状态监测
watchdog = ManagerWatchdog() # 死循环查询是否有任务传进来
while True:
try:
# 从index_queue获取相应数据
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
except queue.Empty:
if watchdog.is_alive():
continue
else:
break
if r is None:
break
idx, batch_indices = r
try:
# 获得以后for循环进行读取数据读取,此处和单进程的工作原理是一样的
# 因此时间花费和batchsize数量呈线性关系
samples = collate_fn([dataset[i] for i in batch_indices])
# 经过collate_fn后变成torch.Tensor
except Exception:
# 异常处理
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
# 通过data_queue传回处理好的batch数据
data_queue.put((idx, samples))
# 显示删除中间变量,降低内存消耗
del samples

这里就是不停地轮询,从index_queues队列里获得索引,然后通过collate_fn函数和索引获取tensor,然后塞入data_queue

b. 再看put_indices

def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
# 默认设定是只允许分配2*num_workers个任务,保证内存等资源不被耗尽
indices = next(self.sample_iter, None)
# 从sample_iter中拿到dataset中下一轮次的索引,用于fetch数据
if indices is None:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
# 轮询选择worker,找到其对应的队列,向其中发送工作内容(数据编号,数据索引)
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
# worker_queue_idx自增
self.batches_outstanding += 1
# 任务分配数+1
self.send_idx += 1
# 已发送任务总数+1(下批数据编号) 

这个就是把索引塞进队列index_queues

以上就是init,当for循环时,会调用next:

c. __next__返回一个batch

def __next__(self):
if self.num_workers == 0: # same-process loading (主进程阻塞式读取数据)
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch # check if the next sample has already been generated
# 先查看数据是否在缓存dict中
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
# 异常处理
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
# 阻塞式的从data_queue里面获取处理好的批数据
idx, batch = self._get_batch()
# 任务数减一
self.batches_outstanding -= 1
# 这一步可能会造成的周期阻塞现象
# 每次获取data以后,要校验和rcvd_idx是否一致
# 若不一致,则先把获取到的数据放到reorder_dict这个缓存dict中,继续死循环
# 直到获取到相应的idx编号于rcvd_idx可以对应上,并将数据返回
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)

__next__里的while True,要从data_queue里面读到的数据idx和rcvd_idx一致才将数据返回。因此可能会存在如下这种情况:

假设num_workers=8,现在发送了8个数据给相应的worker,此时send_idx=8,rcvd_idx=0。过了一段时间以后,{1,2,3,5,6,7}进程数据准备完毕,此时主进程从data_queue读取到相关的数据,但由于和rcvd_idx不匹配,只能将其放在缓存里。直到send_idx=0数据准备齐以后,才能将数据返回出去,随后从缓存中弹出2,3的数据,之后又阻塞等待idx=4的数据。即输出的数据必须保持顺序性!因此在worker变多,出现这种逆序现象可能性会更大,这种现象也会出现在非num_workrers次迭代,只要相应的rcvd_idx没有得到相关数据,则主进程就会一直等待。

d. process_next_batch

def _process_next_batch(self, batch):
# 序号对上以后,rcvd_idx自加1
self.rcvd_idx += 1
# 添加一个fetchdata任务给worker
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch

  

这个函数注意的是,只有在__next__中,idx == self.rcvd_idx时才会调用,也就是可能出现多个worker已经准备好了,但是只能放在缓存区,并且无法向index_queues塞入索引,使worker无法保持活跃状态。

最后对于for循环从dataloader获取data总体流程:

for epoch in range(num_epoches):
for data in dataloader:

对于这个for,其实就是调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter,如果是num_worker>0,init里就会创建多线程,并且有两个队列,一个是存放dataset的索引index_queues,一个是从index_queues里拿到索引,调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch,放到data_queue队列里,反复调用DataLoaderIter 的__next__,从data_queue中获取batch。

参考:

Pytorch数据读取(Dataset, DataLoader, DataLoaderIter) https://zhuanlan.zhihu.com/p/30934236

PyTorch 之 Dataset 和 Dataloader https://zhuanlan.zhihu.com/p/339675188

PyTorch36.DataLoader源代码剖析 https://zhuanlan.zhihu.com/p/169497395

PyTorch DataLoader初探 https://zhuanlan.zhihu.com/p/91521705

一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系 https://zhuanlan.zhihu.com/p/76893455

pytorch的dataset与dataloader解析的相关教程结束。

《pytorch的dataset与dataloader解析.doc》

下载本文的Word格式文档,以方便收藏与打印。