示例#1
0
def create_data_loaders(train_transformation, eval_transformation, datadir,
                        args):
    traindir = os.path.join(datadir, args.train_subdir)
    evaldir = os.path.join(datadir, args.eval_subdir)
    print([args.exclude_unlabeled, args.labeled_batch_size])
    assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size])
    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)
    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)
    assert len(dataset.imgs) == len(labeled_idxs) + len(unlabeled_idxs)

    if args.unsup_augment is not None:
        print("Augmenting Unsupervised Data with {}".format(
            args.unsup_augment))
        if args.unsup_augment == 'cifar100':
            _dataset_config = datasets.__dict__['cifar100']()
            _traindir = os.path.join(_dataset_config['datadir'],
                                     args.train_subdir)
            _dataset = torchvision.datasets.ImageFolder(
                _traindir, _dataset_config['train_transformation'])
            data.relabel_dataset(_dataset, {})
            concat_dataset = torch.utils.data.ConcatDataset(
                [dataset, _dataset])
            extra_idxs = list(
                range(dataset.__len__(),
                      dataset.__len__() + _dataset.__len__()))
            unlabeled_idxs += extra_idxs
            print(concat_dataset.cumulative_sizes)  # [50000, 100000]
            dataset = concat_dataset

    if args.exclude_unlabeled or len(unlabeled_idxs) == 0:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:
        if len(unlabeled_idxs) == 0:
            # for num labels = num images case
            print("Setting unlabeled idxs = labeled idxs")
            unlabeled_idxs = labeled_idxs
        print("len(labeled_idxs)", len(labeled_idxs))
        batch_sampler = data.TwoStreamBatchSampler(unlabeled_idxs,
                                                   labeled_idxs,
                                                   args.batch_size,
                                                   args.labeled_batch_size)
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)
    train_loader_len = len(train_loader)
    eval_loader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(evaldir, eval_transformation),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2 * args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False)
    return train_loader, eval_loader, train_loader_len
示例#2
0
def create_data_loaders(train_transformation, eval_transformation, datadir,
                        args):
    traindir = os.path.join(datadir, args.train_subdir)
    evaldir = os.path.join(datadir, args.eval_subdir)

    assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size])

    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)

    if args.exclude_unlabeled:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:
        batch_sampler = data.TwoStreamBatchSampler(unlabeled_idxs,
                                                   labeled_idxs,
                                                   args.batch_size,
                                                   args.labeled_batch_size)
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler)

    eval_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        evaldir, eval_transformation),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              drop_last=False)

    return train_loader, eval_loader
示例#3
0
def create_data_loaders_cifar100(train_transformation, eval_transformation, datadir, args, wtiny=False):
    # creating data loaders for CIFAR-100 with an option to add tiny images as unlabeled data
    traindir = os.path.join(datadir, args.train_subdir)
    evaldir = os.path.join(datadir, args.eval_subdir)
    print([args.exclude_unlabeled, args.labeled_batch_size])
    assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size])

    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)
    assert len(dataset.imgs) == len(labeled_idxs) + len(unlabeled_idxs)
    orig_ds_size = len(dataset.imgs)

    if args.exclude_unlabeled or (len(unlabeled_idxs) == 0 and args.unsup_augment is None):
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:
        if len(unlabeled_idxs) == 0: # in case of using all labels
          assert len(labeled_idxs) == 50000, 'Only supporting this case for now'
        print("len(labeled_idxs)", len(labeled_idxs))
        if args.unsup_augment is not None:
          print("Unsupervised Augmentation with CIFAR Tiny Images")
          from mean_teacher.tinyimages import TinyImages
          if args.unsup_augment == 'tiny_500k':
            extra_unlab_dataset = TinyImages(transform=train_transformation, which='500k')
          elif args.unsup_augment == 'tiny_237k':
            extra_unlab_dataset = TinyImages(transform=train_transformation, which='237k')
          elif args.unsup_augment == 'tiny_all':
            extra_unlab_dataset = TinyImages(transform=train_transformation, which='tiny_all')
          dataset = ConcatDataset([dataset, extra_unlab_dataset])
          unlabeled_idxs += [orig_ds_size + i for i in range(len(extra_unlab_dataset))]
          print("New unlabeled indices length", len(unlabeled_idxs))
        if args.unsup_augment is None:
          assert args.limit_unlabeled is None, 'With no unsup augmentation, limit_unlabeled should be None'
        batch_sampler = data.TwoStreamBatchSampler(
            unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size, unlabeled_size_limit=args.limit_unlabeled)
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

    train_loader_len = len(train_loader)
    eval_loader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(evaldir, eval_transformation),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2 * args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False)
    return train_loader, eval_loader, train_loader_len
