Exemple #1
0
def load_dataset(args,
                 INPUT_SIZE=[112, 112],
                 RGB_MEAN=[0.5, 0.5, 0.5],
                 RGB_STD=[0.5, 0.5, 0.5],
                 val_datasets=[
                     'lfw', 'cfp_ff', 'cfp_fp', 'agedb_30', 'calfw', 'cplfw',
                     'vgg2_fp'
                 ]):
    train_transform = transforms.Compose([
        transforms.Resize(
            [int(128 * INPUT_SIZE[0] / 112),
             int(128 * INPUT_SIZE[0] / 112)]),
        transforms.RandomCrop([INPUT_SIZE[0], INPUT_SIZE[1]]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=RGB_MEAN, std=RGB_STD),
    ])
    train_data = dset.ImageFolder(
        os.path.join(args.data_path, 'CASIA-maxpy-align'), train_transform)
    weights = torch.DoubleTensor(
        make_weights_for_balanced_classes(train_data.imgs,
                                          len(train_data.classes)))
    if args.distributed:
        from catalyst.data.sampler import DistributedSamplerWrapper
        train_sampler = DistributedSamplerWrapper(
            WeightedRandomSampler(weights, len(weights)))
    else:
        train_sampler = WeightedRandomSampler(weights, len(weights))
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(
            [int(128 * INPUT_SIZE[0] / 112),
             int(128 * INPUT_SIZE[0] / 112)]),
        transforms.CenterCrop([INPUT_SIZE[0], INPUT_SIZE[1]]),
        transforms.ToTensor(),
        transforms.Normalize(mean=RGB_MEAN, std=RGB_STD)
    ])
    val_loaders = []
    for name in val_datasets:
        carray = bcolz.carray(rootdir=os.path.join(args.data_path, name),
                              mode='r')
        val_data_tensor = torch.tensor(carray[:, [2, 1, 0], :, :]) * 0.5 + 0.5
        val_data = TensorsDataset(val_data_tensor, val_transform)
        val_loader = torch.utils.data.DataLoader(val_data,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 sampler=None)
        issame = np.load('{}/{}_list.npy'.format(args.data_path, name))
        val_loaders.append((name, val_loader, issame))

    return train_loader, val_loaders
Exemple #2
0
def _force_make_distributed_loader(loader: DataLoader) -> DataLoader:
    """
    Transfers loader to distributed mode. Experimental feature.

    Args:
        loader: pytorch dataloder

    Returns:
        DataLoader: pytorch dataloder with distributed sampler.
    """
    from catalyst.data.sampler import DistributedSamplerWrapper

    sampler = (DistributedSampler(dataset=loader.dataset) if getattr(
        loader, "sampler", None) is not None else DistributedSamplerWrapper(
            sampler=loader.sampler))
    loader = DataLoader(
        dataset=copy(loader.dataset),
        batch_size=loader.batch_size,
        # shuffle=loader.shuffle,
        sampler=sampler,
        # batch_sampler=loader.batch_sampler,
        num_workers=loader.num_workers,
        # collate_fn=loader.collate_fn,
        pin_memory=loader.pin_memory,
        drop_last=loader.drop_last,
    )
    return loader
Exemple #3
0
 def _get_balanced_train_dataloader(self, dataset, drop_last=False):
     sampler = WeightedRandomSampler(dataset.sample_weights,
                                     len(dataset.sample_weights))
     if is_initialized():
         sampler = DistributedSamplerWrapper(sampler)
     return DataLoader(
         dataset,
         sampler=sampler,
         batch_size=self.datarc["batch_size"],
         drop_last=drop_last,
         num_workers=self.datarc["num_workers"],
         collate_fn=dataset.collate_fn,
     )
