def _create_mapping_loader(config, dataset_class, partitions):
    imgs_list = []
    for partition in partitions:
        imgs_curr = dataset_class(
            **{
                "config": config,
                "split": partition,
                "purpose": "test",
            }  # return testing tuples, image and label
        )
        if config.use_doersch_datasets:
            imgs_curr = DoerschDataset(config, imgs_curr)
        imgs_list.append(imgs_curr)

    imgs = ConcatDataset(imgs_list)
    dataloader = torch.utils.data.DataLoader(
        imgs,
        batch_size=config.batch_sz,
        # full batch
        shuffle=False,
        # no point since not trained on
        num_workers=0,
        drop_last=False,
    )
    return dataloader
def _create_dataloaders(config, dataset_class):
    # unlike in clustering, each dataloader here returns pairs of images - we
    # need the matrix relation between them
    dataloaders = []
    do_shuffle = config.num_dataloaders == 1
    for d_i in range(config.num_dataloaders):
        print(
            (
                "Creating paired dataloader %d out of %d time %s"
                % (d_i, config.num_dataloaders, datetime.now())
            )
        )
        sys.stdout.flush()

        train_imgs_list = []
        for train_partition in config.train_partitions:
            train_imgs_curr = dataset_class(
                **{
                    "config": config,
                    "split": train_partition,
                    "purpose": "train",
                }  # return training tuples, not including labels
            )
            if config.use_doersch_datasets:
                train_imgs_curr = DoerschDataset(config, train_imgs_curr)

            train_imgs_list.append(train_imgs_curr)

        train_imgs = ConcatDataset(train_imgs_list)

        train_dataloader = torch.utils.data.DataLoader(
            train_imgs,
            batch_size=config.dataloader_batch_sz,
            shuffle=do_shuffle,
            num_workers=0,
            drop_last=False,
        )

        if d_i > 0:
            assert len(train_dataloader) == len(dataloaders[d_i - 1])

        dataloaders.append(train_dataloader)

    num_train_batches = len(dataloaders[0])
    print(("Length of paired datasets vector %d" % len(dataloaders)))
    print(("Number of batches per epoch: %d" % num_train_batches))
    sys.stdout.flush()

    return dataloaders