コード例 #1
0
 def get_train_split(self, transform=None, num_samples=None):
     if num_samples is None:
         indices = (list(range(self.STL10_NUM_UNLABELED, self.STL10_NUM_UNLABELED + self.NUM_TRN_SAMPLES)) + ([] if self.TOT_TRN_SAMPLES <= self.NUM_TRN_SAMPLES else list(range(self.TOT_TRN_SAMPLES - self.NUM_TRN_SAMPLES)))) if self.STL10_USE_LBL_SPLIT \
          else (range(self.STL10_NUM_UNLABELED, self.STL10_NUM_UNLABELED + self.TOT_TRN_SAMPLES) if self.TOT_TRN_SAMPLES <= self.STL10_NUM_LABELED
                else (list(range(self.STL10_NUM_UNLABELED, self.STL10_NUM_UNLABELED + self.STL10_NUM_LABELED)) + list(range(self.TOT_TRN_SAMPLES - self.STL10_NUM_LABELED))))
         split = 'train' if (
             self.TOT_TRN_SAMPLES <= self.NUM_TRN_SAMPLES
             and self.STL10_USE_LBL_SPLIT) or (
                 self.TOT_TRN_SAMPLES <= self.STL10_NUM_LABELED
                 and not self.STL10_USE_LBL_SPLIT) else 'train+unlabeled'
         shuffled_indices = [self.idx_labeled[i] for i in indices] if split=='train' \
          else [self.idx_unlabeled[i] if i < self.STL10_NUM_UNLABELED else self.STL10_NUM_UNLABELED + self.idx_labeled[i-self.STL10_NUM_UNLABELED] for i in indices]
         return Subset(
             STL10(root=self.DATASET_FOLDER,
                   split=split,
                   download=True,
                   transform=transform), shuffled_indices)
     indices = range(num_samples) if self.STL10_USE_LBL_SPLIT \
      else (range(self.STL10_NUM_UNLABELED, num_samples) if num_samples <= self.STL10_NUM_LABELED
            else (list(range(self.STL10_NUM_UNLABELED, self.STL10_NUM_UNLABELED + self.STL10_NUM_LABELED)) + list(range(num_samples - self.STL10_NUM_LABELED))))
     split = 'train' if self.STL10_USE_LBL_SPLIT or (
         num_samples <= self.STL10_NUM_LABELED
         and not self.STL10_USE_LBL_SPLIT) else 'train+unlabeled'
     shuffled_indices = [self.idx_labeled[i] for i in indices] if split=='train' \
      else [self.idx_unlabeled[i] if i < self.STL10_NUM_UNLABELED else self.STL10_NUM_UNLABELED + self.idx_labeled[i-self.STL10_NUM_UNLABELED] for i in indices]
     return Subset(
         STL10(root=self.DATASET_FOLDER,
               split=split,
               download=True,
               transform=transform), shuffled_indices)
コード例 #2
0
    def train_dataloader_mixed(self, batch_size, transforms=None):
        if transforms is None:
            transforms = self._default_transforms()

        unlabeled_dataset = STL10(self.save_path,
                                  split='unlabeled',
                                  download=False,
                                  transform=transforms)
        unlabeled_length = len(unlabeled_dataset)
        unlabeled_dataset, _ = random_split(unlabeled_dataset, [
            unlabeled_length - self.unlabeled_val_split,
            self.unlabeled_val_split
        ])

        labeled_dataset = STL10(self.save_path,
                                split='train',
                                download=False,
                                transform=transforms)
        labeled_length = len(labeled_dataset)
        labeled_dataset, _ = random_split(
            labeled_dataset,
            [labeled_length - self.train_val_split, self.train_val_split])

        dataset = ConcatDataset(unlabeled_dataset, labeled_dataset)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader
