torch.utils.data

译者:shuziP

校验:shuziP

PyTorch数据加载程序的核心是 torch.utils.data.DataLoader 类。它表示在数据集上可迭代的Python,并支持

这些选项是由DataLoader的构造函数参数配置的,具有签名:


​ DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, ​ batch_sampler=None, num_workers=0, collate_fn=None, ​ pin_memory=False, drop_last=False, timeout=0, ​ worker_init_fn=None)

下面几节将详细描述这些选项的功能和用法。

数据集类型

DataLoader构造函数最重要的参数是dataset,它表示要从中加载数据的dataset对象。PyTorch支持两种不同类型的数据集:

版图式数据集

版图式数据集实现了 __getitem__()__len__() 协议,并表示从(可能不是完整的)索引/键到数据样本的映射。

例如,当使用 dataset[idx]访问这样的数据集时,可以从磁盘上的文件夹中读取 idx-th i图像及其对应的标签。

参见Dataset了解更多细节。

可迭代式的数据集

可迭代式数据集的 一个子类的实例IterableDataset实现了 __iter__() 协议和代表了数据样本可迭代。这种类型的数据集特别适合这样的情况:随机读取非常高代价,甚至是不可能的,并且批大小取决于获取的数据。

例如,这样的数据集在被访问 iter(dataset),可以返回从数据库、远程服务器甚至实时生成的日志读取的数据流。

参见 IterableDataset了解更多详情。

注意

当使用 multi-process data loading. 的IterableDataset 时。在每个工作进程上复制相同的数据集对象,因此必须对副本进行不同的配置,以避免重复数据。有关如何实现此目的,请参见IterableDataset 文档。

数据加载顺序和采样器

对于iterable风格的数据集,数据加载顺序完全由用户定义的iterable控制。这允许更容易地实现块读取和动态批处理大小(例如,每次生成一个批处理样例)。

本节的其余部分涉及map-style datasetstorch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的顺序。它们表示数据集索引上的可迭代对象。例如,在随机梯度像样(SGD)的常见情况下,一个 Sampler 可以随机排列一个索引列表,并一次产生一个,或产生一小部分用于小型批量SGD的索引。

顺序采样器或打乱采样器将根据 DataLoader 的’ shuffle ‘参数自动构建。或者,用户可以使用’sampler’参数来指定一个自定义的Sampler 对象,该对象每次都会生成下一个要获取的索引/键。

一个自定义的 Sampler ,一次生成一批索引的列表,可以作为’ batch_sampler ‘参数传递。自动批处理也可以通过“batch_size”和“drop_last”参数启用。参见下一节 获得更多的细节。

请注意

“sampler”和“batch_sampler”都与迭代式数据集不兼容,因为这样的数据集没有键或索引的概念。

加载批处理和非批处理数据

DataLoader支持通过参数batch_size、drop_last和batch_sampler将单个获取的数据样本自动整理成批。

自动批处理(默认)

这是最常见的情况,它对应于获取少量数据并将其整理成成批的样本,即,包含一个维度为批处理维度(通常是第一个维度)的张量。

当“batch_size”(默认为“1”)不是“None”时,数据加载器将生成成批的样本,而不是单个样本。“batch_size”和“drop_last”参数用于指定数据加载器如何获取批量数据集键。对于地图样式的数据集,用户也可以指定“batch_sampler”,它一次生成一个键列表。

请注意

“batch_size”和“drop_last”参数主要用于从“sampler”构造“batch_sampler”。对于地图样式的数据集,“采样器”要么由用户提供,要么基于“shuffle”参数构造。对于迭代式数据集,“采样器”是一个虚拟的无限数据集。有关采样器的更多信息,请参见本节

请注意

当从具有多个处理的迭代式数据集中获取数据时,drop_last参数将删除每个工作区的数据集副本的最后一批未完成的数据。

使用来自sampler的索引获取样本列表之后,作为collate_fn参数传递的函数被用来将样本列表整理成批量。

在这种情况下,从一个地图样式的数据集加载大致相当于:

  1. for indices in batch_sampler:
  2. yield collate_fn([dataset[i] for i in indices])

和从一个迭代式数据集加载大致相当于:

  1. dataset_iter = iter(dataset)
  2. for indices in batch_sampler:
  3. yield collate_fn([next(dataset_iter) for _ in indices])

