torch.utils.data

在PyTorch数据加载工具的心脏是 torch.utils.data.DataLoader类。它代表了一个数据集的一个Python迭代,与支持

  • 图式和可迭代式的数据集

  • 定制数据加载顺序

  • 自动配料

  • 单和多处理数据加载

  • 自动存储器钉扎。

这些选项由的构造器参数构成的 的DataLoader,其具有签名:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

下面的章节详细描述了影响和这些选项的用法。

数据集类型

的DataLoader构造的最重要的参数是数据集,其指示数据集对象从加载数据。 PyTorch支持两种不同类型的数据集:

  • 图式数据集

  • 迭代式的数据集[HTG1。

地图式的数据集

一种地图风格数据集是一个用于实现__getitem __()__len __()协议,以及表示从(可能是一个地图非一体)索引/键数据样本。

例如,这样的数据集,当与访问数据集[IDX],可以读取IDX个图像和其相应的标签从磁盘上的文件夹。

参见 数据集了解更多详情。

可迭代式的数据集

可迭代式数据集的 一个子类的实例IterableDataset实现了__iter __()协议和代表了数据样本可迭代。这种类型的数据集的特别适合于情况下随机读取是昂贵的,甚至不可能的,并且其中所述批量大小取决于所取的数据。

例如,这样的数据集,称为当ITER(数据集),可以返回数据从数据库中,远程服务器读取的流,或甚至原木实时生成。

参见 IterableDataset了解更多详情。

注意

当使用 IterableDataset与多进程数据加载。相同的数据集对象被复制在每个工作进程,因此副本必须被不同地配置,以避免重复的数据。参见 IterableDataset如何实现这个单证。

数据加载顺序和 取样

用户定义的迭代为可迭代式的数据集,数据加载顺序完全由控制。这允许数据块读取和动态批量大小的更容易实现(例如,通过产生在每个时间成批样品)。

本节的其余部分涉及与图式的数据集的情况。torch.utils.data.Sampler [HTG7类可用于指定在数据加载用于索引/键的序列。他们代表了索引到的数据集可迭代的对象。例如,与随机梯度下降(SGD)的常见情况下, 取样 可以随机置换指数列表和屈服每一次一个,或产生一个少数人的小批量SGD。

一种顺序或改组采样器将基于所述洗牌参数向 的DataLoader自动构造。可替换地,用户可以使用取样参数来指定自定义 取样对象在每个时间产生的下一个索引/键获取。

自定义 取样为在一个时间产生一批指数列表可以作为batch_sampler参数传递。自动配料,也可以通过的batch_sizedrop_last参数启用。参见下一节本更多细节。

Note

既不取样也不batch_sampler是具有可迭代式的数据集兼容,因为这样的数据集没有一个键或索引的概念。

装载成批和非成批数据

的DataLoader支持单个取出的数据样本自动整理成批次经由参数的batch_sizedrop_lastbatch_sampler

自动配料(默认)HTG0]

这是最常见的情况,并且对应于提取数据的minibatch并将它们整理成批处理样品,即,含有与张量一个维度是所述批料尺寸(通常是第一个)。

的batch_size(默认1)不是,数据加载器的产率批量样品,而不是个别样品。 的batch_sizedrop_last参数用于指定数据加载器如何获得数据集密钥的批次。在地图风格数据集,用户可以另外指定batch_sampler,它在一个时间产生密钥的列表。

Note

的batch_sizedrop_last参数基本上被用于一个batch_sampler从构建取样。在地图式的数据集时,取样或者由用户提供的或根据洗牌参数构成。对于迭代式的数据集时,取样是伪无限之一。参见本节上采样的更多细节。

Note

当从迭代式的数据集与取多处理,在drop_last参数下降到最后的非整批生产的每个员工的数据集复制品。

取使用从采样器的索引的样本的列表之后,函数作为collat​​e_fn参数被用来校核样本列表成批通过。

在这种情况下,从图式集装是大致相当于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

并从迭代式集装是大致相当于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

自定义collat​​e_fn可以被用于定制的归类,例如,填充顺序数据至一批最大长度。查看更多关于collat​​e_fn 本节[HTG5。

禁用自动配料

在某些情况下,用户可能希望将在数据集代码手动处理配料,或简单地装载单个样品。例如,它可能更便宜直接加载成批数据(例如,批量从数据库中读取或读取的存储器大块连续)或批量大小是依赖于数据的,或者程序被设计为在单个样品工作。在这些情况下,很可能更好,不使用自动配料(其中collat​​e_fn被用来校核的样品),但让所述数据加载器直接返回的每个成员数据集对象。

当两个的batch_sizebatch_sampler为(默认值batch_sampler已经),自动配料被禁用。从数据集获得的每个样品与作为collat​​e_fn参数传递的功能进行处理。

[HTG0当自动配料被禁用,默认collat​​e_fn简单地NumPy的阵列转换成PyTorch张量,并保持所有其他不变。

In this case, loading from a map-style dataset is roughly equivalent with:

for index in sampler:
    yield collate_fn(dataset[index])

and loading from an iterable-style dataset is roughly equivalent with:

for data in iter(dataset):
    yield collate_fn(data)

查看更多关于collat​​e_fn本节[HTG1。

collat​​e_fn工作

利用collat​​e_fn当自动配料被启用或禁用略有不同。

[HTG0当自动配料被禁用,collat​​e_fn被称为与每个单独的数据样本,并且输出从所述数据加载器的迭代得到。在这种情况下,默认collat​​e_fn简单地转换在PyTorch张量NumPy的阵列。

[HTG0当自动配料使能,collat​​e_fn调用与各时刻的数据样本的一个列表。预计到输入样本整理成批处理从数据加载器的迭代得到。本节的其余部分描述了在这种情况下,默认的collat​​e_fn的行为。

例如,如果每个数据样本包括3通道图像和积分类别标签,即,该数据集的每个元素返回一个元组(图像, class_index),默认collat​​e_fn核对这样元组的列表成批处理图像张量的一个元组和批处理类别标签张量。具体地,默认collat​​e_fn具有以下性质:

  • 它总是预先考虑一个新的维度批次尺寸。

  • 它自动NumPy的阵列和Python数值转换成PyTorch张量。

  • 它保留的数据结构,例如,如果每个样本是一个字典,它输出具有相同的密钥集合,但分批张量作为值(或列表,如果值不能被转换成张量)的字典。同样为列表S,元组S,namedtupleS等

用户可以使用定制collat​​e_fn以实现自定义配料,例如,沿除各种长度,或增加对自定义数据类型支撑件的第一,填充序列以外的尺寸核对。

单和多进程数据载入

A的DataLoader缺省使用单进程数据加载。

内一个Python过程中,全局解释器锁(GIL)防止真正完全并行跨线程Python代码。为了避免与数据加载阻断计算代码,PyTorch提供了一个简单开关通过简单地将参数num_workers设置为一个正整数,以执行多处理数据加载。

单进程的数据加载(默认)

在此模式下,数据被取在相同的工艺做了 的DataLoader 被初始化。因此,数据加载可能会阻止计算。然而,这种模式可被当处理(例如,共享存储器,文件描述符)之间使用共享数据资源(多个)是有限的,或者当整个数据集是小,并且可以完全在内存加载优选的。另外,单进程加载经常显示更加可读的错误的痕迹,因此对于调试是有用的。

多进程数据加载

设置参数num_workers作为正整数将接通的多进程数据加载与装载机的工作进程指定的次数。

在这种模式下,每次迭代一个 的DataLoader,创建(例如,当调用枚举(的DataLoader)),num_workers被创建工作进程。在这一点上,数据集collat​​e_fnworker_init_fn被传递到每个工人,在那里它们被用来初始化,并获取数据。这意味着,数据集访问其内部IO一起,变换(包括collat​​e_fn)在工作进程中运行。

torch.utils.data.get_worker_info() 在一个工作进程返回各种有用的信息(包括工人ID,数据集的副本,初始种子等)在主处理中,并返回。用户可以在数据集中代码中使用此功能和/或worker_init_fn单独配置每个数据集的副本,并确定该代码是否在工作进程运行。例如,这可以是在分片数据集特别有用。

在地图风格数据集,主处理使用取样产生的索引,并将它们发送给工人。因此,任何洗牌随机化,其中通过分配指标来加载引导加载主进程完成。

对于迭代式的数据集,因为每个工作进程得到数据集对象的副本,幼稚多进程加载通常将导致复制的数据。使用 torch.utils.data.get_worker_info()和/或worker_init_fn中,用户可以配置每个复制品独立。 (参见 IterableDataset单证如何实现这一点。)对于类似的原因,在多进程加载时,drop_last参数下降到最后的非整批生产的每个工人的迭代式的数据集副本。

工人被关闭一旦达到迭代结束时,或者当迭代器将变为垃圾收集。

警告

它一般不建议恢复在多进程加载CUDA张量,因为许多微妙之处使用CUDA和多分享CUDA张量(见多处理 CUDA)。相反,我们建议使用自动存储器钉扎(即,设置pin_memory =真),这使得能够快速数据传输到支持CUDA的GPU。

特定于平台的行为

由于工人依靠Python的 多重处理 ,工人发射行为是在Windows上的不同比的Unix。

  • 在Unix,叉()为默认 多处理启动方法。使用叉(),童工通常可以直接通过克隆地址空间中的数据集和Python参数的函数访问。

  • 在Windows中,产卵()为默认 多处理启动方法。使用重生(),另一种解释是推出是运行在主脚本,然后由接收数据集内部职工功能, collat​​e_fn和通过 泡菜序列的其它参数。

这个单独的序列化意味着你应该采取两个步骤,以确保您与Windows兼容,同时使用多进程数据加载:

  • 包裹内你们中的大多数主要脚本代码,如果 __name__ == '__main__':块,使确保它不会再次运行(最有可能产生误差)时,每个工作进程启动。您可以将您的数据集和 的DataLoader实例创建逻辑在这里,因为它并不需要在工人重新执行。

  • 确保任何自定义collat​​e_fnworker_init_fn数据集代码声明顶层定义,__main__检查之外。这确保了他们在工作进程可用。 (这是需要,因为功能酸洗作为参考而已,不是字节码)。

随机性在多进程数据加载

默认情况下,每个工人将具有其PyTorch种子设为base_seed + worker_id,其中base_seed是一个长期的,通过使用其RNG主过程中产生的(从而,消耗了RNG状态强制)。但是,对于其他种子库可以在初始化工人(W.G.,NumPy的),使每个工人返回相同的随机数被复制。 (参见 本部分 在FAQ)。

worker_init_fn,则可以访问PyTorch种子集对每个工人用任一 torch.utils.data.get_worker_info()。种子 torch.initial_seed() ,并用它的数据加载之前种子其他库。

存储器钢钉

主机到GPU副本要快得多,当他们从固定(锁定页)内存起源。参见 使用固定的内存缓冲区 有关何时以及如何一般采用固定内存的更多细节。

为数据加载,使pin_memory =真的DataLoader 将自动把所获取的数据张量在钉扎存储器,并因此能够更快的数据传输到支持CUDA的GPU。

默认存储器锁定逻辑仅识别张量和地图以及包含张量iterables。默认情况下,如果锁定逻辑看到一个批次是一个自定义类型(这将如果您有发生collat​​e_fn 返回一个自定义的间歇式),或者如果每个元素的批量是一个自定义的类型,钉扎逻辑将无法识别它们,并且它会返回该批次(或那些元件)而没有钉扎的存储器。为了使存储器钉扎定制间歇或数据类型,定义上的自定义类型(多个)pin_memory()方法。

请参见下面的例子。

例:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

classtorch.utils.data.``DataLoader( dataset , batch_size=1 , shuffle=False , sampler=None , batch_sampler=None , num_workers=0 , collate_fn=None , pin_memory=False , drop_last=False , timeout=0 , worker_init_fn=None , multiprocessing_context=None )[source]

数据加载。结合了数据集和采样,并提供了在给定数据集的迭代。

的DataLoader同时支持地图风格和迭代式的数据集与单或多进程加载,定制加载顺序和可选的自动配料(对照)和内存牵制。

有关详细信息,请参见 torch.utils.data文档页面。

Parameters

  • 数据集数据集 ) - 从该数据集到加载数据。

  • 的batch_size INT 可选 ) - 如何每批许多样品加载(默认值:1)。

  • 洗牌 布尔 可选 ) - 设置为为具有在每个历元改组的数据(默认值:)。

  • 取样取样 可选 ) - 定义从数据集中得出样品的策略。如果指定,洗牌必须假 [HTG17。

  • batch_sampler取样 可选 ) - 象取样,但在同一时间返回一批指标。互斥与的batch_size洗牌取样drop_last

  • num_workers INT 可选 ) - 多少子过程用于数据加载。 0意味着数据将在主处理加载。 (默认值:0

  • collat​​e_fn可调用 可选 ) - 合并的样本的列表,以形成小批量张量(S)的。使用从图式集装批处理时使用。

  • pin_memory 布尔 可选 ) - 如果,数据装载将在返回之前复制到张量CUDA固定内存。如果数据元素是一个自定义类型,或你的collat​​e_fn返回一批即自定义类型,见下面的例子。

  • drop_last 布尔 可选 ) - 设置为放弃最后一批不全,如果数据集大小不是由批量大小整除。如果和数据集的大小是不是批量大小整除,则最后一批将较小。 (默认值:

  • 超时数字 可选 ) - 如果是阳性的,对于从工人收集一批的超时值。应始终非负。 (默认值:0

  • worker_init_fn可调用 可选 ) - 如果未,这将是叫上与工人ID每个工人子(在一个int [0, num_workers - 1])作为输入,在播种之后和数据加载之前。 (默认值:

Warning

如果使用菌种启动方法,worker_init_fn不能是unpicklable对象,例如,lambda函数。参见 多处理最佳实践 在PyTorch到多处理相关的更多细节。

Note

LEN(的DataLoader)启发式是基于所使用的取样器的长度。当数据集IterableDataset ,将使用一个无限采样器,其__len__ ()未实现,因为实际的长度取决于两个可迭代以及多进程加载构造。所以,除非他们有地图式的数据集工作,一个不应该查询该方法。参见数据集类型关于这两种类型的数据集的更多细节。

classtorch.utils.data.``Dataset[source]

表示 数据集的抽象类。

表示从键数据样本的地图所有数据集应该继承它。所有子类应该overrite __getitem __(),支持获取对于给定的密钥数据样本。子类还可以任选地覆盖__len __(),预计由返回的数据集的大小许多 取样实施方式和的 的DataLoader的默认选项。

Note

的DataLoader缺省构建一个索引采样能产生整数指数。为了使它与地图式的数据集与非整指数/键的作用,必须提供自定义采样。

classtorch.utils.data.``IterableDataset[source]

可迭代的数据集。

代表数据样本的迭代所有数据集应该继承它。当数据来自一个数据集流的这种形式是特别有用的。

所有子类应该overrite __iter __(),这将返回样本的迭代在该数据集。

当一个子类使用具有 的DataLoader,在数据集中的每个项目将被从得到的 的DataLoader迭代器。当num_workers & GT ; 0,每个工作进程将具有数据集对象的不同拷贝,因此通常希望独立地配置每个拷贝,以避免从工人返回重复数据。get_worker_info(),在一个工作进程调用时,返回关于工人的信息。它可以在任一使用的数据集的__iter __()方法或 的DataLoaderworker_init_fn选项来修改每个副本的行为。

实施例1:在所有工人分裂工作量__iter __()

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

实施例2:使用在所有工人分裂工作量worker_init_fn

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn`that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]

classtorch.utils.data.``TensorDataset( *tensors )[source]

数据集包装张量。

每个样品将沿所述第一维度的索引张量进行检索。

Parameters

*张量 张量 ) - 具有所述第一尺寸的大小相同张量。

classtorch.utils.data.``ConcatDataset( datasets )[source]

数据集作为多个数据集的串联。

这个类是组装不同的现有数据集是有用的。

Parameters

数据集序列 ) - 数据集的列表要连接

classtorch.utils.data.``ChainDataset( datasets )[source]

数据集chainning多个 IterableDataset秒。

这个类是组装不同的现有数据集流是有用的。该chainning操作上即时完成的,因此串联与此类大型数据集将是有效的。

Parameters

数据集IterableDataset 的迭代) - 数据集链接在一起

classtorch.utils.data.``Subset( dataset , indices )[source]

在指定的索引数据集的子集。

Parameters

  • 数据集数据集 ) - 整个数据集

  • 指数序列 ) - 在整个组索引选择的子集

torch.utils.data.``get_worker_info()[source]

返回当前 的DataLoader迭代工作进程的信息。

当一个工人叫,这将返回保证具有以下属性的对象:

  • ID:当前作业人员ID。

  • num_workers:工人的总数。

  • 种子:当前工人随机种子集。此值由主进程RNG和工人的ID来确定。参见 的DataLoader的更多细节的文档。

  • 数据集:数据集对象在 这里 过程的副本。请注意,这将是在不同的进程比一个主处理不同的对象。

当主过程调用,这将返回

Note

当所使用的worker_init_fn传递到 的DataLoader,该方法可以是设置每个工人有用过程不同,例如,使用worker_id配置数据集目的是只读分片数据集的特定部分,或使用种子种子中的数据集的代码(例如,NumPy的)使用其他文库。

torch.utils.data.``random_split( dataset , lengths )[source]

随机分割数据集到给定长度的非重叠的新的数据集。

Parameters

  • 数据集数据集 ) - 数据集要被分割

  • 长度序列 ) - 要产生裂缝的长度

classtorch.utils.data.``Sampler( data_source )[source]

基类的所有取样。

每采样的子类必须提供一个__iter __()的方法,提供一种方式来迭代数据集的元素的索引,和__len __()方法,它返回所返回的迭代器的长度。

Note

__len __()方法并不严格 的DataLoader必需的,但在涉及任何计算预期的 的DataLoader的长度。

classtorch.utils.data.``SequentialSampler( data_source )[source]

顺序地将样品的元素,总是以相同的顺序。

Parameters

DATA_SOURCE数据集 ) - 数据集以从采样

classtorch.utils.data.``RandomSampler( data_source , replacement=False , num_samples=None )[source]

样品元件中随机。如果不更换,然后从一个洗牌的数据集进行采样。如果具有置换,然后用户可指定num_samples绘制。

Parameters

  • data_source ( Dataset) – dataset to sample from

  • 替换 布尔 ) - 样品绘制替换如果,默认=False

  • num_samples INT ) - 样本的数目来绘制,默认=LEN(数据集)。该参数应该当替换是仅被指定。

classtorch.utils.data.``SubsetRandomSampler( indices )[source]

随机样本元素从指数的定列表,无需更换。

Parameters

指数序列 ) - 索引的序列

classtorch.utils.data.``WeightedRandomSampler( weights , num_samples , replacement=True )[source]

样品元素[0,..,LEN(权重)-1]与给定的概率(权重)。

Parameters

  • 权重序列 ) - 权重的顺序,没有必要总结到一个

  • num_samples INT ) - 样本的数目来绘制

  • 替换 布尔 ) - 如果,样品绘制更换。如果不是,他们绘制无需更换,这意味着当指数样本绘制为行,不能再为该行画出。

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

classtorch.utils.data.``BatchSampler( sampler , batch_size , drop_last )[source]

包装另一个采样,以产生小批量指数。

Parameters

  • 取样取样 ) - 基采样器。

  • 的batch_size INT ) - 小批量的大小。

  • drop_last 布尔 ) - 如果,采样器将下降的最后一批,如果它的规模将是小于的batch_size

Example

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

classtorch.utils.data.distributed.``DistributedSampler( dataset , num_replicas=None , rank=None , shuffle=True )[source]

取样器,限制数据加载到数据集的一个子集。

它与 torch.nn.parallel.DistributedDataParallel 结合特别有用。在这种情况下,每个过程可以通过一个DistributedSampler实例作为的DataLoader采样器,并加载原始数据集即排它的一个子集。

Note

数据集被认为是恒定的大小。

Parameters

  • 数据集 - 数据集用于采样。

  • num_replicas可选 ) - 的参与分布式训练的进程数。

  • 可选 ) - num_replicas内的当前过程的秩。

  • 洗牌可选 ) - 如果为true(默认值),采样器将会洗牌指数

Next Previous


©版权所有2019年,Torch 贡献者。


Copyright © ibooker.org.cn 2019 all right reserved,由 ApacheCN 团队提供支持该文件修订时间: 2019-09-24 01:39:32

results matching ""

    No results matching ""

    results matching ""

      No results matching ""