コード例 #3
0
def make_data(data_path, dataset, jitter):
    train_transforms = get_train_transforms(jitter, dataset)
    test_transforms = ToTensor()
    if dataset == STL10:
        _train_data = STL10(data_path,
                            split='train',
                            download=True,
                            transform=train_transforms)
        _test_data = STL10(data_path,
                           split='test',
                           download=True,
                           transform=test_transforms)
    else:
        _train_data = dataset(data_path,
                              train=True,
                              download=True,
                              transform=train_transforms)
        _test_data = dataset(data_path,
                             train=False,
                             download=True,
                             transform=test_transforms)

    train_data_loader = torch.utils.data.DataLoader(_train_data,
                                                    batch_size=256,
                                                    shuffle=True,
                                                    num_workers=2)
    test_data_loader = torch.utils.data.DataLoader(_test_data,
                                                   batch_size=256,
                                                   shuffle=True,
                                                   num_workers=2)
    return train_data_loader, test_data_loader
コード例 #4
0
def convert_into_batch(batch_size, is_download=False):
    '''
    Convert STL10 dataset into batch size

    :param is_download: if "True", a STL10 dataset is downloaded to a dataset directory
    :param batch_size: batch size of training and test datasets

    :return: (train_batches, test_batches, class_labels, BATCH_SIZE)
    '''

    transform = transforms.Compose([
        transforms.RandomResizedCrop(64, scale=(88 / 96, 1.0), ratio=(1., 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_set = STL10(root='./dataset', split='train+unlabeled', download=is_download, transform=transform)
    test_set = STL10(root='./dataset', split='test', download=is_download, transform=transform)

    dataset = train_set + test_set

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    with open('./dataset/stl10_binary/class_names.txt', 'rt') as f:
        class_labels = [s.strip() for s in f.readlines()]

    return (dataloader, class_labels, batch_size)
コード例 #5
0
    def downloader(self):
        if self.dataset == 'mnist':
            self.train_set = MNIST(self.data_dir,
                                   train=True,
                                   transform=self.tranform,
                                   download=self.is_download)
            self.test_set = MNIST(self.data_dir,
                                  train=False,
                                  transform=self.tranform,
                                  download=self.is_download)

        elif self.dataset == 'stl10':
            self.train_set = STL10(self.data_dir,
                                   split='train',
                                   transform=self.tranform,
                                   download=self.is_download)
            self.test_set = STL10(self.data_dir,
                                  split='test',
                                  transform=self.tranform,
                                  download=self.is_download)

        elif self.dataset == 'cifar10':
            self.train_set = CIFAR10(self.data_dir,
                                     train=True,
                                     transform=self.tranform,
                                     download=self.is_download)
            self.test_set = CIFAR10(self.data_dir,
                                    train=False,
                                    transform=self.tranform,
                                    download=self.is_download)

        else:
            print('{} dataset: Not Found !!'.format(self.dataset))
コード例 #6
0
    def process(self):
        trainset = STL10(self.data_dir,
                         split='train',
                         download=True,
                         transform=self.base_transform)
        testset = STL10(self.data_dir,
                        split='test',
                        download=True,
                        transform=self.base_transform)
        dataset = ConcatDataset(datasets=[trainset, testset])

        # convert to superpixels
        data_list = []
        for graph, label in tqdm(dataset,
                                 desc="Generating superpixels",
                                 colour="GREEN"):
            datapoint = graph
            datapoint.y = torch.tensor(label)
            data_list.append(datapoint)

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
コード例 #7
0
 def prepare_data(self):
     """
     Downloads the unlabeled, train and test split
     """
     STL10(self.data_dir, split='unlabeled', download=True, transform=transform_lib.ToTensor())
     STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor())
     STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor())
コード例 #8
0
 def download(self):
     STL10(self.data_dir,
           split='train',
           download=True,
           transform=self.base_transform)
     STL10(self.data_dir,
           split='test',
           download=True,
           transform=self.base_transform)
コード例 #9
0
    def __init__(self):
        train = STL10(root='./data/',
                      split='train',
                      transform=transforms.ToTensor(),
                      download=True)

        test = STL10(root='./data/',
                     split='test',
                     transform=transforms.ToTensor())

        self.data = ConcatDataset([train, test])
コード例 #10
0
 def _load_stl(self):
     trainset = Dataset(STL10(root='.stl',
                              split='train',
                              download=True,
                              transform=get_transform(96, 12, 32,
                                                      True)['train']),
                        with_index=self.with_index)
     testset = STL10(root='.stl',
                     split='test',
                     download=True,
                     transform=get_transform(96, 12, 32, True)['test'])
     return {'train': trainset, 'test': testset}
コード例 #11
0
ファイル: stl10.py プロジェクト: zibuyu2018/Dassl.pytorch
def download_and_prepare(root):
    train = STL10(root, split='train', download=True)
    test = STL10(root, split='test')
    unlabeled = STL10(root, split='unlabeled')

    train_dir = osp.join(root, 'train')
    test_dir = osp.join(root, 'test')
    unlabeled_dir = osp.join(root, 'unlabeled')

    extract_and_save_image(train, train_dir)
    extract_and_save_image(test, test_dir)
    extract_and_save_image(unlabeled, unlabeled_dir)
コード例 #12
0
 def prepare_data(self):
     STL10(self.save_path,
           split='unlabeled',
           download=True,
           transform=transform_lib.ToTensor())
     STL10(self.save_path,
           split='train',
           download=True,
           transform=transform_lib.ToTensor())
     STL10(self.save_path,
           split='test',
           download=True,
           transform=transform_lib.ToTensor())
コード例 #13
0
def stl10(root):
    from torchvision.datasets import STL10
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761)),
    ])
    trainset = STL10(root, split='train', transform=transform, download=True)
    testset = STL10(root, split='test', transform=transform)
    trainset.targets = trainset.labels
    testset.targets = testset.labels
    return trainset, testset
