def test_meta_dataloader(): dataset = Sinusoid(10, num_tasks=1000, noise_std=None) meta_dataloader = MetaDataLoader(dataset, batch_size=4) assert isinstance(meta_dataloader, DataLoader) assert len(meta_dataloader) == 250 # 1000 / 4 batch = next(iter(meta_dataloader)) assert isinstance(batch, list) assert len(batch) == 4 task = batch[0] assert isinstance(task, Task) assert len(task) == 10
def test_meta_dataloader_task_loader(): dataset = Sinusoid(10, num_tasks=1000, noise_std=None) meta_dataloader = MetaDataLoader(dataset, batch_size=4) batch = next(iter(meta_dataloader)) dataloader = DataLoader(batch[0], batch_size=5) inputs, targets = next(iter(dataloader)) assert len(dataloader) == 2 # 10 / 5 # PyTorch dataloaders convert numpy array to tensors assert isinstance(inputs, torch.Tensor) assert isinstance(targets, torch.Tensor) assert inputs.shape == (5, 1) assert targets.shape == (5, 1)
def get_episode_loader(dataset, datapath, ways, shots, test_shots, batch_size, split, download=True, shuffle=True, num_workers=0): """Create an episode data loader for a torchmeta dataset. Can also include unlabelled data for semi-supervised learning. dataset: String. Name of the dataset to use. datapath: String. Path, where dataset are stored. ways: Integer. Number of ways N. shots: Integer. Number of shots K for support set. test_shots: Integer. Number of images in query set. batch_size: Integer. Number of tasks per iteration. split: String. One of ['train', 'val', 'test'] download: Boolean. Whether to download the data. shuffle: Boolean. Whether to shuffle episodes. """ # Select dataset if dataset == 'omniglot': dataset_func = omniglot elif dataset == 'miniimagenet': dataset_func = miniimagenet elif dataset == 'tieredimagenet': dataset_func = tieredimagenet elif dataset == 'cub': dataset_func = cub elif dataset == 'cifar_fs': dataset_func = cifar_fs elif dataset == 'doublemnist': dataset_func = doublemnist elif dataset == 'triplemnist': dataset_func = triplemnist else: raise ValueError("No such dataset available. Please choose from\ ['omniglot', 'miniimagenet', 'tieredimagenet',\ 'cub, cifar_fs, doublemnist, triplemnist']") # Collect arguments that are the same for all possible sub-datasets kwargs = { 'download': download, 'meta_train': split == 'train', 'meta_val': split == 'val', 'meta_test': split == 'test', 'shuffle': shuffle } # Create dataset for labelled images dataset_name = dataset dataset = dataset_func(datapath, ways=ways, shots=shots, test_shots=test_shots, **kwargs) print('Supervised data loader for {}:{}.'.format(dataset_name, split)) # Standard supervised meta-learning dataloader collate_fn = collate_task_batch if batch_size else collate_task return MetaDataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, pin_memory=torch.cuda.is_available())