自定义 collate_fn可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度。参见本节 了解更多关于 collate_fn.的信息。

禁用自动批处理

在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者只加载单个示例。例如,直接加载成批数据(例如,从数据库中批量读取数据或读取连续的内存块),或者批量大小依赖于数据,或者程序设计用于处理单个样本,这样做的成本更低。在这些场景下,最好不要使用自动批处理(其中使用’ collate_fn ‘对样本进行排序),而是让数据加载器直接返回’ dataset ‘对象的每个成员。

当“batch_size”和“batch_sampler”都是“None”(batch_sampler的默认值已经是“None”)时,自动批处理将被禁用。从’ dataset ‘获得的每个样例都使用作为’ collate_fn ‘参数传递的函数进行处理。

当自动批处理被禁用时,默认的’ collate_fn ‘只是将NumPy数组转换为PyTorch张量,而不改变其他内容。

In this case, loading from a map-style dataset is roughly equivalent with:

在这种情况下,从一个map-style dataset加载大致相当于:

  1. for index in sampler:
  2. yield collate_fn(dataset[index])

从一个iterable-style dataset集加载大致相当于:

  1. for data in iter(dataset):
  2. yield collate_fn(data)

这一节更多关于collate_fn。

Working with collate_fn

启用或禁用自动批处理时,’ collate_fn ‘的使用略有不同。

当自动批处理被禁用,’ collate_fn ‘与每个单独的数据样本一起被调用,输出由数据加载器迭代器产生。在本例中,默认’ collate_fn ‘只是转换PyTorch张量中的NumPy数组。

启用自动批处理时,每次使用数据样本列表调用’ collate_fn ‘。预期它会将输入样例整理成一个批,以便从数据加载器迭代器生成。本节的其余部分将在本例中描述默认’ collate_fn ‘的行为。

例如,如果每个数据样本包含一个3通道图像和一个完整的类标签,即,数据集的每个元素都返回一个元组’ (image, class_index) ‘,默认的’ collate_fn ‘将这样的元组列表整理成成批处理的图像张量和成批处理的类标签张量的一个元组。特别是,默认的“collate_fn”具有以下属性:

  • 它总是预先添加一个新的维度作为批处理维度。
  • 它自动将NumPy数组和Python数值转换为PyTorch张量。
  • 它保留了数据结构,例如,如果每个样本是一个字典,它将输出一个字典,该字典具有相同的一组键,但将批量张量作为值(如果不能将值转换为张量,则输出列表)。列表s、元组s、名称元组s也是如此。

用户可以使用自定义的“collate_fn”来实现自定义的批处理,例如,根据第一个维度以外的维度进行排序,填充不同长度的序列,或者添加对自定义数据类型的支持。

Single- and Multi-process Data Loading

一个DataLoader 默认使用单进程数据加载。

在Python进程中,全局解释器锁(GIL)会阻止真正的跨线程完全并行化Python代码。为了避免使用数据加载阻塞计算代码,PyTorch提供了一个简单的开关来执行多进程数据加载,只需将参数’ num_workers ‘设置为正整数。

单进程数据加载(默认)

在这种模式下,在初始化‘ DataLoader ‘的过程中完成数据获取。因此,数据加载可能会阻塞计算。但是,当用于在进程之间共享数据的资源(例如,共享内存、文件描述符)有限时,或者当整个数据集很小并且可以完全加载到内存中时,这种模式可能是首选的。此外,单进程加载通常显示更多可读的错误跟踪,因此对于调试非常有用。

Multi-process data loading多进程数据加载

将参数’ num_workers ‘设置为正整数将打开多进程数据加载,并使用指定的加载工作进程数量。

在这种模式下,每次创建‘ DataLoader ‘的迭代器(例如,当您调用’ enumerate(DataLoader) ‘)时,就会创建’ num_workers ‘工作者进程。此时,’ dataset ‘、’ collate_fn ‘和’ worker_init_fn ‘被传递给每个worker,它们用于初始化和获取数据。这意味着数据集访问及其内部IO、转换(包括’ collate_fn ‘)在工作进程中运行。

torch.utils.data.get_worker_info()返回工作进程中的各种有用信息(包括工作进程id、数据集副本、初始种子等),并在主进程中返回’ None ‘。用户可以在数据集代码和/或’worker_init_fn’中使用这个函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,这对于数据集分片特别有帮助。