コード例 #14
0
    def load_data(self, dataset):
        data_transform = transforms.Compose([transforms.ToTensor()])
        if dataset == "mnist":
            train = MNIST(root="./data",
                          train=True,
                          transform=data_transform,
                          download=True)
            test = MNIST(root="./data",
                         train=False,
                         transform=data_transform,
                         download=True)
        elif dataset == "fashion-mnist":
            train = FashionMNIST(root="./data",
                                 train=True,
                                 transform=data_transform,
                                 download=True)
            test = FashionMNIST(root="./data",
                                train=False,
                                transform=data_transform,
                                download=True)
        elif dataset == "cifar":
            train = CIFAR10(root="./data",
                            train=True,
                            transform=data_transform,
                            download=True)
            test = CIFAR10(root="./data",
                           train=False,
                           transform=data_transform,
                           download=True)
        elif dataset == "stl":
            train = STL10(root="/data02/Atin/STL10/",
                          split="unlabeled",
                          transform=data_transform,
                          download=True)
            test = STL10(root="/data02/Atin/STL10/",
                         split="test",
                         transform=data_transform,
                         download=True)

        train_loader = torch.utils.data.DataLoader(train,
                                                   batch_size=128,
                                                   shuffle=True,
                                                   num_workers=0)
        test_loader = torch.utils.data.DataLoader(test,
                                                  batch_size=64,
                                                  shuffle=True,
                                                  num_workers=0)

        return train_loader, test_loader