示例#4
0
def create_data_loaders(train_transformation,
                        eval_transformation,
                        datadir,
                        args):
    traindir = os.path.join(datadir, args.train_subdir)
    evaldir = os.path.join(datadir, args.eval_subdir)

    assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size, args.fully_supervised])

    dataset = db_semisuper.DBSS(traindir, train_transformation)

    if not args.fully_supervised and args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)

    if args.exclude_unlabeled:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.fully_supervised:
        sampler = SubsetRandomSampler(range(len(dataset)))
        dataset.labeled_idx = range(len(dataset))
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:
        batch_sampler = data.TwoStreamBatchSampler(
            unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

    train_loader_noshuff = torch.utils.data.DataLoader(dataset,
        batch_size=args.batch_size * 2,
        shuffle=False,
        num_workers= args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False)

    eval_dataset = torchvision.datasets.ImageFolder(evaldir, eval_transformation)
    eval_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False)

    return train_loader, eval_loader, train_loader_noshuff, dataset
示例#5
0
def create_data_loaders(dataconfig, args):
    assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size])

    meta = cdc.ASVSpoof19Meta(data_dir=dataconfig['root'],
                              meta_dir=dataconfig['processed_meta'],
                              folds_num=1,  # default
                              attack_type=dataconfig['attack_type'])

    fl_train = meta.fold_list(fold=1, data_split=cdc.ASVSpoof19Meta.DataSplit.train)
    dataset = cdd.ArkDataGenerator(data_file=dataconfig['feat_storage'],
                                   fold_list=fl_train,
                                   transform=dataconfig['train_trans'],
                                   rand_slides=True)

    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split('\t') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)

    if args.exclude_unlabeled:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:
        batch_sampler = data.TwoStreamBatchSampler(
            unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

    #####
    fl_eval = meta.fold_list(fold=1, data_split=cdc.ASVSpoof19Meta.DataSplit.validation)  # TODO note val == train
    eval_data = cdd.ArkDataGenerator(data_file=dataconfig['feat_storage'],
                                     fold_list=fl_eval,
                                     transform=dataconfig['eval_trans'],
                                     rand_slides=True)
    #####

    eval_loader = torch.utils.data.DataLoader(
        eval_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2 * args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False)

    return train_loader, eval_loader
示例#6
0
def create_data_loaders(train_transformation,
						eval_transformation,
						test_transformation,
						datadir,
						args):
	traindir = os.path.join(datadir, args.train_subdir)
	evaldir = os.path.join(datadir, args.eval_subdir)
	testdir = os.path.join(datadir, 'test/')

	assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size])

	dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

	if args.labels: #open the label files: 10% as labels
		root = '/mnt/HDD1/Frederic/Mean-teacher-based'
		label_path  = root+'/data-local/labels/'+args.dataset+'/'+args.labels
		with open(label_path) as f:
			labels = dict(line.split(' ') for line in f.read().splitlines())
		labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)
	if args.exclude_unlabeled:
		sampler = SubsetRandomSampler(labeled_idxs)
		batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
	elif args.labeled_batch_size:
		batch_sampler = data.TwoStreamBatchSampler(
			unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
	else:
		assert False, "labeled batch size {}".format(args.labeled_batch_size)
#dataset.imgs was like {("data-local/1752.jp",-1),("data-local/177.jpg",8)}
	train_loader = torch.utils.data.DataLoader(dataset,
											   batch_sampler=batch_sampler,
											   num_workers=args.workers,
											   pin_memory=True)

	eval_loader = torch.utils.data.DataLoader(
		torchvision.datasets.ImageFolder(evaldir, eval_transformation),
		batch_size=args.batch_size,
		shuffle=False,
		num_workers=2 * args.workers,  # Needs images twice as fast
		pin_memory=True,
		drop_last=False)

	test_loader = torch.utils.data.DataLoader(
		torchvision.datasets.ImageFolder(testdir, test_transformation),
		batch_size=args.batch_size,
		shuffle=False,
		num_workers=2 * args.workers,  # Needs images twice as fast
		pin_memory=True)

	class_names = dataset.classes
	return train_loader, eval_loader, test_loader,class_names
示例#7
0
def create_data_loaders(train_transformation, eval_transformation, datadir,
                        args):
    traindir = os.path.join(datadir, args.train_subdir)

    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, _ = data.relabel_dataset(dataset, labels)

    sampler = SubsetRandomSampler(labeled_idxs)
    batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

    return train_loader
示例#8
0
def create_data_loaders(train_transformation,
                        eval_transformation,
                        datadir,
                        args):
    traindir = os.path.join(datadir, args.train_subdir)
    evaldir = os.path.join(datadir, args.eval_subdir)

    assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size])

    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)

    if args.exclude_unlabeled:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:
        batch_sampler = data.TwoStreamBatchSampler(
            unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

    eval_loader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(evaldir, eval_transformation),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2 * args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False)

    return train_loader, eval_loader
