def val_dataloader_mixed(self): """ Loads a portion of the 'unlabeled' training data set aside for validation along with the portion of the 'train' dataset to be used for validation unlabeled_val = (unlabeled - train_val_split) labeled_val = (train- train_val_split) full_val = unlabeled_val + labeled_val Args: batch_size: the batch size transforms: a sequence of transforms """ transforms = self.default_transforms( ) if self.val_transforms is None else self.val_transforms unlabeled_dataset = STL10(self.data_dir, split='unlabeled', download=False, transform=transforms) unlabeled_length = len(unlabeled_dataset) _, unlabeled_dataset = random_split( unlabeled_dataset, [ unlabeled_length - self.unlabeled_val_split, self.unlabeled_val_split ], generator=torch.Generator().manual_seed(self.seed)) labeled_dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) labeled_length = len(labeled_dataset) _, labeled_dataset = random_split( labeled_dataset, [labeled_length - self.train_val_split, self.train_val_split], generator=torch.Generator().manual_seed(self.seed)) dataset = ConcatDataset(unlabeled_dataset, labeled_dataset) loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True) return loader
def train_dataloader_mixed(self): """ Loads a portion of the 'unlabeled' training data and 'train' (labeled) data. both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split` Args: batch_size: the batch size transforms: a sequence of transforms """ transforms = self.default_transforms( ) if self.train_transforms is None else self.train_transforms unlabeled_dataset = STL10(self.data_dir, split='unlabeled', download=False, transform=transforms) unlabeled_length = len(unlabeled_dataset) unlabeled_dataset, _ = random_split( unlabeled_dataset, [ unlabeled_length - self.unlabeled_val_split, self.unlabeled_val_split ], generator=torch.Generator().manual_seed(self.seed)) labeled_dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) labeled_length = len(labeled_dataset) labeled_dataset, _ = random_split( labeled_dataset, [labeled_length - self.train_val_split, self.train_val_split], generator=torch.Generator().manual_seed(self.seed)) dataset = ConcatDataset(unlabeled_dataset, labeled_dataset) loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=True) return loader