Beispiel #1
0
 def create_sampler(self):
     data = range(self.task_length)
     if self.world_size == 1 and self.num_workers in (0, 1):
         if self.shuffle:
             self.sampler = RandomSampler(data, generator=self.generator)
         else:
             self.sampler = SequentialSampler(data)
     else:
         num_workers = 1 if self.num_workers in (None,
                                                 0) else self.num_workers
         num_replicas = num_workers * self.world_size
         current_seed = self.initial_seed + self.current_task_iteration
         self.sampler = DistributedSampler(data,
                                           num_replicas=num_replicas,
                                           rank=self.worker_rank,
                                           shuffle=self.shuffle,
                                           seed=current_seed)
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return get_tpu_sampler(self.train_dataset)
        else:
            if self.args.sortish_sampler:
                self.train_dataset.make_sortish_sampler(
                    self.args.per_device_train_batch_size,
                    distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
                )

            return (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )
Beispiel #3
0
 def get_data_loader(self, features, batch_size, local_rank, set_type):
     """See base class."""
     input_ids = torch.cat([f.input_ids for f in features], dim=0)
     attention_mask = torch.cat([f.attention_mask for f in features], dim=0)
     token_type_ids = torch.cat([f.token_type_ids for f in features], dim=0)
     label_ids = torch.cat([f.label_ids for f in features], dim=0)
     data = TensorDataset(input_ids, attention_mask, token_type_ids,
                          label_ids)
     if set_type == 'train':
         if local_rank == -1:
             sampler = RandomSampler(data)
         else:
             sampler = DistributedSampler(data)
     else:
         sampler = SequentialSampler(data)
     dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size)
     return dataloader
Beispiel #4
0
def hogwild(model_class, procs, epochs, arch, distributed, nodes, batches):

    torch.set_num_threads(nodes)

    device = torch.device("cpu")

    model = model_class.to(device)

    if distributed == 'y':

        processes = []

        for rank in range(procs):

            #mp.set_start_method('spawn')

            model.share_memory()

            train_loader = torch.utils.data.DataLoader(
                dataset=trainset,
                batch_size=batches,
                sampler=DistributedSampler(dataset=trainset,
                                           num_replicas=procs,
                                           rank=rank))

            p = mp.Process(target=train,
                           args=(epochs, arch, model, device, train_loader))

            p.start()

            processes.append(p)

        for p in processes:
            p.join()

        test(model, device, test_loader, arch)

    else:

        train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                                   batch_size=batches,
                                                   shuffle=True)

        train(epochs, arch, model, device, train_loader)

        test(model, device, test_loader, arch)
Beispiel #5
0
def create_train_loader(cfg, train_d_set):
    if cfg.performance.ddp:
        if cfg.data_loader.weighted_sampler:
            t_sampler = utils.ddp_utils.DistributedWeightedSampler(train_d_set)
        else:
            t_sampler = DistributedSampler(train_d_set)
    elif cfg.data_loader.weighted_sampler:
        train_d_size = len(train_d_set)
        weights = [1.0] * train_d_size
        t_sampler = WeightedRandomSampler(weights=weights, num_samples=train_d_size, replacement=True)
    else:
        t_sampler = RandomSampler(train_d_set)
    train_data_loader = DataLoader(train_d_set, batch_size=cfg.data_loader.batch_size_train,
                                   num_workers=cfg.data_loader.n_workers, drop_last=cfg.data_loader.drop_last,
                                   sampler=t_sampler,
                                   pin_memory=True)
    return train_data_loader
Beispiel #6
0
 def get_dataloader(self, mode):
     if self.hp.data.dataset == "synthetic":
         dataset = CopyDataSet(self.hp, self.args)
     elif self.hp.data.dataset == "music":
         dataset = MusicDataset(self.hp, self.args, self.get_pathlist(mode),
                                mode == "train")
     else:
         raise NotImplementedError
     sampler = RandomSampler(dataset) if platform.system() == "Windows"\
             else DistributedSampler(dataset, shuffle=True)
     return DataLoader(dataset,
                       batch_size=self.args.batch_size,
                       shuffle=False,
                       num_workers=self.hp.train.num_workers,
                       pin_memory=True,
                       drop_last=True,
                       worker_init_fn=init_fn,
                       sampler=sampler)