示例#9
0
文件: main.py 项目: michellehan/ivc
def create_data_loaders(train_transformation, eval_transformation, args):

    ############ training / testing diruse the same test dataset in official split
    print('Training Dataset: %s' % (args.train_dir))
    print('Validation Dataset: %s' % (args.val_dir))

    ############ Customized training dataset
    train_dataset = datasets.IVCdataset(args.train_csv, args.train_dir,
                                        train_transformation)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,  ### no custormized sampler, just batchsize
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    ############ NOT EDITED FOR flag != 'full' ###############################
    if True:  # if args.flag == 'full':
        print('train loader for training on all labeled data!')
    #     train_loader = torch.utils.data.DataLoader(train_dataset,
    #                                                    batch_size=args.batch_size,      ### no custormized sampler, just batchsize
    #                                                    shuffle=True,
    #                                                    num_workers=args.workers,
    #                                                    pin_memory=True,
    #                                                    drop_last=True)

    else:
        sub_traindir = os.path.join(
            args.csvdir, 'train_val_official_%.2f_%s_cls%d.csv' %
            (args.train_portion, args.flag, args.num_classes))
        print('Change to Use Subset Training Dataset: %s' % (sub_traindir))
        sub_train_dataset = datasets.ChestXRayDataset(sub_traindir,
                                                      args.datadir,
                                                      train_transformation)

        if args.batch_size == args.labeled_batch_size:
            print(
                'train loader for training on subset labeled data (NO unlabeled data)!'
            )
            train_loader = torch.utils.data.DataLoader(
                sub_train_dataset,
                batch_size=args.
                batch_size,  ### no custormized sampler, just batchsize
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)
        else:
            print(
                'train loader for training on subset labeled data (INCLUDE unlabeled data)!'
            )
            ### assing NO_LABEL to unlabeled samples
            labeled_idxs, unlabeled_idxs = data.relabel_dataset(
                dataset=train_dataset, labeled_dataset=sub_train_dataset)
            batch_sampler = data.TwoStreamBatchSampler(
                unlabeled_indices=unlabeled_idxs,
                labeled_indices=labeled_idxs,
                batch_size=args.batch_size,
                labeled_batch_size=args.labeled_batch_size)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_sampler=batch_sampler,
                num_workers=args.workers,
                pin_memory=True)
    ############ END: NOT EDITED FOR flag != 'full' ##############################

    ############ Customized validation dataset
    val_dataset = datasets.IVCdataset(args.val_csv, args.val_dir,
                                      eval_transformation)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2 * args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False
    )  # set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size

    args.class_to_idx = train_dataset.class_to_idx
    return train_loader, val_loader
