Example #1
0
    def add_query(self):
        print('\nSelect queries... (query type: {})'.format(
            self.args.query_type))
        start_time = time.time()
        self.netG.eval()
        self.netD.eval()

        # get query & update train/pool
        if self.args.query_type == 'random':
            query_idx = np.random.permutation(len(
                self.pool))[:self.args.per_size]
        elif self.args.query_type == 'gold':
            query_idx = query.gold_acquistiion(self.pool, self.netD, self.args,
                                               self.device)
        else:
            raise NotImplementedError

        self.train_idx = list(set(self.train_idx) | set(query_idx))
        self.pool_idx = list(set(self.pool_idx) - set(query_idx))

        self.trainset = data_utils.Subset(self.base_dataset, self.train_idx)
        self.pool = data_utils.Subset(self.base_dataset, self.pool_idx)

        # print computation time
        query_time = int(time.time() - start_time)
        print('{:d}s elapsed'.format(query_time))
Example #2
0
def get_train_valid_cifar10_dataloader(data_dir,
                                       batch_size=100,
                                       train_portion=0.99):
    trans = get_img_tranformation()
    full_dataset = torchvision.datasets.CIFAR10(root=data_dir,
                                                train=True,
                                                download=False,
                                                transform=trans)
    num_train = len(full_dataset)
    train_size = int(train_portion * num_train)
    valid_size = len(full_dataset) - train_size
    # splite the dataset to train and validation set non-randomly
    idxs = list(range(num_train))
    train_idxs = idxs[valid_size:]
    valid_idxs = idxs[:valid_size]
    #
    train_set = data.Subset(full_dataset, train_idxs)
    valid_set = data.Subset(full_dataset, valid_idxs)

    trainloader = data.DataLoader(train_set,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=2)
    validloader = data.DataLoader(valid_set,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=2)

    return trainloader, validloader
Example #3
0
def get_AwA2(augment, dataroot):
    num_classes = 85
    image_shape = (64, 64, 3)
    test_transform = transforms.Compose(
        [transforms.Resize((64, 64)),
         transforms.ToTensor()])
    if augment:
        transformations = [
            transforms.Resize((64, 64)),
            transforms.RandomCrop(size=64, padding=4, padding_mode='reflect'),
            transforms.RandomHorizontalFlip()
        ]
    else:
        transformations = [transforms.Resize((64, 64))]
    transformations.extend([transforms.ToTensor()])
    train_transform = transforms.Compose(transformations)

    path = Path(dataroot) / 'data' / 'AwA2' / 'Animals_with_Attributes2'
    dataset = AwA2(path,
                   train_transform=train_transform,
                   test_transform=test_transform)

    train_dataset = data.Subset(dataset, list(dataset.train_idxs))
    test_dataset = data.Subset(dataset, list(dataset.test_idxs))
    pred_bin_mat = torch.tensor(dataset.pred_bin_mat).float()
    return image_shape, num_classes, train_dataset, test_dataset, pred_bin_mat
Example #4
0
def split(self, condition):
    """
    Splits the dataset according to the given boolean condition. When :code:`pyblaze.nn` is
    imported, this method is available on all :code:`torch.utils.data.Dataset` objects.

    Attention
    ---------
    Do not call this method on iterable datasets.

    Parameters
    ----------
    condition: callable (object) -> bool
        The condition which splits the dataset.

    Returns
    -------
    torch.utils.data.Subset
        The dataset with the items for which the condition evaluated to `true`.
    torch.utils.data.Subset
        The dataset with the items for which the condition evaluated to `false`.
    """
    filter_ = np.array([condition(item) for item in self])
    true_indices = np.where(filter_)[0]
    false_indices = np.where(~filter_)[0]
    return data.Subset(self, true_indices), data.Subset(self, false_indices)