对于 map-style 数据集,主进程使用 sampler 生成索引并将它们发送给工作者。因此,任何随机洗牌都是在主进程中完成的,它通过为load分配索引来引导装载。

For iterable-style datasets, since each worker process gets a replica of the dataset object, naive multi-process loading will often result in duplicated data. Using torch.utils.data.get_worker_info() and/or worker_init_fn, users may configure each replica independently. (See IterableDataset documentations for how to achieve this. ) For similar reasons, in multi-process loading, the drop_last argument drops the last non-full batch of each worker’s iterable-style dataset replica.

对于迭代风格的数据集,由于每个工作进程都获得一个“dataset”对象的副本,所以简单的多进程加载通常会导致重复的数据。使用torch.utils.data.get_worker_info()](https://pytorch.org/docs/stable/data.html#torch.utils.data.get_worker_info)’worker_init_fn,,用户可以独立配置每个副本。(参见 IterableDataset 出于类似的原因,在多进程加载过程中,’ drop_last ‘参数会删除每个worker的迭代式数据集副本的最后一批非完整数据。

一旦到达迭代的末尾,或者当迭代器变成垃圾收集时,Workers就会被关闭。

警告

它一般不建议恢复在多进程加载CUDA张量,因为许多微妙之处使用CUDA和多分享CUDA张量(见并行处理 CUDA)。相反,我们建议使用自动存储器钉扎(即,设置pin_memory =真),这使得能够快速数据传输到支持CUDA的GPU。

特定于平台的行为

由于工人依靠Python的多重处理 ,工人启动在Windows上U不同于nix。

  • 在Unix上,fork() 是默认的multiprocessing 启动方法。使用“fork()”,儿童工作者通常可以通过克隆的地址空间直接访问 dataset 和Python参数函数。
  • 在Windows中,产卵()为默认 并行处理启动方法。使用重生(),另一种解释是推出是运行在主脚本,然后由接收数据集内部职工功能, collat​​e_fn和通过 泡菜序列的其它参数。
  • 在Windows上,spawn()是默认的并行处理启动方法(multiprocessing)。使用spawn() ,启动另一个解释器,它运行主脚本,然后启动内部的worker函数,它通过 pickle 序列化接收数据集、collate_fn和其他参数。

这种独立的序列化意味着,你应该采取两个步骤,以确保与Windows兼容,同时使用多进程数据加载:

  • 将主脚本的大部分代码封装在 if __name__ == '__main__': block, 中,以确保在启动每个工作进程时不会再次运行(很可能会产生错误)。您可以将数据集和DataLoader 实例创建逻辑放在这里,因为它不需要在workers中重新执行。
  • 确保任何自定义的collate_fn, worker_init_fn 或数据集代码都被声明为顶层定义,并在 __main__ 检查之外。这确保它们在工作进程中可用。(这是必需的,因为函数仅作为引用进行pickle,而不是作为字节码。)

多进程数据加载的随机性

默认情况下,每个worker将其PyTorch种子设置为base_seed + worker_id,其中base_seed是由使用其RNG的主进程生成的长种子(因此,强制使用RNG状态)。但是,其他库的种子可能在初始化worker (w.g.)时被复制。,导致每个worker返回相同的随机数。(参见FAQ中的这个 部分)。

In worker_init_fn, you may access the PyTorch seed set for each worker with either torch.utils.data.get_worker_info().seed or torch.initial_seed(), and use it to seed other libraries before data loading.

worker_init_fn,你可以访问PyTorch种子为每个工具人与 torch.utils.data.get_worker_info().seedtorch.initial_seed(),并使用它的种子数据加载之前其他库。

Memory Pinning

当来自固定(页面锁定)内存时,GPU副本的主机速度要快得多。参见使用固定内存缓冲区了解更多关于何时以及如何使用固定内存的细节。

对于数据加载,将’ pin_memory=True ‘传递给 DataLoader 将自动将获取的数据张量放入固定内存中,从而能够更快地将数据传输到支持cuda的gpu。

默认的内存固定逻辑只识别张量、映射和包含张量的迭代器。默认情况下,如果把逻辑看到一批自定义类型(这将发生如果你有一批“collate_fn”,返回一个自定义类型),或者如果你批的每个元素是一个自定义类型,将逻辑不会认出他们,它会返回这批没有固定的内存(或这些元素)。要为自定义批处理或数据类型启用内存固定,请在自定义类型上定义’ pin_memory() ‘方法。

See the example below.

请参见下面的例子。

例:

  1. class SimpleCustomBatch:
  2. def __init__(self, data):
  3. transposed_data = list(zip(*data))
  4. self.inp = torch.stack(transposed_data[0], 0)
  5. self.tgt = torch.stack(transposed_data[1], 0)
  6. # custom memory pinning method on custom type
  7. def pin_memory(self):
  8. self.inp = self.inp.pin_memory()
  9. self.tgt = self.tgt.pin_memory()
  10. return self
  11. def collate_wrapper(batch):
  12. return SimpleCustomBatch(batch)
  13. inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
  14. tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
  15. dataset = TensorDataset(inps, tgts)
  16. loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
  17. pin_memory=True)
  18. for batch_ndx, sample in enumerate(loader):
  19. print(sample.inp.is_pinned())
  20. print(sample.tgt.is_pinned())

