def test_infinite_tasks(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset( dataset, task_transforms=[LoadData(dataset), random_subset]) self.assertEqual(len(task_dataset), 1) prev = task_dataset.sample() for i, task in enumerate(task_dataset): self.assertFalse(task_equal(prev, task)) prev = task if i > 4: break
def test_instanciation(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset(dataset, task_transforms=[LoadData(dataset)], num_tasks=NUM_TASKS) self.assertEqual(len(task_dataset), NUM_TASKS)
def test_load_data(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) task_dataset = TaskDataset(dataset, task_transforms=[LoadData(dataset)], num_tasks=NUM_TASKS) for task in task_dataset: self.assertTrue(isinstance(task[0], torch.Tensor)) self.assertTrue(isinstance(task[1], torch.Tensor))
def test_filter_labels(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) chosen_labels = random.sample(list(range(Y_SHAPE)), k=Y_SHAPE // 2) dataset = MetaDataset(TensorDataset(data, labels)) task_dataset = TaskDataset(dataset, task_transforms=[ FilterLabels(dataset, chosen_labels), LoadData(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: for label in task[1]: self.assertTrue(label in chosen_labels)
def test_n_ways(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for ways in range(1, 10): task_dataset = TaskDataset( dataset, task_transforms=[NWays(dataset, n=ways), LoadData(dataset)], num_tasks=NUM_TASKS) for task in task_dataset: bins = task[1].bincount() num_classes = len(bins) - (bins == 0).sum() self.assertEqual(num_classes, ways)
def test_remap_labels(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for ways in range(1, 5): task_dataset = TaskDataset(dataset, task_transforms=[ NWays(dataset, ways), LoadData(dataset), RemapLabels(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: for label in range(ways): self.assertTrue(label in task[1])
def test_dataloader(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset( dataset, task_transforms=[LoadData(dataset), random_subset], num_tasks=NUM_TASKS) task_loader = DataLoader(task_dataset, shuffle=True, batch_size=META_BSZ, num_workers=WORKERS, drop_last=True) for task_batch in task_loader: self.assertEqual(task_batch[0].shape, (META_BSZ, X_SHAPE)) self.assertEqual(task_batch[1].shape, (META_BSZ, 1))
def test_task_transforms(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset( dataset, task_transforms=[LoadData(dataset), random_subset], num_tasks=NUM_TASKS) for task in task_dataset: # Tests transforms on the task_description self.assertEqual(len(task[0]), SUBSET_SIZE) self.assertEqual(len(task[1]), SUBSET_SIZE) # Tests transforms on the data self.assertEqual(task[0].size(1), X_SHAPE) self.assertLessEqual(task[1].max(), Y_SHAPE - 1) self.assertGreaterEqual(task[1].max(), 0)
def test_k_shots(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for replacement in [False, True]: for shots in range(1, 10): task_dataset = TaskDataset(dataset, task_transforms=[ KShots(dataset, k=shots, replacement=replacement), LoadData(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: bins = task[1].bincount() correct = (bins == shots).sum() self.assertEqual(correct, Y_SHAPE)
def test_task_caching(self): data = th.randn(NUM_DATA, X_SHAPE) labels = th.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = TensorDataset(data, labels) task_dataset = TaskDataset(dataset, task_transforms=[LoadData(dataset)], num_tasks=NUM_TASKS) tasks = [] for i, task in enumerate(task_dataset, 1): tasks.append(task) self.assertEqual(i, NUM_TASKS) for ref, task in zip(tasks, task_dataset): self.assertTrue(task_equal(ref, task)) for i in range(NUM_TASKS): ref = tasks[i] task = task_dataset[i] self.assertTrue(task_equal(ref, task))
def build_taskset(datasets, k=4): """ Function to build a learn2learn TaskDataset datasest : list -- a list of the different datasets to create training tasks from """ MetaDS = l2l.data.UnionMetaDataset( [l2l.data.MetaDataset(MicroDataset(t)) for t in datasets]) dataset = l2l.data.MetaDataset(MetaDS) transforms = [ l2l.data.transforms.NWays(dataset, n=2), l2l.data.transforms.KShots(dataset, k=k, replacement=True), l2l.data.transforms.LoadData(dataset) ] return (TaskDataset(dataset, transforms, num_tasks=(len(datasets) * 2 * (len(datasets) * 2 - 1)), task_collate=collate))