def test_infinite_tasks(self):
     data = th.randn(NUM_DATA, X_SHAPE)
     labels = th.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = TensorDataset(data, labels)
     task_dataset = TaskDataset(
         dataset, task_transforms=[LoadData(dataset), random_subset])
     self.assertEqual(len(task_dataset), 1)
     prev = task_dataset.sample()
     for i, task in enumerate(task_dataset):
         self.assertFalse(task_equal(prev, task))
         prev = task
         if i > 4:
             break
 def test_instanciation(self):
     data = th.randn(NUM_DATA, X_SHAPE)
     labels = th.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = TensorDataset(data, labels)
     task_dataset = TaskDataset(dataset,
                                task_transforms=[LoadData(dataset)],
                                num_tasks=NUM_TASKS)
     self.assertEqual(len(task_dataset), NUM_TASKS)
Пример #3
0
 def test_load_data(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     task_dataset = TaskDataset(dataset,
                                task_transforms=[LoadData(dataset)],
                                num_tasks=NUM_TASKS)
     for task in task_dataset:
         self.assertTrue(isinstance(task[0], torch.Tensor))
         self.assertTrue(isinstance(task[1], torch.Tensor))
Пример #4
0
 def test_filter_labels(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     chosen_labels = random.sample(list(range(Y_SHAPE)), k=Y_SHAPE // 2)
     dataset = MetaDataset(TensorDataset(data, labels))
     task_dataset = TaskDataset(dataset,
                                task_transforms=[
                                    FilterLabels(dataset, chosen_labels),
                                    LoadData(dataset)
                                ],
                                num_tasks=NUM_TASKS)
     for task in task_dataset:
         for label in task[1]:
             self.assertTrue(label in chosen_labels)
Пример #5
0
 def test_n_ways(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     for ways in range(1, 10):
         task_dataset = TaskDataset(
             dataset,
             task_transforms=[NWays(dataset, n=ways),
                              LoadData(dataset)],
             num_tasks=NUM_TASKS)
         for task in task_dataset:
             bins = task[1].bincount()
             num_classes = len(bins) - (bins == 0).sum()
             self.assertEqual(num_classes, ways)
Пример #6
0
 def test_remap_labels(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     for ways in range(1, 5):
         task_dataset = TaskDataset(dataset,
                                    task_transforms=[
                                        NWays(dataset, ways),
                                        LoadData(dataset),
                                        RemapLabels(dataset)
                                    ],
                                    num_tasks=NUM_TASKS)
         for task in task_dataset:
             for label in range(ways):
                 self.assertTrue(label in task[1])
 def test_dataloader(self):
     data = th.randn(NUM_DATA, X_SHAPE)
     labels = th.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = TensorDataset(data, labels)
     task_dataset = TaskDataset(
         dataset,
         task_transforms=[LoadData(dataset), random_subset],
         num_tasks=NUM_TASKS)
     task_loader = DataLoader(task_dataset,
                              shuffle=True,
                              batch_size=META_BSZ,
                              num_workers=WORKERS,
                              drop_last=True)
     for task_batch in task_loader:
         self.assertEqual(task_batch[0].shape, (META_BSZ, X_SHAPE))
         self.assertEqual(task_batch[1].shape, (META_BSZ, 1))
    def test_task_transforms(self):
        data = th.randn(NUM_DATA, X_SHAPE)
        labels = th.randint(0, Y_SHAPE, (NUM_DATA, ))
        dataset = TensorDataset(data, labels)
        task_dataset = TaskDataset(
            dataset,
            task_transforms=[LoadData(dataset), random_subset],
            num_tasks=NUM_TASKS)
        for task in task_dataset:
            # Tests transforms on the task_description
            self.assertEqual(len(task[0]), SUBSET_SIZE)
            self.assertEqual(len(task[1]), SUBSET_SIZE)

            # Tests transforms on the data
            self.assertEqual(task[0].size(1), X_SHAPE)
            self.assertLessEqual(task[1].max(), Y_SHAPE - 1)
            self.assertGreaterEqual(task[1].max(), 0)
Пример #9
0
 def test_k_shots(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     for replacement in [False, True]:
         for shots in range(1, 10):
             task_dataset = TaskDataset(dataset,
                                        task_transforms=[
                                            KShots(dataset,
                                                   k=shots,
                                                   replacement=replacement),
                                            LoadData(dataset)
                                        ],
                                        num_tasks=NUM_TASKS)
             for task in task_dataset:
                 bins = task[1].bincount()
                 correct = (bins == shots).sum()
                 self.assertEqual(correct, Y_SHAPE)
Пример #10
0
    def test_task_caching(self):
        data = th.randn(NUM_DATA, X_SHAPE)
        labels = th.randint(0, Y_SHAPE, (NUM_DATA, ))
        dataset = TensorDataset(data, labels)
        task_dataset = TaskDataset(dataset,
                                   task_transforms=[LoadData(dataset)],
                                   num_tasks=NUM_TASKS)
        tasks = []
        for i, task in enumerate(task_dataset, 1):
            tasks.append(task)
        self.assertEqual(i, NUM_TASKS)
        for ref, task in zip(tasks, task_dataset):
            self.assertTrue(task_equal(ref, task))

        for i in range(NUM_TASKS):
            ref = tasks[i]
            task = task_dataset[i]
            self.assertTrue(task_equal(ref, task))
Пример #11
0
def build_taskset(datasets, k=4):
    """
    Function to build a learn2learn TaskDataset
    datasest : list -- a list of the different datasets to create training tasks from
    """
    MetaDS = l2l.data.UnionMetaDataset(
        [l2l.data.MetaDataset(MicroDataset(t)) for t in datasets])
    dataset = l2l.data.MetaDataset(MetaDS)
    transforms = [
        l2l.data.transforms.NWays(dataset, n=2),
        l2l.data.transforms.KShots(dataset, k=k, replacement=True),
        l2l.data.transforms.LoadData(dataset)
    ]
    return (TaskDataset(dataset,
                        transforms,
                        num_tasks=(len(datasets) * 2 *
                                   (len(datasets) * 2 - 1)),
                        task_collate=collate))