示例#10
0
    def create_data_loaders(self, train_transformation, eval_transformation,
                            datadir, args):
        """
        Creates the dataset loaders
        :param train_transformation:
        :param eval_transformation:
        :param datadir:
        :param args:
        :return:
        """
        logger.info("Loading data from: " + datadir)
        traindir = os.path.join(datadir, self.args.train_subdir)
        evaldir = os.path.join(datadir, self.args.eval_subdir)
        assert_exactly_one(
            [self.args.exclude_unlabeled, self.args.labeled_batch_size])
        dataset = torchvision.datasets.ImageFolder(traindir,
                                                   train_transformation)

        if self.args.labels:

            with open(self.args.labels) as f:
                labels = dict(
                    line.split(' ') for line in f.read().splitlines())
                #takes the file names in the labels dictionary as labeled data, and the rest, as unlabeled
                #MODIFICATION FOR A MAXIMUM OF UNLABELED OBSERVATIONS, TO STUDY THE BEHAVIOUR WITH DIFFERENT NUMBER OF UNLABELED OBSERVATIONS
                labeled_idxs, unlabeled_idxs, validation_idxs, dataset = data.relabel_dataset(
                    dataset, labels)
                logger.info("Number of labeled training observations: " +
                            str(len(labeled_idxs)))
                logger.info("Number of labeled validation observations: " +
                            str(len(validation_idxs)))
                logger.info("Number of unlabeled observations: " +
                            str(len(unlabeled_idxs)))
                if (len(labeled_idxs) < self.args.batch_size
                        or len(validation_idxs) < self.args.batch_size
                        or len(unlabeled_idxs) < self.args.batch_size):
                    logger.warning(
                        "Warning, the batch size is larger than a subset of data"
                    )

        if self.args.exclude_unlabeled:
            logger.info("Not using unlabeled data")
            sampler = SubsetRandomSampler(labeled_idxs)
            batch_sampler = BatchSampler(sampler,
                                         self.args.batch_size,
                                         drop_last=False)
        elif self.args.labeled_batch_size:
            logger.info("Using unlabeled data")
            batch_sampler = data.TwoStreamBatchSampler(
                unlabeled_idxs, labeled_idxs, self.args.batch_size,
                self.args.labeled_batch_size)
        else:
            assert False, "labeled batch size {}".format(
                self.args.labeled_batch_size)

        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            num_workers=self.args.workers,
            pin_memory=True)
        # evaluation loader
        sampler_eval = SubsetRandomSampler(validation_idxs)
        #what is drop last and pin_memory???
        batch_sampler_eval = BatchSampler(sampler_eval,
                                          self.args.batch_size,
                                          drop_last=False)
        eval_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batch_sampler_eval,
            num_workers=self.args.workers,
            pin_memory=True)
        return train_loader, eval_loader