Example #5
0
    def init_data(self):
        print('Initialize dataset...')
        self.train_transform = data.get_transform(self.args.image_size,
                                                  self.args.train_transform)
        self.test_transform = data.get_transform(self.args.image_size,
                                                 self.args.test_transform)

        # load base dataset
        self.base_dataset, self.test_dataset = data.load_base_dataset(
            self.args)
        self.base_dataset.transform = self.train_transform
        self.test_dataset.transform = self.test_transform

        # split to train/val/pool set
        if self.args.init_size is None:
            self.train_idx = list(range(len(self.base_dataset)))
            self.val_idx = []
            self.pool_idx = []
            self.args.init_size = len(self.base_dataset)
            self.args.per_size = 0
            self.args.max_size = len(self.base_dataset)
        else:
            self.train_idx, self.val_idx, self.pool_idx = data.split_dataset(
                self.base_dataset, self.args.ny, self.args.init_size,
                self.args.val_size)

        if self.args.max_size is None:
            self.args.per_size = 0
            self.args.max_size = self.args.init_size

        # define trainset and pool
        self.trainset = data_utils.Subset(self.base_dataset, self.train_idx)
        self.valset = data_utils.Subset(self.base_dataset, self.val_idx)
        self.pool = data_utils.Subset(self.base_dataset, self.pool_idx)
Example #6
0
    def __init__(self, context, dataset, criterion):
        self.ctx = context
        self.criterion = criterion

        ds_size = len(dataset)
        test_count = ds_size // 10
        train_count = ds_size - test_count
        test_indices = torch.arange(start=0, end=test_count)
        train_indices = torch.arange(start=test_count + 1, end=ds_size - 1)
        self.train_dataset = td.Subset(dataset, indices=train_indices)
        self.test_dataset = td.Subset(dataset, indices=test_indices)

        log_dir = os.path.join(
            'runs',
            self.ctx.params['meta_name'],
        )
        log_path = os.path.join(
            log_dir,
            'run_{}'.format(datetime.datetime.now().isoformat()),
        )
        if not os.path.exists('runs'):
            os.mkdir('runs')
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)
        print('Writing Tensorboard logs to {}'.format(log_path))
        self.writer = SummaryWriter(log_path)

        print('Writing graph')
        dummy_input = torch.empty(size=self.ctx.model_shape,
                                  dtype=torch.float32).to(self.ctx.device)
        self.writer.add_graph(self.ctx.net, dummy_input)
        del dummy_input
        print('done')
Example #7
0
	def get_dataloader(self,split,verbose=0):
		if split == 'all_train':
			dataset = self.get_dataset(split='train') ; shuffle = True ;
			if self.unlab_samples_per_class != -1:
				labels = torch.tensor([y for x,y in dataset])
				indices = torch.arange(len(labels))
				indices = torch.cat([indices[labels==x][:self.unlab_samples_per_class] for x in torch.unique(labels)])
				dataset = data.Subset(dataset, indices)

		elif split == 'lab_train':
			dataset = self.get_dataset(split='train') ; shuffle = True ;
			if self.samples_per_class != -1:
				labels = torch.tensor([y for x,y in dataset])
				indices = torch.arange(len(labels))
				indices = torch.cat([indices[labels==x][:self.samples_per_class] for x in torch.unique(labels)])
				dataset = data.Subset(dataset, indices)

		elif split == 'test':
			dataset = self.get_dataset(split='test') ; shuffle = False ;

		elif split == 'valid':
			dataset = self.get_dataset(split='valid') ; shuffle = False ;

		dataloader = data.DataLoader(dataset, batch_size=self.batch_size, shuffle=shuffle, sampler=None, batch_sampler=None, num_workers=16)
		
		if verbose > 0:
			print(split,len(dataloader))
			labels = torch.cat([y for x,y in dataloader])
			labels = labels.numpy()
			print(np.unique(labels,return_counts=True))
		return dataloader
Example #8
0
def train_data(bs):
    """Get data loader for trainning & validating, bs means batch_size."""

    train_ds = VideoZoomDataset(train_dataset_rootdir, VIDEO_SEQUENCE_LENGTH,
                                get_transform(train=True))
    print(train_ds)

    # Split train_ds in train and valid set
    # xxxx--modify here
    valid_len = int(0.2 * len(train_ds))
    indices = [i for i in range(len(train_ds) - valid_len, len(train_ds))]

    valid_ds = data.Subset(train_ds, indices)
    indices = [i for i in range(len(train_ds) - valid_len)]
    train_ds = data.Subset(train_ds, indices)

    # Define training and validation data loaders
    n_threads = min(4, bs)
    train_dl = data.DataLoader(train_ds,
                               batch_size=bs,
                               shuffle=True,
                               num_workers=n_threads)
    valid_dl = data.DataLoader(valid_ds,
                               batch_size=bs,
                               shuffle=False,
                               num_workers=n_threads)

    return train_dl, valid_dl