コード例 #15
0
    def val_dataloader_mixed(self):
        """
        Loads a portion of the 'unlabeled' training data set aside for validation along with
        the portion of the 'train' dataset to be used for validation

        unlabeled_val = (unlabeled - train_val_split)

        labeled_val = (train- train_val_split)

        full_val = unlabeled_val + labeled_val

        Args:

            batch_size: the batch size
            transforms: a sequence of transforms
        """
        transforms = self.default_transforms(
        ) if self.val_transforms is None else self.val_transforms
        unlabeled_dataset = STL10(self.data_dir,
                                  split='unlabeled',
                                  download=False,
                                  transform=transforms)
        unlabeled_length = len(unlabeled_dataset)
        _, unlabeled_dataset = random_split(
            unlabeled_dataset, [
                unlabeled_length - self.unlabeled_val_split,
                self.unlabeled_val_split
            ],
            generator=torch.Generator().manual_seed(self.seed))

        labeled_dataset = STL10(self.data_dir,
                                split='train',
                                download=False,
                                transform=transforms)
        labeled_length = len(labeled_dataset)
        _, labeled_dataset = random_split(
            labeled_dataset,
            [labeled_length - self.train_val_split, self.train_val_split],
            generator=torch.Generator().manual_seed(self.seed))

        dataset = ConcatDataset(unlabeled_dataset, labeled_dataset)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=False,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader
コード例 #16
0
    def _grab_dataset(self):
        transform = self.val_set_transform

        if self.hparams.dataset == "CIFAR10":
            dataset = CIFAR10(root=self.hparams.dataset_dir,
                              train=False,
                              transform=transform,
                              download=True)

        elif self.hparams.dataset == "CIFAR100":
            dataset = CIFAR100(root=self.hparams.dataset_dir,
                               train=False,
                               transform=transform,
                               download=True)

        elif self.hparams.dataset == 'STL10':
            dataset = STL10(root=self.hparams.dataset_dir,
                            split='test',
                            transform=transform,
                            download=True)

        elif self.hparams.dataset == 'COVID':
            dataset = torchvision.datasets.ImageFolder(
                root=self.hparams.dataset_dir + 'test', transform=transform)

        elif self.hparams.dataset == 'ImageNet100B' or self.hparams.dataset == 'imagenet-100B':
            dataset = ImageNet100(root=self.hparams.dataset_dir,
                                  split='val',
                                  transform=transform)

        return dataset
コード例 #17
0
def get_stl10(root, transform_train=None, transform_val=None, download=True):

    training_set = STL10(root,
                         split='train',
                         download=True,
                         transform=transform_train)
    dev_set = STL10(root, split='test', download=True, transform=transform_val)
    unl_set = STL10(root,
                    split='unlabeled',
                    download=True,
                    transform=transform_train)

    print(
        f"#Labeled: {len(training_set)} #Unlabeled: {len(unl_set)} #Val: {len(dev_set)} #Test: None"
    )
    return training_set, unl_set, dev_set, None
コード例 #18
0
 def get_test_split(self, transform=None, num_samples=None):
     if num_samples is None: num_samples = self.NUM_TST_SAMPLES
     return Subset(
         STL10(root=self.DATASET_FOLDER,
               split='test',
               download=True,
               transform=transform), range(num_samples))