Beispiel #7
0
 def _create_data_loader(self,
                         data_transform,
                         data_partition,
                         sample_rate=None):
     sample_rate = sample_rate or self.hparams.sample_rate
     dataset = SliceData(root=self.hparams.data_path /
                         f'{self.hparams.challenge}_{data_partition}',
                         transform=data_transform,
                         sample_rate=sample_rate,
                         challenge=self.hparams.challenge)
     sampler = DistributedSampler(dataset)
     return DataLoader(
         dataset=dataset,
         batch_size=self.hparams.batch_size,
         num_workers=8,
         pin_memory=True,
         sampler=sampler,
     )
Beispiel #8
0
 def evaluate(self):
     dataset_val = build_dataset(image_set='val', args=self.args)
     if self.args.distributed:
         sampler_val = DistributedSampler(dataset_val, shuffle=False)
     else:
         sampler_val = torch.utils.data.SequentialSampler(dataset_val)
     data_loader_val = DataLoader(dataset_val,
                                  self.args.batch_size,
                                  sampler=sampler_val,
                                  drop_last=False,
                                  collate_fn=utils.collate_fn,
                                  num_workers=self.args.num_workers)
     checkpoint = torch.load(self.args.pretrained, map_location='cpu')
     self.model_without_ddp.load_state_dict(checkpoint['model'],
                                            strict=False)
     test_stats = evaluate_hoi(self.args.dataset_file, self.model,
                               self.postprocessors, data_loader_val,
                               self.args.subject_category_id, self.device)
Beispiel #9
0
    def get_dataloader(self, dataset):
        if self.distributed:
            sampler = DistributedSampler(
                dataset)
            local_bs = self.args.batch_size // self.world_size
            is_shuffle = False
            loader = DataLoader(dataset,
                   batch_size=local_bs,
                   num_workers=self.args.workers,
                   drop_last=False,
                   shuffle=is_shuffle,
                   pin_memory=True,
                   sampler=sampler)

        else:
            raise NotImplementedError

        return loader
Beispiel #10
0
 def init_test_dataset(self):
     dist_loader_text = "distributed" if self.args.get_distributed() else ""
     self.get_args().get_logger().debug(
         f"Loading '{dist_loader_text}' Fashion MNIST test data")
     self.test_dataset = datasets.FashionMNIST(
         root=self.get_args().get_data_path(),
         train=False,
         download=True,
         transform=transforms.Compose([transforms.ToTensor()]))
     self.test_sampler = DistributedSampler(
         self.test_dataset,
         rank=self.args.get_rank(),
         num_replicas=self.args.get_world_size(
         )) if self.args.get_distributed() else None
     # self.test_sampler = None
     self.test_loader = DataLoader(self.test_dataset,
                                   batch_size=16,
                                   sampler=self.test_sampler)
Beispiel #11
0
 def __iter__(self):
     self.random.shuffle(self.data_files)
     for data_file in self.data_files:
         train_data = BertPretrainingPreprocessedDataset(
             input_file=data_file,
             max_predictions_per_seq=self.max_predictions_per_seq)
         train_sampler = DistributedSampler(train_data)
         # print("---")
         # print(os.getpid(), train_sampler.rank, train_sampler.num_replicas, train_sampler.num_samples)
         # print("---")
         train_dataloader = DataLoader(
             dataset=train_data,
             sampler=train_sampler,
             batch_size=self.batch_size,
             shuffle=False,
         )
         for x in train_dataloader:
             yield x