def train_test_split(dset):
    length = len(dset)

    n_train = int(length * 0.6)
    n_test = int(length * 0.1)
    idx = list(range(length))

    train_idx = idx[:n_train]
    val_idx = idx[n_train:(n_train + n_test)]
    test_idx = idx[(n_train + n_test):]

    train_set = data.Subset(dset, train_idx)
    val_set = data.Subset(dset, val_idx)
    test_set = data.Subset(dset, test_idx)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=100,
                                               shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=100,
                                             shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=100,
                                              shuffle=False)

    #print(len(train_loader), len(val_loader), len(test_loader))
    return train_loader, val_loader, test_loader
def main(opt):
    model = LSTM(opt, batch_first=True, dropout=opt.dropout)
    if opt.pre_train:
        model.load_state_dict(torch.load(opt.save_path))
    optimizer = optim.Adam(model.parameters(), opt.learning_rate)
    mseloss = nn.MSELoss()

    dataset = PowerDataset(opt,
                           prepocess_path=opt.prepocess_path,
                           transform=transforms.Compose(
                               [transforms.ToTensor()]))
    train_dataset = data.Subset(dataset, indices=range(8664))
    test_dataset = data.Subset(dataset, indices=range(8664, len(dataset)))
    train_dataloader = data.dataloader.DataLoader(train_dataset,
                                                  num_workers=opt.n_threads,
                                                  batch_size=opt.batch_size,
                                                  shuffle=True)
    test_sampler = data.SequentialSampler(test_dataset)
    test_dataloader = data.dataloader.DataLoader(
        test_dataset,
        num_workers=opt.n_threads,
        batch_size=opt.test_batch_size,
        shuffle=False,
        sampler=test_sampler)

    for e in range(opt.epochs):
        if opt.test_only:
            test(model, test_dataloader)
            break
        print('epoch: ', e)
        train(model, mseloss, optimizer, train_dataloader)
        test(model, test_dataloader)
        torch.save(model.state_dict(), opt.save_path)
Example #11
0
def train_data(bs):
    """Get data loader for trainning & validating, bs means batch_size."""

    train_ds = ImageColorDataset(train_dataset_rootdir,
                                 get_transform(train=True))
    print(train_ds)

    # Split train_ds in train and valid set
    valid_len = int(0.2 * len(train_ds))
    indices = [i for i in range(len(train_ds) - valid_len, len(train_ds))]

    valid_ds = data.Subset(train_ds, indices)
    indices = [i for i in range(len(train_ds) - valid_len)]
    train_ds = data.Subset(train_ds, indices)

    # Define training and validation data loaders
    train_dl = data.DataLoader(train_ds,
                               batch_size=bs,
                               shuffle=True,
                               num_workers=4)
    valid_dl = data.DataLoader(valid_ds,
                               batch_size=bs,
                               shuffle=False,
                               num_workers=4)

    return train_dl, valid_dl
def get_loader(BATCH_SIZE):
    data_set = Img_data()

    indices = [i for i in range(0, len(data_set))]

    train_data = data.Subset(data_set, indices[:15120])
    val_data = data.Subset(data_set, indices[15120:])

    transform = get_transforms()

    train_data.dataset.set_transform(transform["train"])
    val_data.dataset.set_transform(transform["val"])

    train_loader = data.DataLoader(train_data,
                                   batch_size=BATCH_SIZE,
                                   num_workers=2,
                                   shuffle=True,
                                   pin_memory=True)

    val_loader = data.DataLoader(val_data,
                                 batch_size=BATCH_SIZE,
                                 num_workers=2,
                                 shuffle=False,
                                 pin_memory=True)

    return train_loader, val_loader
 def _split_train_val_create_loaders(self, dataset, train_idxs, test_idxs
                                     ) -> Tuple[torch_data.DataLoader, torch_data.DataLoader]:
     train_dataset = torch_data.Subset(dataset, train_idxs)
     val_dataset = torch_data.Subset(dataset, test_idxs)
     train_loader = self._create_loader_with_answers(train_dataset)
     test_loader = self._create_loader_with_answers(val_dataset)
     return train_loader, test_loader