def _mp_fn(rank, flags):
    device = xm.xla_device()
    net.to(device)

    train_sampler = DistributedSamplerWrapper(sampler=BalanceClassSampler(
        labels=train_dataset.get_labels(), mode="downsampling"),
                                              num_replicas=xm.xrt_world_size(),
                                              rank=xm.get_ordinal(),
                                              shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TrainGlobalConfig.batch_size,
        sampler=train_sampler,
        pin_memory=False,
        drop_last=True,
        num_workers=TrainGlobalConfig.num_workers,
    )
    if rank == 0:
        time.sleep(1)

    fitter = TPUFitter(model=net, device=device, config=TrainGlobalConfig)
    fitter.fit(train_loader)
    if os.path.isfile(args.val_file) and args.val_tune == 1:
        validation_sampler = torch.utils.data.distributed.DistributedSampler(
            validation_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)
        validation_loader = torch.utils.data.DataLoader(
            validation_dataset,
            batch_size=TrainGlobalConfig.batch_size,
            sampler=validation_sampler,
            pin_memory=False,
            drop_last=False,
            num_workers=TrainGlobalConfig.num_workers)
        validation_tune_sampler = torch.utils.data.distributed.DistributedSampler(
            validation_tune_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
        validation_tune_loader = torch.utils.data.DataLoader(
            validation_tune_dataset,
            batch_size=TrainGlobalConfig.batch_size,
            sampler=validation_tune_sampler,
            pin_memory=False,
            drop_last=False,
            num_workers=TrainGlobalConfig.num_workers)
        fitter.run_tuning_and_inference(validation_tune_loader)
Exemple #5
0
def _change_dl(k, dl, shuffle):
    old_dl = dl
    train_sampler = DistributedSamplerWrapper(
        sampler=BalanceClassSampler(labels=k.train_dataset.get_labels(),
                                    mode="downsampling"),
        num_replicas=8,  # xm.xrt_world_size(),
        rank=0,  # xm.get_ordinal(), it only get 1/8 data ....
        shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        k.train_dataset,
        batch_size=k.config.batch_size,
        sampler=train_sampler,
        pin_memory=True,
        drop_last=True,
        num_workers=k.config.num_workers,
    )
    new_dl = train_loader

    return old_dl, new_dl, train_sampler
Exemple #6
0
 def _get_loader(
     dataset: Dataset,
     sampler: Sampler,
     initial_seed: int,
     params: DictConfig,
 ) -> DataLoader:
     params = OmegaConf.to_container(params, resolve=True)
     per_gpu_scaling = params.pop("per_gpu_scaling", False)
     params["dataset"] = dataset
     distributed_rank = get_rank()
     distributed = distributed_rank > -1
     if per_gpu_scaling and not distributed:
         num_gpus = max(1, torch.cuda.device_count())
         assert ("batch_size"
                 in params), "loader config must contain 'batch_size' key"
         assert ("num_workers"
                 in params), "loader config must contain 'num_workers' key"
         params["batch_size"] *= num_gpus
         params["num_workers"] *= num_gpus
     if distributed:
         if sampler is not None:
             if not isinstance(sampler, DistributedSampler):
                 sampler = DistributedSamplerWrapper(sampler=sampler)
         else:
             sampler = DistributedSampler(dataset=params["dataset"])
     params["shuffle"] = params.get("shuffle", False) and sampler is None
     params["sampler"] = sampler
     worker_init_fn = params.pop("worker_init_fn", None)
     if worker_init_fn is None:
         params["worker_init_fn"] = lambda x: set_global_seed(initial_seed +
                                                              x)
     else:
         params["worker_init_fn"] = hydra.utils.get_method(worker_init_fn)
     collate_fn = params.pop("collate_fn", None)
     if collate_fn is None:
         params["collate_fn"] = None
     else:
         params["collate_fn"] = hydra.utils.get_method(collate_fn)
     loader: DataLoader = DataLoader(**params)
     return loader
Exemple #7
0
def get_loaders_from_params(
    batch_size: int = 1,
    num_workers: int = 0,
    drop_last: bool = False,
    per_gpu_scaling: bool = False,
    loaders_params: Dict[str, Any] = None,
    samplers_params: Dict[str, Any] = None,
    initial_seed: int = 42,
    get_datasets_fn: Callable = None,
    **data_params,
) -> "OrderedDict[str, DataLoader]":
    """
    Creates pytorch dataloaders from datasets and additional parameters.

    Args:
        batch_size: ``batch_size`` parameter
            from ``torch.utils.data.DataLoader``
        num_workers: ``num_workers`` parameter
            from ``torch.utils.data.DataLoader``
        drop_last: ``drop_last`` parameter
            from ``torch.utils.data.DataLoader``
        per_gpu_scaling: boolean flag,
            if ``True``, uses ``batch_size=batch_size*num_available_gpus``
        loaders_params (Dict[str, Any]): additional loaders parameters
        samplers_params (Dict[str, Any]): additional sampler parameters
        initial_seed: initial seed for ``torch.utils.data.DataLoader``
            workers
        get_datasets_fn(Callable): callable function to get dictionary with
            ``torch.utils.data.Datasets``
        **data_params: additional data parameters
            or dictionary with ``torch.utils.data.Datasets`` to use for
            pytorch dataloaders creation

    Returns:
        OrderedDict[str, DataLoader]: dictionary with
            ``torch.utils.data.DataLoader``

    Raises:
        NotImplementedError: if datasource is out of `Dataset` or dict
        ValueError: if batch_sampler option is mutually
            exclusive with distributed
    """
    from catalyst.data.sampler import DistributedSamplerWrapper

    default_batch_size = batch_size
    default_num_workers = num_workers
    loaders_params = loaders_params or {}
    assert isinstance(
        loaders_params,
        dict), f"`loaders_params` should be a Dict. " f"Got: {loaders_params}"
    samplers_params = samplers_params or {}
    assert isinstance(
        samplers_params,
        dict), f"`samplers_params` should be a Dict. Got: {samplers_params}"

    distributed_rank = get_rank()
    distributed = distributed_rank > -1

    if get_datasets_fn is not None:
        datasets = get_datasets_fn(**data_params)
    else:
        datasets = dict(**data_params)

    loaders = OrderedDict()
    for name, datasource in datasets.items():  # noqa: WPS426
        assert isinstance(
            datasource,
            (Dataset, dict
             )), f"{datasource} should be Dataset or Dict. Got: {datasource}"

        loader_params = loaders_params.pop(name, {})
        assert isinstance(loader_params,
                          dict), f"{loader_params} should be Dict"

        sampler_params = samplers_params.pop(name, None)
        if sampler_params is None:
            if isinstance(datasource, dict) and "sampler" in datasource:
                sampler = datasource.pop("sampler", None)
            else:
                sampler = None
        else:
            sampler = SAMPLER.get_from_params(**sampler_params)
            if isinstance(datasource, dict) and "sampler" in datasource:
                datasource.pop("sampler", None)

        batch_size = loader_params.pop("batch_size", default_batch_size)
        num_workers = loader_params.pop("num_workers", default_num_workers)

        if per_gpu_scaling and not distributed:
            num_gpus = max(1, torch.cuda.device_count())
            batch_size *= num_gpus
            num_workers *= num_gpus

        loader_params = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "pin_memory": torch.cuda.is_available(),
            "drop_last": drop_last,
            **loader_params,
        }

        if isinstance(datasource, Dataset):
            loader_params["dataset"] = datasource
        elif isinstance(datasource, dict):
            assert "dataset" in datasource, "You need to specify dataset for dataloader"
            loader_params = merge_dicts(datasource, loader_params)
        else:
            raise NotImplementedError

        if distributed:
            if sampler is not None:
                if not isinstance(sampler, DistributedSampler):
                    sampler = DistributedSamplerWrapper(sampler=sampler)
            else:
                sampler = DistributedSampler(dataset=loader_params["dataset"])

        loader_params["shuffle"] = name.startswith("train") and sampler is None

        loader_params["sampler"] = sampler

        if "batch_sampler" in loader_params:
            if distributed:
                raise ValueError("batch_sampler option is mutually "
                                 "exclusive with distributed")

            for k in ("batch_size", "shuffle", "sampler", "drop_last"):
                loader_params.pop(k, None)

        if "worker_init_fn" not in loader_params:
            loader_params["worker_init_fn"] = lambda x: set_global_seed(
                initial_seed + x)

        loaders[name] = DataLoader(**loader_params)

    return loaders
    def _mp_fn(rank, flags, k=k):
        device = xm.xla_device(devkind='TPU')
        logger.debug("%s used for xla_device" % device)
        net = k.model
        net.to(device)
        logger.debug("%s used for xla_device, to device done" % device)

        train_sampler = DistributedSamplerWrapper(
            sampler=BalanceClassSampler(labels=k.train_dataset.get_labels(), mode="downsampling"),
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True
        )
        train_loader = torch.utils.data.DataLoader(
            k.train_dataset,
            batch_size=TrainGlobalConfig.batch_size,
            sampler=train_sampler,
            pin_memory=False,
            drop_last=True,
            num_workers=TrainGlobalConfig.num_workers,
        )
        validation_sampler = torch.utils.data.distributed.DistributedSampler(
            k.validation_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False
        )
        validation_loader = torch.utils.data.DataLoader(
            k.validation_dataset,
            batch_size=TrainGlobalConfig.batch_size,
            sampler=validation_sampler,
            pin_memory=False,
            drop_last=False,
            num_workers=TrainGlobalConfig.num_workers
        )
        validation_tune_sampler = torch.utils.data.distributed.DistributedSampler(
            k.validation_tune_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True
        )
        validation_tune_loader = torch.utils.data.DataLoader(
            k.validation_tune_dataset,
            batch_size=TrainGlobalConfig.batch_size,
            sampler=validation_tune_sampler,
            pin_memory=False,
            drop_last=False,
            num_workers=TrainGlobalConfig.num_workers
        )
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            k.test_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False
        )
        test_loader = torch.utils.data.DataLoader(
            k.test_dataset,
            batch_size=TrainGlobalConfig.batch_size,
            sampler=test_sampler,
            pin_memory=False,
            drop_last=False,
            num_workers=TrainGlobalConfig.num_workers
        )

        logger.debug("rank: %d. Will create TPU Fitter", rank)

        if rank == 0:
            time.sleep(1)

        fitter = TPUFitter(model=net, device=device, config=TrainGlobalConfig)
        fitter.fit(train_loader, validation_loader)
        fitter.run_tuning_and_inference(test_loader, validation_tune_loader)