Beispiel #12
0
def get_loader(args):
    img_preprocess = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    if args.dataset == "cifar10":
        trainset = datasets.CIFAR10(root="./data",
                                    train=True,
                                    download=True,
                                    transform=img_preprocess)
        testset = datasets.CIFAR10(
            root="./data",
            train=False,
            download=True,
            transform=img_preprocess) if args.local_rank in [-1, 0] else None

    else:
        trainset = datasets.CIFAR100(root="./data",
                                     train=True,
                                     download=True,
                                     transform=img_preprocess)
        testset = datasets.CIFAR100(
            root="./data",
            train=False,
            download=True,
            transform=img_preprocess) if args.local_rank in [-1, 0] else None

    train_sampler = RandomSampler(
        trainset) if args.local_rank == -1 else DistributedSampler(trainset)
    test_sampler = SequentialSampler(testset)
    train_loader = DataLoader(trainset,
                              sampler=train_sampler,
                              batch_size=args.train_batch_size,
                              num_workers=4,
                              pin_memory=True)
    test_loader = DataLoader(testset,
                             sampler=test_sampler,
                             batch_size=args.eval_batch_size,
                             num_workers=4,
                             pin_memory=True) if testset is not None else None

    return train_loader, test_loader
 def __init__(self, model, opt_eval, env):
     super().__init__(model, opt_eval, env)
     self.batch_sz = opt_eval['batch_size']
     self.fid_batch_size = opt_get(opt_eval, ['fid_batch_size'], 64)
     assert self.batch_sz is not None
     self.dataset = create_dataset(opt_eval['dataset'])
     if torch.distributed.is_available(
     ) and torch.distributed.is_initialized():
         self.sampler = DistributedSampler(self.dataset,
                                           shuffle=False,
                                           drop_last=True)
     else:
         self.sampler = SequentialSampler(self.dataset)
     self.fid_real_samples = opt_eval['dataset'][
         'paths']  # This is assumed to exist for the given dataset.
     assert isinstance(self.fid_real_samples, str)
     self.gd = GaussianDiffusionInferenceInjector(
         opt_eval['diffusion_params'], env)
     self.out_key = opt_eval['diffusion_params']['out']
    def get_dataloader(self, dataset, batch_size):
        if dataset is None:
            raise ValueError("Trainer: requires a dataset.")
        if transformers.is_torch_tpu_available():
            sampler = get_tpu_sampler(dataset)
        elif not isinstance(dataset, IterableDataset):
            sampler = (RandomSampler(dataset) if self.args.local_rank == -1
                       else DistributedSampler(dataset))
        else:
            sampler = None

        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            collate_fn=self.data_collator,
        )

        return data_loader
Beispiel #15
0
def create_val_loader(conf: DictConfig,
                      rank: Optional[int] = None,
                      num_replicas: Optional[int] = None) -> DataLoader:
    data = create_dataset(conf, 'val')
    bs = conf.loader.batch_size
    num_workers = conf.get('loader.workers', 0)
    sampler = None
    if num_replicas is not None and num_replicas > 1:
        if not isinstance(rank, int):
            raise AttributeError("Rank is missing")
        sampler = DistributedSampler(data,
                                     rank=rank,
                                     num_replicas=num_replicas,
                                     shuffle=False)
    loader = DataLoader(data,
                        sampler=sampler,
                        batch_size=bs,
                        num_workers=num_workers)
    return loader