def split(full_dataset, n_features, n_attributes):
    """Split dataset making sure that each symbol is represented with equal frequency"""
    assert n_attributes == 2, 'Only implemented for 2 attrs'
    first_dim_indices, second_dim_indices = list(range(n_features)), list(range(n_features))
    random.shuffle(second_dim_indices)
    test_indices = [a * n_features + b for a, b in zip(first_dim_indices, second_dim_indices)]
    train_indices = [i for i in range(n_features * n_features) if i not in test_indices]
    return data.Subset(full_dataset, train_indices), data.Subset(full_dataset, test_indices)
Example #15
0
def train_test_split_dataset(full_dataset, test_size=0.2):
    """ Splits the given dataset in a train and test set using stratified sampling """
    full_indices = np.arange(len(full_dataset))
    full_targets = np.array([target for _, target in full_dataset.samples])

    train_indices, test_indices = train_test_split(full_indices, test_size=test_size, random_state=41,
                                                   stratify=full_targets)
    return data.Subset(full_dataset, train_indices), data.Subset(full_dataset, test_indices)
Example #16
0
 def _fixed_split(self, dataset, ratio=0.5):
     assert isinstance(dataset, data.dataset.Dataset)
     n_total = len(dataset)
     thres = int(n_total * ratio)
     id_a = range(len(dataset))[:thres]
     id_b = range(len(dataset))[thres:]
     data_a = data.Subset(dataset, id_a)
     data_b = data.Subset(dataset, id_b)
     return data_a, data_b
Example #17
0
    def prepare_data(self):
        dataset = SplitDataset(self.text_dataset, self.labeler, 500, 800, 20)

        train_indices, valid_indices = train_test_split(
            np.arange(len(dataset)),
            test_size=self.hparams.test_size,
            random_state=1234)
        self.train_dataset = data.Subset(dataset, train_indices)
        self.valid_dataset = data.Subset(dataset, valid_indices)
Example #18
0
def load_mnist(args, **kwargs):
    """
    Dataloading function for mnist. Outputs image data in vectorized form: each image is a vector of size 784
    """
    args.dynamic_binarization = False
    args.input_type = "binary"

    flatten = kwargs.get("flatten", False)

    # start processing
    transforms_list = [
        tv.transforms.ToTensor(),
        # tv.transforms.Normalize((0.5,), (0.5,)),
    ]
    args.xdim = (28, 28)
    if flatten:
        transforms_list.append(tv.transforms.Lambda(lambda x: x.view(-1)))
        args.xdim = (784,)
    preprocess = tv.transforms.Compose(transforms_list)
    preprocess = kwargs.get("preprocess", preprocess)

    train_dataset = tv.datasets.MNIST(
        "./data", transform=preprocess, download=True, train=True
    )
    train_len = len(train_dataset)

    if args.val_frac:
        train_indices, validation_indices = train_test_split(
            np.arange(train_len), test_size=args.val_frac, random_state=args.manual_seed
        )
        train_indices = train_indices.tolist()
        validation_indices = validation_indices.tolist()
    else:
        train_indices = np.arange(train_len).tolist()
        val_indices = []

    train = utdata.Subset(train_dataset, train_indices)
    validation = utdata.Subset(train_dataset, validation_indices)

    # pytorch data loader
    train_loader = data_utils.DataLoader(
        train, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory
    )

    val_loader = data_utils.DataLoader(
        validation,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=args.pin_memory,
    )

    test = tv.datasets.MNIST("./data", transform=preprocess, download=True, train=False)
    test_loader = data_utils.DataLoader(
        test, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory
    )

    return train_loader, val_loader, test_loader, args
Example #19
0
def k_fold_split_dataset(full_dataset, n_splits=5):
    """ Splits the dataset `n_splits` times using CV and yields tuples (train_dataset, test_dataset) """
    skf = StratifiedKFold(n_splits, random_state=42, shuffle=True)

    full_targets = np.array([target for _, target in full_dataset.samples])
    full_inputs = np.zeros(len(full_dataset))

    for train_indices, test_indices in skf.split(full_inputs, full_targets):
        yield data.Subset(full_dataset, train_indices), data.Subset(full_dataset, test_indices)