Exemple #9
0
def train_model(data, fold_no, log=False):
    seed_everything(FLAGS["seed"])

    def get_datasets(data):
        X_train, y_train, X_val, y_val = data
        datasets = {}
        datasets["train"] = MelanomaDataset(X_train,
                                            y_train,
                                            istrain=True,
                                            transforms=get_train_transforms())
        datasets["valid"] = MelanomaDataset(X_val,
                                            y_val,
                                            istrain=False,
                                            transforms=get_valid_transforms())
        return datasets

    datasets = SERIAL_EXEC.run(lambda: get_datasets(data))

    if xm.is_master_ordinal == True and log == True:
        writer = SummaryWriter()
        # writer.add_hparams(FLAGS)

    labels_vcount = y_train["target"].value_counts()
    class_counts = [
        labels_vcount[0].astype(np.float32),
        labels_vcount[1].astype(np.float32),
    ]
    num_samples = sum(class_counts)
    class_weights = [
        num_samples / class_counts[i] for i in range(len(class_counts))
    ]
    weights = [
        class_weights[y_train["target"].values[i]]
        for i in range(int(num_samples))
    ]
    wrsampler = WeightedRandomSampler(torch.DoubleTensor(weights),
                                      int(num_samples))
    # BalanceClassSampler(labels=y_train['target'].values, mode="downsampling"),

    # DistributedSamplerWrapper
    train_sampler = DistributedSamplerWrapper(
        sampler=wrsampler,  # sampler=wrsampler,# datasets['train'],
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
    )
    validation_sampler = DistributedSampler(
        datasets["valid"],
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
    )
    train_loader = DataLoader(
        datasets["train"],
        batch_size=FLAGS["batch_size"],
        num_workers=FLAGS["num_workers"],
        sampler=train_sampler,
        drop_last=True,
    )
    val_loader = DataLoader(
        datasets["valid"],
        batch_size=FLAGS["batch_size"],
        num_workers=FLAGS["num_workers"],
        sampler=validation_sampler,
        drop_last=True,
    )

    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=FLAGS["learning_rate"] * xm.xrt_world_size(),
        weight_decay=FLAGS["weight_decay"],
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=0.5,
        cooldown=1,
        mode="min",
        patience=2,
        verbose=True,
        min_lr=1e-8,
    )

    criterion = sigmoid_focal_loss

    def train_one_epoch(loader):
        model.train()
        running_loss = 0
        max_idx = 0
        xm.master_print("-" * 40)
        xm.master_print("Step\t|\tTime")
        xm.master_print("-" * 40)
        for idx, (images, targets) in enumerate(loader):
            optimizer.zero_grad()
            y_pred = model(images.float())
            loss = criterion(y_pred, targets)
            running_loss += float(loss)
            loss.backward()
            xm.optimizer_step(optimizer)
            # xm.mark_step() call everystep for grad accum
            max_idx = float(idx)
            if idx % FLAGS["log_steps"] == 0 and idx != 0:
                xm.master_print("({})\t|\t{}".format(
                    idx, time.asctime(time.localtime())))
        xm.master_print("-" * 40)
        return running_loss / (max_idx + 1)

    def val_one_epoch(loader):
        model.eval()
        running_loss = 0
        max_idx = 0
        roc_auc_scores = RocAucMeter()
        with torch.no_grad():
            for idx, (images, targets) in enumerate(loader):
                y_pred = model(images.float())
                loss = criterion(y_pred, targets)
                running_loss += float(loss)
                max_idx = float(idx)
                roc_auc_scores.update(targets, y_pred)  # [:, 1]
        score = roc_auc_scores.avg
        return running_loss / (max_idx + 1), score

    def _reduce_fn(x):
        return np.array(x).mean()

    best_score = 0
    xm.master_print("=" * 26 + f"Fold #{fold_no} started" + "=" * 27)
    for epoch in range(0, FLAGS["num_epochs"]):
        xm.master_print("-" * 26 + f"Epoch #{epoch+1} started" + "-" * 26)
        xm.master_print(f"Epoch start {time.asctime(time.localtime())}")
        train_start = time.time()
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loss = train_one_epoch(para_loader.per_device_loader(device))
        xm.master_print(f"finished training epoch {epoch+1}")
        elapsed_time = int(time.time() - train_start)
        xm.master_print(
            f"elapsed time: {(elapsed_time)//60}mins {(elapsed_time)%60}s")
        reduced_loss = xm.mesh_reduce("train_loss", train_loss, _reduce_fn)
        xm.master_print(f"reduced loss {reduced_loss:.5f}")
        if xm.is_master_ordinal == True and log == True:
            writer.add_scalar("train/loss", reduced_loss, epoch + 1)

        if (epoch + 1) % FLAGS["val_freq"] == 0:
            val_start = time.time()
            para_loader = pl.ParallelLoader(val_loader, [device])
            val_loss, auc_score = val_one_epoch(
                para_loader.per_device_loader(device))
            xm.master_print(f"finished validating epoch {epoch+1}")
            reduced_val_loss = xm.mesh_reduce("val_loss", val_loss, _reduce_fn)
            reduced_auc_score = xm.mesh_reduce("auc_score", auc_score,
                                               _reduce_fn)
            scheduler.step(reduced_val_loss)
            xm.master_print(f"reduced val loss {reduced_val_loss:.5f}")
            xm.master_print(f"reduced auc score {reduced_auc_score:.5f}")
            val_elapsed_time = int(time.time() - val_start)
            xm.master_print(
                f"elapsed time: {(val_elapsed_time)//60}mins {(val_elapsed_time)%60}s"
            )
            if xm.is_master_ordinal == True and log == True:
                writer.add_scalar("val/loss", reduced_val_loss, epoch + 1)
                writer.add_scalar("val/roc_auc", reduced_auc_score, epoch + 1)
            if (best_score < reduced_auc_score
                    or (best_score - reduced_auc_score) < 0.005):
                best_score = reduced_auc_score
                file_name = f"./{FLAGS['exp_name']}_fold_{fold_no+1}_epoch_{epoch+1}_auc_{reduced_auc_score:.5f}.pth"
                xm.save(model.state_dict(), file_name)
                xm.master_print(f"saved model...")
                xm.master_print(f"new best score: {best_score:.5f}")
                # xser.save(model.state_dict(), file_name, master_only=True)

        xm.master_print(f"Epoch end {time.asctime(time.localtime())}")
        xm.master_print("-" * 27 + f"Epoch #{epoch+1} ended" + "-" * 26)

    xm.master_print("=" * 28 + f"Fold #{fold_no} ended" + "=" * 27)
