def setUpClass(cls) -> None: cls.ds = TestDatasets() cls.meta_tensor_dataset = MetaDataset(cls.ds.tensor_dataset) cls.meta_str_dataset = MetaDataset(cls.ds.str_dataset) cls.meta_alpha_dataset = MetaDataset(cls.ds.alphabet_dataset) cls.mnist_dataset = MetaDataset(cls.ds.get_mnist()) cls.omniglot_dataset = MetaDataset(cls.ds.get_omniglot())
def test_fails_with_non_torch_dataset(self): try: MetaDataset(np.random.randn(100, 100)) return False except TypeError: return True finally: return False
def test_load_data(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) task_dataset = TaskDataset(dataset, task_transforms=[LoadData(dataset)], num_tasks=NUM_TASKS) for task in task_dataset: self.assertTrue(isinstance(task[0], torch.Tensor)) self.assertTrue(isinstance(task[1], torch.Tensor))
def __init__(self, root=None, transforms=None): img_folder = ImageFolder(root=root) meta_data = MetaDataset(img_folder) self.img_list = img_folder.imgs self.labels_to_indices = meta_data.labels_to_indices self.indices_to_labels = meta_data.indices_to_labels self.labels = [0, 1, 2, 3, 4] # labels in dataset self.transforms = transforms
def test_filter_labels(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) chosen_labels = random.sample(list(range(Y_SHAPE)), k=Y_SHAPE // 2) dataset = MetaDataset(TensorDataset(data, labels)) task_dataset = TaskDataset(dataset, task_transforms=[ FilterLabels(dataset, chosen_labels), LoadData(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: for label in task[1]: self.assertTrue(label in chosen_labels)
def test_n_ways(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for ways in range(1, 10): task_dataset = TaskDataset( dataset, task_transforms=[NWays(dataset, n=ways), LoadData(dataset)], num_tasks=NUM_TASKS) for task in task_dataset: bins = task[1].bincount() num_classes = len(bins) - (bins == 0).sum() self.assertEqual(num_classes, ways)
def test_remap_labels(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for ways in range(1, 5): task_dataset = TaskDataset(dataset, task_transforms=[ NWays(dataset, ways), LoadData(dataset), RemapLabels(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: for label in range(ways): self.assertTrue(label in task[1])
def __init__(self): model = Net() self.loss_fn = nn.CrossEntropyLoss() model.to(device) self.encoder = BertModel.from_pretrained('bert-base-chinese') self.meta_model = l2l.algorithms.MAML(model, lr=1e-2, first_order=True) self.optim = AdamW(self.meta_model.parameters(), lr=5e-3) # text_train = l2l.text.datasets.NewsClassification(root=download_location, download=True) # train_gen = l2l.text.datasets.TaskGenerator(text_train, ways=ways) X, Y = load_data() self.dataset = TensorDataset(T.LongTensor(X), T.LongTensor(Y)) self.metaset = MetaDataset(self.dataset) self.task_generator = TaskGenerator(self.metaset, ways=15, shots=10, classes=None, tasks=1000)
def test_k_shots(self): data = torch.randn(NUM_DATA, X_SHAPE) labels = torch.randint(0, Y_SHAPE, (NUM_DATA, )) dataset = MetaDataset(TensorDataset(data, labels)) for replacement in [False, True]: for shots in range(1, 10): task_dataset = TaskDataset(dataset, task_transforms=[ KShots(dataset, k=shots, replacement=replacement), LoadData(dataset) ], num_tasks=NUM_TASKS) for task in task_dataset: bins = task[1].bincount() correct = (bins == shots).sum() self.assertEqual(correct, Y_SHAPE)
def setUpClass(cls) -> None: cls.ds = TestDatasets() cls.meta_tensor_dataset = MetaDataset(cls.ds.tensor_dataset) cls.meta_str_dataset = MetaDataset(cls.ds.str_dataset) cls.meta_alpha_dataset = MetaDataset(cls.ds.alphabet_dataset)