CLASStorch.utils.data.``DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

数据加载程序。组合一个数据集和一个采样器,并在给定的数据集上提供一个可迭代的。

DataLoader支持地图样式和迭代样式的数据集,支持单进程或多进程加载、自定义加载顺序以及可选的自动批处理(排序)和内存固定。

torch.utils.data 。有关更多详细信息,请参阅数据文档页。

Parameters

  • dataset (Dataset) - 从该数据集到加载数据。

  • batch_size (int, optional)) - 如何每批许多样品加载(默认值:1)。

  • shuffle (bool, optional)) - 设置为为具有在每个历元改组的数据(默认值:)。

  • sampler (Sampler, optional)) - 定义从数据集中得出样品的策略。如果指定,洗牌必须假 [HTG17。

  • batch_sampler (取样 可选 ) - 象取样,但在同一时间返回一批指标。互斥与的batch_size洗牌取样drop_last

  • num_workers ( INT 可选 ) - 多少子过程用于数据加载。 0意味着数据将在主处理加载。 (默认值:0

  • collat​​e_fn (可调用 可选 ) - 合并的样本的列表,以形成小批量张量(S)的。使用从图式集装批处理时使用。

  • pin_memory ( 布尔 可选 ) - 如果,数据装载将在返回之前复制到张量CUDA固定内存。如果数据元素是一个自定义类型,或你的collat​​e_fn返回一批即自定义类型,见下面的例子。

  • drop_last ( 布尔 可选 ) - 设置为放弃最后一批不全,如果数据集大小不是由批量大小整除。如果和数据集的大小是不是批量大小整除,则最后一批将较小。 (默认值:

  • timeout (数字 可选 ) - 如果为正,则为从工作者收集批的超时值。应该是非负的。(默认值:0)

  • worker_init_fn (可调用 可选 ) - 如果不是’ None ‘,则在播种之后和数据加载之前,以工作者id (‘ [0, num_workers - 1] ‘中的int)作为输入,在每个工作者子进程上调用它。(默认:“没有一个”)

Warning

如果使用 spawn 启动方法,则worker_init_fn 不能是一个不可修改的对象,例如lambda函数。有关PyTorch中并行处理的更多细节,请参见Multiprocessing best practices

Note

len(dataloader) 启发式是基于所用采样器的长度。当“dataset”是一个IterableDataset时,将使用一个无限采样器,它的 __len__() 没有实现,因为实际长度取决于可迭代和多进程加载配置。因此,除非使用地图样式的数据集,否则不应该查询此方法。有关这两种数据集的详细信息,请参见 Dataset Types

CLASStorch.utils.data.``Dataset

表示数据集的抽象类。

所有表示从键到数据样本的映射的数据集都应该继承它。所有的子类都应该覆盖__getitem__(),支持为给定的键获取数据样本。子类也可以选择性地覆盖 __len__()预计返回数据集的大小由许多Sampler 实现和默认选项DataLoader.

Note

的DataLoader缺省构建一个索引采样能产生整数指数。为了使它与地图式的数据集与非整指数/键的作用,必须提供自定义采样。

DataLoader 默认情况下构造一个索引采样器,生成完整的索引。要使它与具有非完整索引/键的地图样式数据集一起工作,必须提供自定义采样器。

classtorch.utils.data.``IterableDataset[source]

可迭代的数据集。

代表数据样本的迭代所有数据集应该继承它。当数据来自一个数据集流的这种形式是特别有用的。

所有子类应该overrite __iter __(),这将返回样本的迭代在该数据集。

当一个子类使用具有 的DataLoader,在数据集中的每个项目将被从得到的 的DataLoader迭代器。当num_workers & GT ; 0,每个工作进程将具有数据集对象的不同拷贝,因此通常希望独立地配置每个拷贝,以避免从工人返回重复数据。get_worker_info(),在一个工作进程调用时,返回关于工人的信息。它可以在任一使用的数据集的__iter __()方法或 的DataLoaderworker_init_fn选项来修改每个副本的行为。

实施例1:在所有工人分裂工作量__iter __()

  1. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  2. ... def __init__(self, start, end):
  3. ... super(MyIterableDataset).__init__()
  4. ... assert end > start, "this example code only works with end >= start"
  5. ... self.start = start
  6. ... self.end = end
  7. ...
  8. ... def __iter__(self):
  9. ... worker_info = torch.utils.data.get_worker_info()
  10. ... if worker_info is None: # single-process data loading, return the full iterator
  11. ... iter_start = self.start
  12. ... iter_end = self.end
  13. ... else: # in a worker process
  14. ... # split workload
  15. ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
  16. ... worker_id = worker_info.id
  17. ... iter_start = self.start + worker_id * per_worker
  18. ... iter_end = min(iter_start + per_worker, self.end)
  19. ... return iter(range(iter_start, iter_end))
  20. ...
  21. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  22. >>> ds = MyIterableDataset(start=3, end=7)
  23. >>> # Single-process loading
  24. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  25. [3, 4, 5, 6]
  26. >>> # Mult-process loading with two worker processes
  27. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  28. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  29. [3, 5, 4, 6]
  30. >>> # With even more workers
  31. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
  32. [3, 4, 5, 6]

实施例2:使用worker_init_fn在所有工人之间分配工作负载:

  1. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  2. ... def __init__(self, start, end):
  3. ... super(MyIterableDataset).__init__()
  4. ... assert end > start, "this example code only works with end >= start"
  5. ... self.start = start
  6. ... self.end = end
  7. ...
  8. ... def __iter__(self):
  9. ... return iter(range(self.start, self.end))
  10. ...
  11. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  12. >>> ds = MyIterableDataset(start=3, end=7)
  13. >>> # Single-process loading
  14. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  15. [3, 4, 5, 6]
  16. >>>
  17. >>> # Directly doing multi-process loading yields duplicate data
  18. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  19. [3, 3, 4, 4, 5, 5, 6, 6]
  20. >>> # Define a `worker_init_fn` that configures each dataset copy differently
  21. >>> def worker_init_fn(worker_id):
  22. ... worker_info = torch.utils.data.get_worker_info()
  23. ... dataset = worker_info.dataset # the dataset copy in this worker process
  24. ... overall_start = dataset.start
  25. ... overall_end = dataset.end
  26. ... # configure the dataset to only process the split workload
  27. ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
  28. ... worker_id = worker_info.id
  29. ... dataset.start = overall_start + worker_id * per_worker
  30. ... dataset.end = min(dataset.start + per_worker, overall_end)
  31. ...
  32. >>> # Mult-process loading with the custom `worker_init_fn`
  33. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  34. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
  35. [3, 5, 4, 6]
  36. >>> # With even more workers
  37. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
  38. [3, 4, 5, 6]

classtorch.utils.data.``TensorDataset( \tensors* )[source]

数据集包装张量。

每个样品将沿所述第一维度的索引张量进行检索。

Parameters

*tensors (Tensor) - 具有所述第一尺寸的大小相同张量。

classtorch.utils.data.``ConcatDataset( datasets )[source]

数据集作为多个数据集的串联。

这个类是组装不同的现有数据集是有用的。

Parameters

datasets (sequence) 数据集 (序列 ) - 数据集的列表要连接

classtorch.utils.data.``ChainDataset( datasets )[source]

数据集chainning多个 IterableDataset秒。

这个类是组装不同的现有数据集流是有用的。该chainning操作上即时完成的,因此串联与此类大型数据集将是有效的。

Parameters

数据集 (IterableDataset 的迭代) - 数据集链接在一起

classtorch.utils.data.``Subset( dataset , indices )[source]

在指定的索引数据集的子集。

Parameters

  • 数据集 (数据集 ) - 整个数据集

  • 指数 (序列 ) - 在整个组索引选择的子集

torch.utils.data.``get_worker_info()[source]

返回当前 的DataLoader迭代工作进程的信息。

当一个工人叫,这将返回保证具有以下属性的对象:

  • ID:当前作业人员ID。

  • num_workers:工人的总数。

  • 种子:当前工人随机种子集。此值由主进程RNG和工人的ID来确定。参见 的DataLoader的更多细节的文档。

  • 数据集:数据集对象在 这里 过程的副本。请注意,这将是在不同的进程比一个主处理不同的对象。

当主过程调用,这将返回

Note

用于worker_init_fn经过DataLoader时,这种方法可能是有用的设置每个工作进程不同,例如,使用worker_id配置数据集对象只读取一个特定部分的分片数据集,或其他使用种子种子库中使用数据集的代码(例如,NumPy)。

torch.utils.data.``random_split( dataset , lengths )[source]

随机分割数据集到给定长度的非重叠的新的数据集。

Parameters

  • dataset (数据集 ) - 数据集要被分割

  • lengths (序列 ) - 要产生裂缝的长度

classtorch.utils.data.``Sampler( data_source )[source]

基类的所有取样。

每采样的子类必须提供一个 __iter__() 的方法,提供一种方式来迭代数据集的元素的索引,和 __len__() 方法,它返回所返回的迭代器的长度。

Note

__len __()方法并不严格 的DataLoader必需的,但在涉及任何计算预期的 的DataLoader的长度。

classtorch.utils.data.``SequentialSampler( data_source )[source]

顺序地将样品的元素,总是以相同的顺序。

Parameters

DATA_SOURCE (数据集 ) - 数据集以从采样

classtorch.utils.data.``RandomSampler( data_source , replacement=False , num_samples=None )[source]

样品元件中随机。如果不更换,然后从一个洗牌的数据集进行采样。如果具有置换,然后用户可指定num_samples绘制。

Parameters

  • data_source ( Dataset) – dataset to sample from

  • replacement( 布尔 ) - 样品绘制替换如果,默认=False

  • num_samples ( INT ) - 样本的数目来绘制,默认=LEN(数据集)。该参数应该当替换是仅被指定。

classtorch.utils.data.``SubsetRandomSampler( indices )[source]

随机样本元素从指数的定列表,无需更换。

Parameters

indices (sequence) - 索引的序列

classtorch.utils.data.``WeightedRandomSampler( weights , num_samples , replacement=True )[source]

样品元素[0,..,len(weights)-1]` 与给定的概率(权重)。

Parameters

  • weights (序列 ) - 权重的顺序,没有必要总结到一个

  • num_samples ( INT ) - 样本的数目来绘制

  • replacement ( 布尔 ) - 如果,样品绘制更换。如果不是,他们绘制无需更换,这意味着当指数样本绘制为行,不能再为该行画出。

  1. >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
  2. [0, 0, 0, 1, 0]
  3. >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
  4. [0, 1, 4, 3, 2]

包装另一个采样,以产生小批量指数。

Parameters

  • sampler (取样 ) - 基采样器。

  • batch_size (int) - 小批量的大小。

  • drop_last (bool) - 如果,采样器将下降的最后一批,如果它的规模将是小于的batch_size

Example

  1. >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
  2. [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
  3. >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
  4. [[0, 1, 2], [3, 4, 5], [6, 7, 8]]

限制数据加载到数据集子集的采样器。

它与torch.nn.parallel.DistributedDataParallel.特别有用。在这种情况下,每个进程可以将DistributedSampler实例作为DataLoader采样器传递,并加载原始数据集的一个子集,该子集是它独有的。

Note

数据集被认为是恒定的大小。

Parameters

  • dataset - 数据集用于采样。

  • num_replicas (可选 ) - 的参与分布式训练的进程数。

  • rank (可选 ) - num_replicas内的当前过程的秩。

  • shuffle (可选 ) - 如果为true(默认值),采样器将会洗牌指数

Next torch.utils.data - 图1 torch.utils.data - 图2 Previous


©版权所有2019年,Torch 贡献者。