Example #20
0
    def __init__(self, dataset: data.Dataset):
        super().__init__()
        self.dataset = dataset
        self.training_mask = np.full((len(dataset), ), False)
        self.pool_mask = np.full((len(dataset), ), True)

        self.training_dataset = data.Subset(self.dataset, None)
        self.pool_dataset = data.Subset(self.dataset, None)

        self._update_indices()
Example #21
0
    def __init__(self, dataset: data.Dataset):
        super().__init__()
        self.dataset = dataset
        self.active_mask = np.full((len(dataset), ), False)
        self.available_mask = np.full((len(dataset), ), True)

        self.active_dataset = data.Subset(self.dataset, None)
        self.available_dataset = data.Subset(self.dataset, None)

        self._update_indices()
Example #22
0
    def prepare_data(self):
        text_data = MemoryMapDataset("texts.txt", "slices.pkl")
        dataset = SplitDataset(text_data, 500, 800, 20)

        train_indices, valid_indeces = train_test_split(np.arange(
            len(dataset)),
                                                        test_size=200_000,
                                                        random_state=1234)
        self.train_dataset = data.Subset(dataset, train_indices)
        self.valid_dataset = data.Subset(dataset, valid_indeces)
Example #23
0
def train_test_split(torch_dataset, test_size):
    dataset_size = len(torch_dataset)
    indices = list(range(dataset_size))
    n_test = int(dataset_size * test_size)
    n_train = dataset_size - n_test
    train_indices, test_indices = (indices[:n_train], indices[n_train:])

    train_set = data.Subset(torch_dataset, train_indices)
    test_set = data.Subset(torch_dataset, test_indices)

    return train_set, test_set