コード例 #19
0
    def val_dataloader(self):
        """
        Loads a portion of the 'unlabeled' training data set aside for validation
        The val dataset = (unlabeled - train_val_split)

        Args:

            batch_size: the batch size
            transforms: a sequence of transforms
        """
        transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms

        dataset = STL10(self.data_dir, split='unlabeled', download=False, transform=transforms)
        train_length = len(dataset)
        _, dataset_val = random_split(dataset,
                                      [train_length - self.unlabeled_val_split, self.unlabeled_val_split],
                                      generator=torch.Generator().manual_seed(self.seed))
        loader = DataLoader(
            dataset_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
        return loader
コード例 #20
0
    def train_dataloader(self):
        """
        Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`.
        """
        transforms = self.default_transforms(
        ) if self.train_transforms is None else self.train_transforms

        dataset = STL10(self.data_dir,
                        split='unlabeled',
                        download=False,
                        transform=transforms)
        train_length = len(dataset)
        dataset_train, _ = random_split(
            dataset, [
                train_length - self.unlabeled_val_split,
                self.unlabeled_val_split
            ],
            generator=torch.Generator().manual_seed(self.seed))
        loader = DataLoader(dataset_train,
                            batch_size=self.batch_size,
                            shuffle=True,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader
コード例 #21
0
 def stl(dataset_root, split: str = None):
     dataset = STL10(root=dataset_root,
                     split='unlabeled',
                     transform=amdim_transforms.AMDIMTrainTransformsSTL10(),
                     download=True)
     tng_split, val_split = random_split(dataset, [95000, 5000])
     return tng_split, val_split
コード例 #22
0
ファイル: datasets.py プロジェクト: bsafacicek/rca
def get_loader_STL(batchsize):
    transform_train = noaug_cifar()

    trainset = STL10(root='./data', split='train', download=True, transform=transform_train)
    testset = STL10(root='./data', split='test', download=True, transform=noaug_cifar())
    print (trainset.data.shape, len(trainset.labels))
    print (testset.data.shape, len(testset.labels))

    ### remove monkey samples
    trainset.labels = np.array(trainset.labels)
    final_inds_train = np.where(trainset.labels != 7)[0]
    trainset.data = trainset.data[final_inds_train]
    trainset.labels = trainset.labels[final_inds_train]
    testset.labels = np.array(testset.labels)
    final_inds_test = np.where(testset.labels != 7)[0]
    testset.data = testset.data[final_inds_test]
    testset.labels = testset.labels[final_inds_test]
    print (trainset.data.shape, len(trainset.labels))
    print (testset.data.shape, len(testset.labels))

    ### change label indexes to be the same as cifar10
    labels_train = deepcopy(trainset.labels)
    trainset.labels[labels_train==1] = 2
    trainset.labels[labels_train==2] = 1
    trainset.labels[labels_train==8] = 7
    trainset.labels[labels_train==9] = 8
    labels_test = deepcopy(testset.labels)
    testset.labels[labels_test==1] = 2
    testset.labels[labels_test==2] = 1
    testset.labels[labels_test==8] = 7
    testset.labels[labels_test==9] = 8

    ### resize images N X 9 6 X 96 X 3 -> N X 32 X 32 X 3:
    trainset.data = downscale_local_mean(trainset.data, (1, 1, 3, 3)).astype(np.uint8)
    testset.data = downscale_local_mean(testset.data, (1, 1, 3, 3)).astype(np.uint8)

    print (trainset.data.shape, len(trainset.labels))
    print (testset.data.shape, len(testset.labels))

    trainloader = DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=0)
    testloader = DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=0)

    print("STL train min=%f, max=%f" % (trainset.data.min(), trainset.data.max()))
    print("STL test min=%f, max=%f" % (testset.data.min(), testset.data.max()))

    return trainloader, testloader
コード例 #23
0
def get_dataset():

    train_dataset = STL10('../data',
                          split='train',
                          download=True,
                          transform=transforms.Compose([
                              transforms.Resize(32),
                              transforms.RandomHorizontalFlip(p=0.5),
                              transforms.RandomCrop(28),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),
                                                   (0.5, 0.5, 0.5))
                          ]))

    test_dataset = CIFAR10('../data',
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize(32),
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.RandomCrop(28),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5),
                                                    (0.5, 0.5, 0.5))
                           ]))

    def map_labels_stl_to_cifar10(labels_stl):
        # map labels...
        #   2 -> 1, 1 -> 2
        #   6 -> 7, 7 -> 6

        new_labels = np.zeros(len(labels_stl))
        for idx, l in enumerate(labels_stl):
            if l == 1:
                new_labels[idx] = 2
            elif l == 2:
                new_labels[idx] = 1
            elif l >= 5:
                new_labels[idx] = 5
            else:
                new_labels[idx] = l

        return np.asarray(new_labels)

    def apply_one_hot(labels):
        new_labels = np.zeros(len(labels))
        for idx, l in enumerate(labels):
            if l >= 5:
                new_labels[idx] = 5
            else:
                new_labels[idx] = l

        return np.asarray(new_labels)

    train_dataset.labels = map_labels_stl_to_cifar10(train_dataset.labels)
    test_dataset.train_labels = apply_one_hot(test_dataset.train_labels)

    return train_dataset, test_dataset
