colossalai.utils.data_sampler.data_parallel_sampler

class colossalai.utils.data_sampler.data_parallel_sampler.DataParallelSampler(dataset, shuffle=False, seed=0, drop_last=False)

A data sampler for distributed data parallelism

Parameters
  • dataset (torch.utils.data.Dataset) – A Dataset instance

  • shuffle (bool, optional) – Whether to shuffle data, defaults to False

  • seed (int, optional) – The random seed, defaults to 0

  • drop_last (bool, optional) – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False

set_epoch(epoch)

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

colossalai.utils.data_sampler.data_parallel_sampler.get_dataloader(dataset, shuffle=False, seed=1024, add_sampler=True, drop_last=False, pin_memory=False, num_workers=0, **kwargs)

Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)

Note

When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data on the 1st stage and label on the last stage

Parameters
  • dataset (torch.utils.data.Dataset) – A torch.utils.data.Dataset object

  • shuffle (bool, optional. Default is False) – Whether to shuffle the dataset

  • seed (int, optional. Default is 1024) – Random worker seed, defaults to 1024

  • add_sampler (bool, optional. Default is True) – Add DistributedDataParallelSampelr to the dataset

  • drop_last (bool, optional. Default is False) – Drop the last incomplete batch of data

  • pin_memory (bool, optional. Default is False) – Whether to pin memory address in CPU memory

  • num_workers (int, optional. Default is 0) – Number of worker threads for this dataloader

Returns

A object of torch.utils.data.DataLoader

Return type

torch.utils.data.DataLoader