Example #24
0
 def split(self, dataset):
     """
     Splits the dataset into dev, train and test
     :param dataset: the dataset to split
     :return: DataSets named tupple (dev, train, test)
     """
     dev = data_utils.Subset(dataset, range(len(dataset) * 2 // 10))
     train = data_utils.Subset(dataset, range(0, len(dataset) * 9 // 10))
     test = data_utils.Subset(
         dataset, range(len(dataset) * 9 // 10 + 1, len(dataset)))
     return DataSets(dev=dev, train=train, test=test)
Example #25
0
    def __init__(self, dataset: data.Dataset):
        self.dataset = dataset
        self.total_size = len(dataset)

        self.train_mask = np.full((self.total_size, ), False)
        self.pool_mask = np.full((self.total_size, ), True)

        self.train_data = data.Subset(dataset, None)
        self.pool_data = data.Subset(dataset, None)

        self._update_indices()
Example #26
0
def load_cifar10(args, **kwargs):
    flatten = kwargs.get("flatten", False)

    # start processing
    transforms_list = [
        tv.transforms.ToTensor(),
        # tv.transforms.Normalize((0.5,), (0.5,)),
    ]
    args.xdim = (3, 32, 32)
    if flatten:
        transforms_list.append(tv.transforms.Lambda(lambda x: x.view(-1)))
        args.xdim = (1024,)
    preprocess = tv.transforms.Compose(transforms_list)
    preprocess = kwargs.get("preprocess", preprocess)

    train_dataset = tv.datasets.CIFAR10(
        "./data", transform=preprocess, download=True, train=True
    )
    train_len = len(train_dataset)

    if args.val_frac:
        train_indices, validation_indices = train_test_split(
            np.arange(train_len), test_size=args.val_frac, random_state=args.manual_seed
        )
        train_indices = train_indices.tolist()
        validation_indices = validation_indices.tolist()
    else:
        train_indices = np.arange(train_len).tolist()
        val_indices = []

    train = utdata.Subset(train_dataset, train_indices)
    validation = utdata.Subset(train_dataset, validation_indices)

    # pytorch data loader
    train_loader = data_utils.DataLoader(
        train, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory
    )

    val_loader = data_utils.DataLoader(
        validation,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=args.pin_memory,
    )

    test = tv.datasets.CIFAR10(
        "./data", transform=preprocess, download=True, train=False
    )
    test_loader = data_utils.DataLoader(
        test, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory
    )

    return train_loader, val_loader, test_loader, args
    def train_test_split(self, dataset: torch_data.Dataset, test_size=0.2):
        dataset_indexes = list(range(len(dataset)))
        train_idxs, test_idxs = train_test_split(dataset_indexes,
                                                 test_size=test_size)

        train_dataset = torch_data.Subset(dataset, train_idxs)
        val_dataset = torch_data.Subset(dataset, test_idxs)
        train_loader = self.loader_builder.build(train_dataset,
                                                 has_answers=True)
        test_loader = self.loader_builder.build(val_dataset)

        return train_loader, test_loader
Example #28
0
def cifar_loaders(batch_size, shuffle_train = True, shuffle_test = False, train_random_transform = False, normalize_input = False, num_examples = None, test_batch_size=None): 
    if normalize_input:
        # lly slightly changed the std and mean
        mean = [0.485, 0.456, 0.406]
        std = [0.225, 0.225, 0.225]


        # std = [0.2023, 0.1994, 0.2010]
        # mean = [0.4914, 0.4822, 0.4465]
        normalize = transforms.Normalize(mean = mean, std = std)
    else:
        std = [1.0, 1.0, 1.0]
        mean = [0, 0, 0]
        normalize = transforms.Normalize(mean = mean, std = std)
    if train_random_transform:
        if normalize_input:
            train = datasets.CIFAR10('./data', train=True, download=True, 
                transform=transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize,
                ]))
        else:
            train = datasets.CIFAR10('./data', train=True, download=True, 
                transform=transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                ]))
    else:
        train = datasets.CIFAR10('./data', train=True, download=True, 
            transform=transforms.Compose([transforms.ToTensor(),normalize]))
    test = datasets.CIFAR10('./data', train=False, 
        transform=transforms.Compose([transforms.ToTensor(), normalize]))
    
    if num_examples:
        indices = list(range(num_examples))
        train = data.Subset(train, indices)
        test = data.Subset(test, indices)

    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
        shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
    if test_batch_size:
        batch_size = test_batch_size
    test_loader = torch.utils.data.DataLoader(test, batch_size=max(batch_size, 1),
        shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
    train_loader.std = std
    test_loader.std = std
    train_loader.mean = mean
    test_loader.mean = mean
    return train_loader, test_loader
Example #29
0
    def _split(self, valid_rate, shuffle_seed):
        self.indices = list(range(self.dataset_size))
        random.seed(shuffle_seed)
        random.shuffle(self.indices)
        split = int(np.floor((1 - valid_rate) * self.dataset_size))

        self.train_indices, self.valid_indices = self.indices[:split], self.indices[split:]
        self.train_dataset = data.Subset(self, self.train_indices)
        self.valid_dataset = data.Subset(self, self.valid_indices)

        self.train_sampler = data.RandomSampler(self.train_dataset)
        self.valid_sampler = data.SequentialSampler(self.valid_dataset)
        self.test_sampler = data.SequentialSampler(self)
Example #30
0
    def create_train_val_slice(self,
                               image_datasets,
                               sample_size=None,
                               val_same_as_train=False):
        img_dataset = image_datasets  # reminder - this is a generator

        # clone the image_datasets_reduced[train] generator for the val
        if val_same_as_train:
            img_dataset['val'] = list(img_dataset['train'])
            # copy the train to val (so the tranformations won't occur again)
            # image_datasets_reduced['train'] = image_datasets_reduced['val']

        dataset_sizes = {x: len(img_dataset[x]) for x in ['train', 'val']}

        if sample_size:  # return the whole data set
            sample_n = {
                x: random.sample(list(range(dataset_sizes[x])), sample_size)
                for x in ['train', 'val']
            }
            img_dataset = {
                x: data.Subset(img_dataset[x], sample_n[x])
                for x in ['train', 'val']
            }
            dataset_sizes = {x: len(img_dataset[x]) for x in ['train', 'val']}

        dataloaders = {
            x: data.DataLoader(img_dataset[x],
                               batch_size=self.batch_size,
                               shuffle=True,
                               num_workers=0)
            for x in ['train', 'val']
        }
        return dataloaders, dataset_sizes