Beispiel #16
0
    def run(self, num_processes: int) -> None:
        """
        Args:
            num_processes:
        Returns:
        """

        # Load the Dataset
        mnist_train, mnist_test = self.dataset_loader()
        # Create Train and Test loaders
        if self.data_parallel:
            mode = "Data Parallel"
            train_sampler = DistributedSampler(mnist_train)
            shuffle = False
        else:
            mode = "Non-parallel"
            train_sampler = None
            shuffle = True
        if self.rank in (None, 0):
            self.logger.info(
                f"{mode} mode with {num_processes} proc(s) requested..")
        train_loader = torch.utils.data.DataLoader(
            mnist_train,
            batch_size=self.batch_size_train,
            shuffle=shuffle,
            sampler=train_sampler)
        test_loader = torch.utils.data.DataLoader(
            mnist_test, batch_size=self.batch_size_test, shuffle=True)
        # Train and Test
        if self.data_parallel:
            train_results, test_results = self.run_data_parallel(
                train_loader, test_loader)
        else:
            train_results, test_results = self.run_non_parallel(
                train_loader, test_loader)
        # Save Results
        if self.rank in (None, 0):
            self.store_results(data=train_results,
                               num_processes=num_processes,
                               train=True)
            self.store_results(data=test_results,
                               num_processes=num_processes,
                               train=False)
Beispiel #17
0
def initialize_data_loader(data_dir, batch_size,
                           num_data_workers) -> Tuple[DataLoader, DataLoader]:
    traindir = os.path.join(data_dir, "train")
    valdir = os.path.join(data_dir, "val")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            # pyre-fixme[16]: Module `transforms` has no attribute
            #  `RandomResizedCrop`.
            transforms.RandomResizedCrop(224),
            # pyre-fixme[16]: Module `transforms` has no attribute
            #  `RandomHorizontalFlip`.
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
    )
    train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_data_workers,
        pin_memory=True,
        sampler=train_sampler,
    )
    val_loader = DataLoader(
        datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]),
        ),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_data_workers,
        pin_memory=True,
    )
    return train_loader, val_loader
Beispiel #18
0
    def load_test_dataset(self):
        self.get_args().get_logger().debug("Loading CIFAR100 test data")

        normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
        test_dataset = datasets.CIFAR100(root=self.get_args().get_data_path(), train=False, download=True,
                                        transform=transform)
        sampler = DistributedSampler(test_dataset, rank=self.args.get_rank(),
                                     num_replicas=self.args.get_world_size()) if self.args.get_distributed() else None
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), sampler=sampler)
        self.args.set_sampler(sampler)

        test_data = self.get_tuple_from_data_loader(test_loader)

        self.get_args().get_logger().debug("Finished loading CIFAR10 test data")

        return test_data
Beispiel #19
0
 def batch_extract(self,
                   imglist,
                   img_prefix,
                   imgs_per_gpu=1,
                   workers_per_gpu=4):
     # build dataset
     dataset = PersonDataset(imglist, img_prefix=img_prefix)
     # get dist info
     rank, world_size = get_dist_info()
     # build data loader
     sampler = DistributedSampler(dataset, world_size, rank, shuffle=False)
     data_loader = DataLoader(dataset,
                              batch_size=imgs_per_gpu,
                              sampler=sampler,
                              num_workers=workers_per_gpu,
                              collate_fn=partial(
                                  collate, samples_per_gpu=imgs_per_gpu),
                              pin_memory=False)
     results = self.multi_gpu_test(data_loader)
     return results
Beispiel #20
0
 def _get_training_loader(self, folds: List[int], name: str) -> DataLoader:
     if name == "val":
         batch_size = self.params["training_params"].get("test_batch_size")
     else:
         batch_size = self.params["training_params"].get("train_batch_size")
     constructor = getattr(self, f'{name}_constructor')
     setattr(
         self, f'{name}_set',
         constructor(
             data_folder=self.params["data_params"].get("train_path"),
             folds=folds,
             tokenizer=self.tokenzier,
             max_len=self.params["data_params"].get("max_len")))
     dataset = getattr(self, f'{name}_set')
     sampler = DistributedSampler(dataset,
                                  num_replicas=xm.xrt_world_size(),
                                  rank=xm.get_ordinal())
     return DataLoader(dataset=getattr(self, f'{name}_set'),
                       batch_size=batch_size,
                       sampler=sampler)