Exemple #10
0
def get_loaders_from_params(
    batch_size: int = 1,
    num_workers: int = 0,
    drop_last: bool = False,
    per_gpu_scaling: bool = False,
    loaders_params: Dict[str, Any] = None,
    samplers: "OrderedDict[str, Sampler]" = None,
    datasets: "OrderedDict[str, Union[Dataset, dict]]" = None,
    initial_seed: int = 42,
) -> "OrderedDict[str, DataLoader]":
    """
    Creates pytorch dataloaders from datasets and additional parameters.

    Args:
        batch_size: ``batch_size`` parameter
            from ``torch.utils.data.DataLoader``
        num_workers: ``num_workers`` parameter
            from ``torch.utils.data.DataLoader``
        drop_last: ``drop_last`` parameter
            from ``torch.utils.data.DataLoader``
        per_gpu_scaling: boolean flag,
            if ``True``, scales batch_size in proportion to the number of GPUs
        loaders_params: additional loaders parameters
        samplers: additional sampler parameters
        initial_seed: initial seed for ``torch.utils.data.DataLoader``
            workers
        datasets: ordered dictionary with ``torch.utils.data.Dataset``

    Returns:
        OrderedDict[str, DataLoader]: dictionary with
            ``torch.utils.data.DataLoader``

    Raises:
        NotImplementedError: if datasource is out of ``Dataset`` or dict
        ValueError: if batch_sampler option is mutually
            exclusive with distributed
    """
    from catalyst.data.sampler import DistributedSamplerWrapper

    default_batch_size = batch_size
    default_num_workers = num_workers
    loaders_params = copy.deepcopy(loaders_params) or {}
    assert isinstance(loaders_params,
                      dict), (f"`loaders_params` should be a Dict. "
                              f"Got: {loaders_params}")
    samplers = copy.deepcopy(samplers) or {}
    assert isinstance(samplers,
                      dict), f"`samplers` should be a Dict. Got: {samplers}"
    datasets = datasets if datasets is not None else {}

    distributed_rank = get_rank()
    distributed = distributed_rank > -1

    loaders = OrderedDict()
    for name, datasource in datasets.items():  # noqa: WPS426
        assert isinstance(
            datasource,
            (Dataset, dict
             )), f"{datasource} should be Dataset or Dict. Got: {datasource}"

        loader_params = loaders_params.pop(name, {})
        assert isinstance(loader_params,
                          dict), f"{loader_params} should be Dict"

        sampler: Sampler = None
        if isinstance(datasource, dict) and "sampler" in datasource:
            sampler = datasource.pop("sampler", None)
        sampler = samplers.pop(name, sampler)

        batch_size = loader_params.pop("batch_size", default_batch_size)
        num_workers = loader_params.pop("num_workers", default_num_workers)

        if per_gpu_scaling and not distributed:
            num_gpus = max(1, torch.cuda.device_count())
            batch_size *= num_gpus
            num_workers *= num_gpus
        elif not per_gpu_scaling and distributed:
            world_size = get_distributed_params().pop("world_size", 1)
            if batch_size % world_size == 0:
                batch_size = int(batch_size / world_size)
            else:
                raise ValueError(
                    "For this distributed mode with per_gpu_scaling = False "
                    "you need to have batch_size divisible by number of GPUs")

        loader_params = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "pin_memory": torch.cuda.is_available(),
            "drop_last": drop_last,
            **loader_params,
        }

        if isinstance(datasource, Dataset):
            loader_params["dataset"] = datasource
        elif isinstance(datasource, dict):
            assert "dataset" in datasource, "You need to specify dataset for dataloader"
            loader_params = merge_dicts(datasource, loader_params)
        else:
            raise NotImplementedError

        if distributed:
            if sampler is not None:
                if not isinstance(sampler, DistributedSampler):
                    sampler = DistributedSamplerWrapper(sampler=sampler)
            else:
                sampler = DistributedSampler(dataset=loader_params["dataset"])

        loader_params["shuffle"] = name.startswith("train") and sampler is None

        loader_params["sampler"] = sampler

        if "batch_sampler" in loader_params:
            if distributed:
                raise ValueError("batch_sampler option is mutually "
                                 "exclusive with distributed")

            for k in ("batch_size", "shuffle", "sampler", "drop_last"):
                loader_params.pop(k, None)

        if "worker_init_fn" not in loader_params:
            loader_params["worker_init_fn"] = partial(
                _worker_init_fn, initial_seed=initial_seed)

        loaders[name] = DataLoader(**loader_params)

    return loaders