示例#1
0
 def setUpClass(cls) -> None:
     cls.ds = TestDatasets()
     cls.meta_tensor_dataset = MetaDataset(cls.ds.tensor_dataset)
     cls.meta_str_dataset = MetaDataset(cls.ds.str_dataset)
     cls.meta_alpha_dataset = MetaDataset(cls.ds.alphabet_dataset)
     cls.mnist_dataset = MetaDataset(cls.ds.get_mnist())
     cls.omniglot_dataset = MetaDataset(cls.ds.get_omniglot())
示例#2
0
 def test_fails_with_non_torch_dataset(self):
     try:
         MetaDataset(np.random.randn(100, 100))
         return False
     except TypeError:
         return True
     finally:
         return False
示例#3
0
 def test_load_data(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     task_dataset = TaskDataset(dataset,
                                task_transforms=[LoadData(dataset)],
                                num_tasks=NUM_TASKS)
     for task in task_dataset:
         self.assertTrue(isinstance(task[0], torch.Tensor))
         self.assertTrue(isinstance(task[1], torch.Tensor))
示例#4
0
    def __init__(self, root=None, transforms=None):

        img_folder = ImageFolder(root=root)
        meta_data = MetaDataset(img_folder)
        
        self.img_list = img_folder.imgs
        self.labels_to_indices = meta_data.labels_to_indices
        self.indices_to_labels = meta_data.indices_to_labels
        self.labels = [0, 1, 2, 3, 4] # labels in dataset
        self.transforms = transforms
示例#5
0
 def test_filter_labels(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     chosen_labels = random.sample(list(range(Y_SHAPE)), k=Y_SHAPE // 2)
     dataset = MetaDataset(TensorDataset(data, labels))
     task_dataset = TaskDataset(dataset,
                                task_transforms=[
                                    FilterLabels(dataset, chosen_labels),
                                    LoadData(dataset)
                                ],
                                num_tasks=NUM_TASKS)
     for task in task_dataset:
         for label in task[1]:
             self.assertTrue(label in chosen_labels)
示例#6
0
 def test_n_ways(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     for ways in range(1, 10):
         task_dataset = TaskDataset(
             dataset,
             task_transforms=[NWays(dataset, n=ways),
                              LoadData(dataset)],
             num_tasks=NUM_TASKS)
         for task in task_dataset:
             bins = task[1].bincount()
             num_classes = len(bins) - (bins == 0).sum()
             self.assertEqual(num_classes, ways)
示例#7
0
 def test_remap_labels(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     for ways in range(1, 5):
         task_dataset = TaskDataset(dataset,
                                    task_transforms=[
                                        NWays(dataset, ways),
                                        LoadData(dataset),
                                        RemapLabels(dataset)
                                    ],
                                    num_tasks=NUM_TASKS)
         for task in task_dataset:
             for label in range(ways):
                 self.assertTrue(label in task[1])
 def __init__(self):
     model = Net()
     self.loss_fn = nn.CrossEntropyLoss()
     model.to(device)
     self.encoder = BertModel.from_pretrained('bert-base-chinese')
     self.meta_model = l2l.algorithms.MAML(model, lr=1e-2, first_order=True)
     self.optim = AdamW(self.meta_model.parameters(), lr=5e-3)
     # text_train = l2l.text.datasets.NewsClassification(root=download_location, download=True)
     # train_gen = l2l.text.datasets.TaskGenerator(text_train, ways=ways)
     X, Y = load_data()
     self.dataset = TensorDataset(T.LongTensor(X), T.LongTensor(Y))
     self.metaset = MetaDataset(self.dataset)
     self.task_generator = TaskGenerator(self.metaset,
                                         ways=15,
                                         shots=10,
                                         classes=None,
                                         tasks=1000)
示例#9
0
 def test_k_shots(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     for replacement in [False, True]:
         for shots in range(1, 10):
             task_dataset = TaskDataset(dataset,
                                        task_transforms=[
                                            KShots(dataset,
                                                   k=shots,
                                                   replacement=replacement),
                                            LoadData(dataset)
                                        ],
                                        num_tasks=NUM_TASKS)
             for task in task_dataset:
                 bins = task[1].bincount()
                 correct = (bins == shots).sum()
                 self.assertEqual(correct, Y_SHAPE)
示例#10
0
 def setUpClass(cls) -> None:
     cls.ds = TestDatasets()
     cls.meta_tensor_dataset = MetaDataset(cls.ds.tensor_dataset)
     cls.meta_str_dataset = MetaDataset(cls.ds.str_dataset)
     cls.meta_alpha_dataset = MetaDataset(cls.ds.alphabet_dataset)