Beispiel #21
0
    def _init_train_sampler(self):
        if self.train_dataset is None:
            return None

        if self.local_rank == -1:
            if self.train_weights is None or self.train_weights[
                    'sampler_weights'] is None:
                train_sampler = RandomSampler(self.train_dataset)
            else:
                assert len(self.train_weights['sampler_weights']) == len(
                    self.train_dataset)
                train_sampler = WeightedRandomSampler(
                    self.train_weights['sampler_weights'],
                    len(self.train_dataset))
        else:
            train_sampler = DistributedSampler(self.train_dataset)

        logger.info(f'Used train sampler: {type(train_sampler).__name__}.')

        return train_sampler
Beispiel #22
0
    def init_test_dataset(self):
        self.get_args().get_logger().debug("Loading CIFAR10 test data")

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        self.test_dataset = datasets.CIFAR10(
            root=self.get_args().get_data_path(),
            train=False,
            download=True,
            transform=transform)
        self.test_sampler = DistributedSampler(
            self.test_dataset,
            rank=self.args.get_rank(),
            num_replicas=self.args.get_world_size(
            )) if self.args.get_distributed() else None
        # self.test_sampler = None
        self.test_loader = DataLoader(self.test_dataset,
                                      batch_size=16,
                                      sampler=self.test_sampler)
Beispiel #23
0
 def __init__(self,
              model,
              training_set,
              batch_size,
              learning_rate,
              validation_set=None,
              test_set=None,
              checkpoint_dir=None,
              sampler=None):
     self.model = model
     self.checkpoint_dir = checkpoint_dir
     self.sampler = sampler or DistributedSampler(
         training_set, num_replicas=1, rank=0)
     self.train_loader = self._get_data_loader(training_set,
                                               batch_size,
                                               sampler=self.sampler)
     self.validation_loader = self._get_data_loader(validation_set,
                                                    batch_size=None)
     self.test_loader = self._get_data_loader(test_set, batch_size=None)
     self.optimizer = self._get_optimizer(model, learning_rate)
Beispiel #24
0
def prepare_dataloaders(opt):
    # Get data, data loaders and collate function ready
    trainset = WavLandmarksDataset(opt, mode='train')
    valset = WavLandmarksDataset(opt, mode='test')
    collate_fn = MelLandmarkCollate(opt)

    train_sampler = DistributedSampler(trainset) \
        if opt['distributed_run'] else None  # TODO a better sampler
    train_loader = DataLoader(trainset,
                              num_workers=opt['num_workers'],
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=opt['batch_size'],
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)
    # valset = DataLoader(valset, num_workers=0,
    #                         shuffle=False, batch_size=opt['batch_size'],
    #                         pin_memory=False, collate_fn=collate_fn)
    return train_loader, valset, collate_fn
def get_loaders_objectnet(dataroot, imagenet_dataroot, val_batch_size,
                          input_size, workers, num_nodes, local_rank):
    # TODO: pin-memory currently broken for distributed
    pin_memory = False
    # TODO: datasets.ImageNet
    val_data_im = datasets.ImageFolder(
        root=os.path.join(imagenet_dataroot, 'val'),
        transform=get_transform_imagenet(False, input_size))
    # TODO: datasets.ImageNet
    val_data = datasets.ImageFolder(root=os.path.join(dataroot, 'images'),
                                    transform=get_transform_imagenet(
                                        False, input_size))
    val_sampler = DistributedSampler(val_data, num_nodes, local_rank)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=val_batch_size,
                                             sampler=val_sampler,
                                             num_workers=workers,
                                             pin_memory=pin_memory)
    imagenet_to_objectnet, objectnet_to_imagenet, objectnet_both, imagenet_both = objectnet_imagenet_mappings(
        dataroot, val_data, val_data_im)
    return val_loader, imagenet_to_objectnet, objectnet_to_imagenet, objectnet_both, imagenet_both
Beispiel #26
0
def prepare_data_loader(features, args, mode):

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    all_valid_ids = torch.tensor([f.valid_ids for f in features], dtype=torch.long)
    all_label_mask = torch.tensor([f.label_mask for f in features], dtype=torch.long)
    data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_valid_ids, all_label_mask)

    if mode == 'train':
        if args.local_rank == -1:
            sampler = RandomSampler(data)
        else:
            sampler = DistributedSampler(data)
        return DataLoader(data, sampler=sampler, batch_size=args.train_batch_size)
    elif mode == 'eval':
        sampler = SequentialSampler(data)
        return DataLoader(data,sampler=sampler,batch_size=args.eval_batch_size)
    else:
        raise ValueError("Invalid mode %s" % mode)
def create_train_loader(conf, rank=None, num_replicas=None):
    # type: (DictConfig, Optional[int], Optional[int]) -> Sized
    build_ds = {
        'simple': create_simple_dataset,
        'image_folder': create_image_folder_dataset,
    }
    ds_type = conf.type
    data = build_ds[ds_type](conf, conf.transforms)
    print("Found {} images".format(len(data)))

    sampler = None
    if num_replicas is not None:
        sampler = DistributedSampler(data,
                                     num_replicas=num_replicas,
                                     rank=rank)

    loader = DataLoader(data,
                        sampler=sampler,
                        batch_size=conf.loader.batch_size,
                        num_workers=conf.get('loader.workers', 0),
                        drop_last=True)
    return loader
Beispiel #28
0
def build_dataloader(dataset, **kwargs):
    kwargs.setdefault('shuffle', None)
    kwargs.setdefault('sampler', None)
    kwargs.setdefault('batch_size', 1)

    sampler = kwargs.pop('sampler')
    shuffle = kwargs.pop('shuffle')
    batch_size = kwargs.pop('batch_size')

    if dist.is_initialized() and sampler is None:
        sampler = DistributedSampler(dataset)
        if batch_size > dist.get_world_size():
            batch_size = batch_size // dist.get_world_size()

    if shuffle is None:
        shuffle = sampler is None

    loader = DataLoader(
        dataset, sampler=sampler, batch_size=batch_size,
        shuffle=shuffle, **kwargs
    )
    return loader, sampler
Beispiel #29
0
def load_data(data_folder, tv_prop, batch_size):
    full_ds = datasets.ImageFolder(data_folder + '/train_val_images',
                                   transform=train_transforms)
    train_size = int(tv_prop * len(full_ds))
    val_size = len(full_ds) - train_size
    print(
        f"Using train size {train_size} and val size {val_size} with batch size {batch_size}"
    )
    train_ds, val_ds = random_split(full_ds, [train_size, val_size])
    train_sampler = DistributedSampler(train_ds)
    # val_sampler = DistributedSampler(val_ds)
    train_loader = torch.utils.data.DataLoader(train_ds,
                                               batch_size=batch_size,
                                               sampler=train_sampler,
                                               num_workers=0)
    #    drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_ds,
        batch_size=batch_size,
        #  sampler=val_sampler,
        num_workers=0)
    return train_loader, val_loader
def create_val_loader(conf, rank=None, num_replicas=None, mean=None, std=None):
    # type: (DictConfig, Optional[int], Optional[int], Optional[float_3], Optional[float_3]) -> DataLoader
    show_progress = rank is None or rank == 0
    data = create_dataset(conf, show_progress=show_progress, name="val")

    sampler = None
    if num_replicas is not None:
        sampler = DistributedSampler(data,
                                     num_replicas=num_replicas,
                                     rank=rank,
                                     shuffle=False)

    loader = DataLoader(data,
                        sampler=sampler,
                        batch_size=conf.loader.batch_size,
                        num_workers=conf.get('loader.workers', 0),
                        collate_fn=fast_collate,
                        drop_last=False,
                        shuffle=not sampler)
    if conf.loader.prefetch:
        loader = PrefetchLoader(loader, mean=mean, std=std)
    return loader