示例#11
0
def create_data_loaders(train_transformation, eval_transformation, datadir,
                        args):
    traindir = os.path.join(datadir, args.train_subdir)

    evaldir = os.path.join(datadir, args.eval_subdir)

    assert_exactly_one([
        args.exclude_unlabeled, args.labeled_batch_size, args.fully_supervised
    ])

    ###
    #Numpy 180+CH images
    ###
    #----------------------------------------------------------------------------------------------------
    if args.dataset == 's1s2glcm16' or args.dataset == 's1s2glcm8' or args.dataset == 's18' or args.dataset == 's28' or args.dataset == 'glcm8':

        def npy_loader(path):
            #sample = torch.from_numpy(np.load(path))
            sample = np.load(path)
            return sample

        #dataset = db_semisuper.DBSS(root=traindir, transform=train_transformation, loader=npy_loader)
        dataset = db_semisuper.DBSS(root=traindir,
                                    transform=train_transformation,
                                    loader=npy_loader)

        if not args.fully_supervised and args.labels:
            with open(args.labels) as f:
                labels = dict(
                    line.split(' ') for line in f.read().splitlines())
            labeled_idxs, unlabeled_idxs = data.relabel_dataset(
                dataset, labels)

        if args.exclude_unlabeled:
            sampler = SubsetRandomSampler(labeled_idxs)
            batch_sampler = BatchSampler(sampler,
                                         args.batch_size,
                                         drop_last=True)
        elif args.fully_supervised:
            sampler = SubsetRandomSampler(range(len(dataset)))
            dataset.labeled_idx = range(len(dataset))
            batch_sampler = BatchSampler(sampler,
                                         args.batch_size,
                                         drop_last=True)
        elif args.labeled_batch_size:
            batch_sampler = data.TwoStreamBatchSampler(unlabeled_idxs,
                                                       labeled_idxs,
                                                       args.batch_size,
                                                       args.labeled_batch_size)
        else:
            assert False, "labeled batch size {}".format(
                args.labeled_batch_size)

        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            num_workers=0,  #args.workers
            pin_memory=True)

        train_loader_noshuff = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size * 2,
            shuffle=False,
            num_workers=0,  # Needs images twice as fast args.workers
            pin_memory=True,
            drop_last=False)

        eval_dataset = torchvision.datasets.DatasetFolder(
            root=evaldir,
            loader=npy_loader,
            transform=eval_transformation,
            extensions=('.npy'))

        eval_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=0,  #args.workers
            pin_memory=True,
            drop_last=False)

        return train_loader, eval_loader, train_loader_noshuff, dataset


###
#Regular image propagation
###
#----------------------------------------------------------------------------------------------------

    else:
        dataset = db_semisuper.DBSS(
            traindir, train_transformation
        )  # torchvision imagefolder object in original paper, here custom made

        if not args.fully_supervised and args.labels:
            with open(args.labels) as f:
                labels = dict(
                    line.split(' ') for line in f.read().splitlines())
            labeled_idxs, unlabeled_idxs = data.relabel_dataset(
                dataset, labels)

        if args.exclude_unlabeled:
            sampler = SubsetRandomSampler(labeled_idxs)
            batch_sampler = BatchSampler(sampler,
                                         args.batch_size,
                                         drop_last=True)
        elif args.fully_supervised:
            sampler = SubsetRandomSampler(range(len(dataset)))
            dataset.labeled_idx = range(len(dataset))
            batch_sampler = BatchSampler(sampler,
                                         args.batch_size,
                                         drop_last=True)
        elif args.labeled_batch_size:
            batch_sampler = data.TwoStreamBatchSampler(unlabeled_idxs,
                                                       labeled_idxs,
                                                       args.batch_size,
                                                       args.labeled_batch_size)
        else:
            assert False, "labeled batch size {}".format(
                args.labeled_batch_size)

        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            num_workers=0,  #args.workers
            pin_memory=True)

        train_loader_noshuff = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size * 2,
            shuffle=False,
            num_workers=0,  # Needs images twice as fast args.workers
            pin_memory=True,
            drop_last=False)

        eval_dataset = torchvision.datasets.ImageFolder(
            evaldir, eval_transformation)
        eval_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=0,  #args.workers
            pin_memory=True,
            drop_last=False)

        return train_loader, eval_loader, train_loader_noshuff, dataset