コード例 #24
0
ファイル: natural.py プロジェクト: val-iisc/GANTree
    def get_data(self):
        STL10('../data/stl10', download=True)
        train_data, train_labels = tr.load(
            '../data/stl10/processed/training.pt')
        test_data, test_labels = tr.load('../data/stl10/processed/test.pt')

        train_data = normalize_mnist_images(train_data)
        test_data = normalize_mnist_images(test_data)
        return train_data, test_data, train_labels, test_labels
コード例 #25
0
    def stl_train(dataset_root):
        train_transform = amdim_transforms.TransformsSTL10()
        dataset = STL10(root=dataset_root,
                        split='unlabeled',
                        transform=train_transform,
                        download=True)
        tng_split, val_split = random_split(dataset, [95000, 5000])

        return tng_split, val_split
コード例 #26
0
 def val_dataloader(self):
     # makes the validation dataloader
     val_ds = STL10('.',
                    split='test',
                    transform=transforms.ToTensor(),
                    download=True)
     val_ds = Subset(val_ds, torch.arange(500))
     val_loader = DataLoader(val_ds, batch_size=64)
     return val_loader
コード例 #27
0
 def train_dataloader(self):
     # makes the training dataloaders
     train_labeled_ds = STL10('.',
                              split='train',
                              transform=transforms.ToTensor(),
                              download=True)
     train_unlabeled_ds = STL10('.',
                                split='unlabeled',
                                transform=transforms.ToTensor(),
                                download=True)
     # using Pytorch's built-in STL10 dataset
     train_labeled_ds = Subset(train_labeled_ds, torch.arange(500))
     train_labeled_loader = DataLoader(train_labeled_ds, batch_size=64)
     # making a dataloader for labeled examples
     train_unlabeled_ds = Subset(train_unlabeled_ds, torch.arange(500))
     train_unlabeled_loader = DataLoader(train_unlabeled_ds, batch_size=64)
     # making a dataloader for unlabeled examples
     return [train_labeled_loader, train_unlabeled_loader]
コード例 #28
0
ファイル: utils.py プロジェクト: Afanc/colorizing_things
def _load_dataset(folder, split, transform):
    params_dataset = {
        "root": folder,
        "download": True,
        "split": split,
        "transform": transform
    }

    return STL10(**params_dataset)
コード例 #29
0
    def stl_train(dataset_root, patch_size, patch_overlap):
        train_transform = amdim_transforms.TransformsSTL10Patches(
            patch_size=patch_size, overlap=patch_overlap)
        dataset = STL10(root=dataset_root,
                        split='unlabeled',
                        transform=train_transform,
                        download=True)
        tng_split, val_split = random_split(dataset, [95000, 5000])

        return tng_split, val_split
コード例 #30
0
    def train_dataloader_mixed(self):
        """
        Loads a portion of the 'unlabeled' training data and 'train' (labeled) data.
        both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split`

        Args:

            batch_size: the batch size
            transforms: a sequence of transforms
        """
        transforms = self.default_transforms(
        ) if self.train_transforms is None else self.train_transforms

        unlabeled_dataset = STL10(self.data_dir,
                                  split='unlabeled',
                                  download=False,
                                  transform=transforms)
        unlabeled_length = len(unlabeled_dataset)
        unlabeled_dataset, _ = random_split(
            unlabeled_dataset, [
                unlabeled_length - self.unlabeled_val_split,
                self.unlabeled_val_split
            ],
            generator=torch.Generator().manual_seed(self.seed))

        labeled_dataset = STL10(self.data_dir,
                                split='train',
                                download=False,
                                transform=transforms)
        labeled_length = len(labeled_dataset)
        labeled_dataset, _ = random_split(
            labeled_dataset,
            [labeled_length - self.train_val_split, self.train_val_split],
            generator=torch.Generator().manual_seed(self.seed))

        dataset = ConcatDataset(unlabeled_dataset, labeled_dataset)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=True,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader