class SSLModuleTest(unittest.TestCase): def setUp(self): d1_len = 100 d2_len = 1000 d1 = SSLTestDataset(labeled=True, length=d1_len) d2 = SSLTestDataset(labeled=False, length=d2_len) dataset = ConcatDataset([d1, d2]) print(len(dataset)) self.al_dataset = ActiveLearningDataset(dataset) self.al_dataset.label(list( range(d1_len))) # Label data from d1 (even numbers) def test_epoch(self): hparams = { 'p': None, 'num_steps': None, 'batch_size': 10, 'workers': 0 } module = TestSSLModule(self.al_dataset, Namespace(**hparams)) trainer = Trainer(max_epochs=1, num_sanity_val_steps=0, progress_bar_refresh_rate=0, logger=False, checkpoint_callback=False) trainer.fit(module) assert len(module.labeled_data) == len(module.unlabeled_data) assert torch.all(torch.tensor(module.labeled_data) % 2 == 0) assert torch.all(torch.tensor(module.unlabeled_data) % 2 != 0)
def test_arrowds(): dataset = HFdata.load_dataset('glue', 'sst2')['test'] dataset = ActiveLearningDataset(dataset) dataset.label(np.arange(10)) assert len(dataset) == 10 assert len(dataset.pool) == 1811 data = dataset.pool[0] assert [k in ['idx', 'label', 'sentence'] for k, v in data.items()]
def test_transform(self): train_transform = Lambda(lambda k: 1) test_transform = Lambda(lambda k: 0) dataset = ActiveLearningDataset(MyDataset(train_transform), test_transform, make_unlabelled=lambda x: (x[0], -1)) dataset.label(np.arange(10)) pool = dataset.pool assert np.equal([i for i in pool], [(0, -1) for i in np.arange(10, 100)]).all() assert np.equal([i for i in dataset], [(1, i) for i in np.arange(10)]).all()
def test_transform(self): train_transform = Lambda(lambda k: 1) test_transform = Lambda(lambda k: 0) dataset = ActiveLearningDataset(MyDataset(train_transform), make_unlabelled=lambda x: (x[0], -1), pool_specifics={'transform': test_transform}) dataset.label(np.arange(10)) pool = dataset.pool assert np.equal([i for i in pool], [(0, -1) for i in np.arange(10, 100)]).all() assert np.equal([i for i in dataset], [(1, i) for i in np.arange(10)]).all() with pytest.raises(ValueError) as e: ActiveLearningDataset(MyDataset(train_transform), pool_specifics={'whatever': 123}).pool
class SSLDatasetTest(unittest.TestCase): def setUp(self): d1_len = 100 d2_len = 1000 d1 = SSLTestDataset(labeled=True, length=d1_len) d2 = SSLTestDataset(labeled=False, length=d2_len) dataset = ConcatDataset([d1, d2]) print(len(dataset)) self.al_dataset = ActiveLearningDataset(dataset) self.al_dataset.label(list( range(d1_len))) # Label data from d1 (even numbers) self.ss_iterator = SemiSupervisedIterator(self.al_dataset, p=None, num_steps=None, batch_size=10) def test_epoch(self): labeled_data = [] unlabeled_data = [] for batch_idx, batch in enumerate(self.ss_iterator): if SemiSupervisedIterator.is_labeled(batch): batch = SemiSupervisedIterator.get_batch(batch) labeled_data.extend(batch) else: batch = SemiSupervisedIterator.get_batch(batch) unlabeled_data.extend(batch) labeled_data = torch.tensor(labeled_data) unlabeled_data = torch.tensor(unlabeled_data) assert len(labeled_data) == len(unlabeled_data) assert torch.all(labeled_data % 2 == 0) assert torch.all(unlabeled_data % 2 != 0) def test_p(self): ss_iterator = SemiSupervisedIterator(self.al_dataset, p=0.1, num_steps=None, batch_size=10) labeled_data = [] unlabeled_data = [] for batch_idx, batch in enumerate(ss_iterator): if SemiSupervisedIterator.is_labeled(batch): batch = SemiSupervisedIterator.get_batch(batch) labeled_data.extend(batch) else: batch = SemiSupervisedIterator.get_batch(batch) unlabeled_data.extend(batch) total = len(labeled_data) + len(unlabeled_data) l_ratio = len(labeled_data) / total u_ratio = len(unlabeled_data) / total assert l_ratio < .15 assert u_ratio > 0.85 def test_no_pool(self): d1 = SSLTestDataset(labeled=True, length=100) al_dataset = ActiveLearningDataset(d1) al_dataset.label_randomly(100) ss_iterator = SemiSupervisedIterator(al_dataset, p=0.1, num_steps=None, batch_size=10) labeled_data = [] unlabeled_data = [] for batch_idx, batch in enumerate(ss_iterator): if SemiSupervisedIterator.is_labeled(batch): batch = SemiSupervisedIterator.get_batch(batch) labeled_data.extend(batch) else: batch = SemiSupervisedIterator.get_batch(batch) unlabeled_data.extend(batch) total = len(labeled_data) + len(unlabeled_data) l_ratio = len(labeled_data) / total u_ratio = len(unlabeled_data) / total assert l_ratio == 1 assert u_ratio == 0
class ActiveDatasetTest(unittest.TestCase): def setUp(self): self.dataset = ActiveLearningDataset(MyDataset(), make_unlabelled=lambda x: (x[0], -1)) def test_len(self): assert len(self.dataset) == 0 assert self.dataset.n_unlabelled == 100 assert len(self.dataset.pool) == 100 self.dataset.label(0) assert len(self.dataset) == self.dataset.n_labelled == 1 assert self.dataset.n_unlabelled == 99 assert len(self.dataset.pool) == 99 self.dataset.label(list(range(99))) assert len(self.dataset) == 100 assert self.dataset.n_unlabelled == 0 assert len(self.dataset.pool) == 0 dummy_dataset = ActiveLearningDataset(MyDataset(), labelled=self.dataset._labelled, make_unlabelled=lambda x: (x[0], -1)) assert len(dummy_dataset) == len(self.dataset) assert len(dummy_dataset.pool) == len(self.dataset.pool) dummy_lbl = torch.from_numpy(self.dataset._labelled.astype(np.float32)) dummy_dataset = ActiveLearningDataset(MyDataset(), labelled=dummy_lbl, make_unlabelled=lambda x: (x[0], -1)) assert len(dummy_dataset) == len(self.dataset) assert len(dummy_dataset.pool) == len(self.dataset.pool) def test_pool(self): self.dataset._dataset.label = unittest.mock.MagicMock() labels_initial = self.dataset.n_labelled self.dataset.can_label = False self.dataset.label(0, value=np.arange(1, 10)) self.dataset._dataset.label.assert_not_called() labels_next_1 = self.dataset.n_labelled assert labels_next_1 == labels_initial + 1 self.dataset.can_label = True self.dataset.label(np.arange(0, 9)) self.dataset._dataset.label.assert_not_called() labels_next_2 = self.dataset.n_labelled assert labels_next_1 == labels_next_2 self.dataset.label(np.arange(0, 9), value=np.arange(1, 10)) assert self.dataset._dataset.label.called_once_with(np.arange(1, 10)) # cleanup del self.dataset._dataset.label self.dataset.can_label = False pool = self.dataset.pool assert np.equal([i for i in pool], [(i, -1) for i in np.arange(2, 100)]).all() assert np.equal([i for i in self.dataset], [(i, i) for i in np.arange(2)]).all() def test_get_raw(self): # check that get_raw returns the same thing regardless of labelling # status i_1 = self.dataset.get_raw(5) self.dataset.label(5) i_2 = self.dataset.get_raw(5) assert i_1 == i_2 def test_state_dict(self): state_dict_1 = self.dataset.state_dict() assert np.equal(state_dict_1["labeled"], np.full((100, ), False)).all() self.dataset.label(0) assert np.equal( state_dict_1["labeled"], np.concatenate((np.array([True]), np.full((99, ), False)))).all() def test_transform(self): train_transform = Lambda(lambda k: 1) test_transform = Lambda(lambda k: 0) dataset = ActiveLearningDataset(MyDataset(train_transform), test_transform, make_unlabelled=lambda x: (x[0], -1)) dataset.label(np.arange(10)) pool = dataset.pool assert np.equal([i for i in pool], [(0, -1) for i in np.arange(10, 100)]).all() assert np.equal([i for i in dataset], [(1, i) for i in np.arange(10)]).all() def test_random(self): self.dataset.label_randomly(50) assert len(self.dataset) == 50 assert len(self.dataset.pool) == 50
class ActiveDatasetTest(unittest.TestCase): def setUp(self): self.dataset = ActiveLearningDataset(MyDataset(), make_unlabelled=lambda x: (x[0], -1)) def test_len(self): assert len(self.dataset) == 0 assert self.dataset.n_unlabelled == 100 assert len(self.dataset.pool) == 100 self.dataset.label(0) assert len(self.dataset) == self.dataset.n_labelled == 1 assert self.dataset.n_unlabelled == 99 assert len(self.dataset.pool) == 99 self.dataset.label(list(range(99))) assert len(self.dataset) == 100 assert self.dataset.n_unlabelled == 0 assert len(self.dataset.pool) == 0 dummy_dataset = ActiveLearningDataset(MyDataset(), labelled=self.dataset._labelled, make_unlabelled=lambda x: (x[0], -1)) assert len(dummy_dataset) == len(self.dataset) assert len(dummy_dataset.pool) == len(self.dataset.pool) dummy_lbl = torch.from_numpy(self.dataset._labelled.astype(np.float32)) dummy_dataset = ActiveLearningDataset(MyDataset(), labelled=dummy_lbl, make_unlabelled=lambda x: (x[0], -1)) assert len(dummy_dataset) == len(self.dataset) assert len(dummy_dataset.pool) == len(self.dataset.pool) def test_pool(self): self.dataset._dataset.label = unittest.mock.MagicMock() labels_initial = self.dataset.n_labelled self.dataset.can_label = False self.dataset.label(0, value=np.arange(1, 10)) self.dataset._dataset.label.assert_not_called() labels_next_1 = self.dataset.n_labelled assert labels_next_1 == labels_initial + 1 self.dataset.can_label = True self.dataset.label(np.arange(0, 9)) self.dataset._dataset.label.assert_not_called() labels_next_2 = self.dataset.n_labelled assert labels_next_1 == labels_next_2 self.dataset.label(np.arange(0, 9), value=np.arange(1, 10)) assert self.dataset._dataset.label.called_once_with(np.arange(1, 10)) # cleanup del self.dataset._dataset.label self.dataset.can_label = False pool = self.dataset.pool assert np.equal([i for i in pool], [(i, -1) for i in np.arange(2, 100)]).all() assert np.equal([i for i in self.dataset], [(i, i) for i in np.arange(2)]).all() def test_get_raw(self): # check that get_raw returns the same thing regardless of labelling # status i_1 = self.dataset.get_raw(5) self.dataset.label(5) i_2 = self.dataset.get_raw(5) assert i_1 == i_2 def test_types(self): self.dataset.label_randomly(2) assert self.dataset._pool_to_oracle_index( 1) == self.dataset._pool_to_oracle_index([1]) assert self.dataset._oracle_to_pool_index( 1) == self.dataset._oracle_to_pool_index([1]) def test_state_dict(self): state_dict_1 = self.dataset.state_dict() assert np.equal(state_dict_1["labelled"], np.full((100, ), False)).all() self.dataset.label(0) assert np.equal( state_dict_1["labelled"], np.concatenate((np.array([True]), np.full((99, ), False)))).all() def test_load_state_dict(self): dataset_1 = ActiveLearningDataset(MyDataset(), random_state=50) dataset_1.label_randomly(10) state_dict1 = dataset_1.state_dict() dataset_2 = ActiveLearningDataset(MyDataset(), random_state=None) assert dataset_2.n_labelled == 0 dataset_2.load_state_dict(state_dict1) assert dataset_2.n_labelled == 10 # test if the second lable_randomly call have same behaviour dataset_1.label_randomly(5) dataset_2.label_randomly(5) assert np.allclose(dataset_1._labelled, dataset_2._labelled) def test_transform(self): train_transform = Lambda(lambda k: 1) test_transform = Lambda(lambda k: 0) dataset = ActiveLearningDataset( MyDataset(train_transform), pool_specifics={'transform': test_transform}, make_unlabelled=lambda x: (x[0], -1)) dataset.label(np.arange(10)) pool = dataset.pool assert np.equal([i for i in pool], [(0, -1) for i in np.arange(10, 100)]).all() assert np.equal([i for i in dataset], [(1, i) for i in np.arange(10)]).all() with pytest.warns(DeprecationWarning) as e: ActiveLearningDataset(MyDataset(train_transform), eval_transform=train_transform) assert len(e) == 1 with pytest.raises(ValueError) as e: ActiveLearningDataset(MyDataset(train_transform), pool_specifics={ 'whatever': 123 }).pool def test_random(self): self.dataset.label_randomly(50) assert len(self.dataset) == 50 assert len(self.dataset.pool) == 50 def test_random_state(self): seed = None dataset_1 = ActiveLearningDataset(MyDataset(), random_state=seed) dataset_1.label_randomly(10) dataset_2 = ActiveLearningDataset(MyDataset(), random_state=seed) dataset_2.label_randomly(10) assert not np.allclose(dataset_1._labelled, dataset_2._labelled) seed = 50 dataset_1 = ActiveLearningDataset(MyDataset(), random_state=seed) dataset_1.label_randomly(10) dataset_2 = ActiveLearningDataset(MyDataset(), random_state=seed) dataset_2.label_randomly(10) assert np.allclose(dataset_1._labelled, dataset_2._labelled) seed = np.random.RandomState(50) dataset_1 = ActiveLearningDataset(MyDataset(), random_state=seed) dataset_1.label_randomly(10) dataset_2 = ActiveLearningDataset(MyDataset(), random_state=seed) dataset_2.label_randomly(10) assert not np.allclose(dataset_1._labelled, dataset_2._labelled) def test_label_randomly_full(self): dataset_1 = ActiveLearningDataset(MyDataset()) dataset_1.label_randomly(99) assert dataset_1.n_unlabelled == 1 assert len(dataset_1.pool) == 1 dataset_1.label_randomly(1) assert dataset_1.n_unlabelled == 0 assert dataset_1.n_labelled == 100
class FileDatasetTest(unittest.TestCase): @classmethod def setUpClass(cls): tmp_dir = tempfile.gettempdir() paths = [] for idx in range(100): path = os.path.join(tmp_dir, "{}.png".format(idx)) Image.fromarray(np.random.randint(0, 100, [10, 10, 3], np.uint8)).save(path) paths.append(path) cls.paths = paths def setUp(self): self.lbls = None self.transform = Compose([Resize(60), RandomRotation(90), ToTensor()]) testtransform = Compose([Resize(32), ToTensor()]) self.dataset = FileDataset(self.paths, self.lbls, transform=self.transform) self.lbls = self.generate_labels(len(self.paths), 10) self.dataset = FileDataset(self.paths, self.lbls, transform=self.transform) self.active = ActiveLearningDataset( self.dataset, labelled=(np.array(self.lbls) != -1), pool_specifics={'transform': testtransform}) def generate_labels(self, n, init_lbls): lbls = [-1] * n for i in random.sample(range(n), init_lbls): lbls[i] = i % 10 return lbls def test_default_label(self): dataset = FileDataset(self.paths) assert dataset.lbls == [-1] * len(self.paths) def test_labelling(self): actually_labelled = [i for i, j in enumerate(self.lbls) if j >= 0] actually_not_labelled = [i for i, j in enumerate(self.lbls) if j < 0] with pytest.warns(UserWarning): self.dataset.label(actually_labelled[0], 1) self.dataset.label(actually_not_labelled[0], 1) assert sum([1 for i, j in enumerate(self.dataset.lbls) if j >= 0]) == 11 def test_active_labelling(self): assert self.active.can_label actually_not_labelled = [i for i, j in enumerate(self.lbls) if j < 0] actually_labelled = [i for i, j in enumerate(self.lbls) if j >= 0] init_length = len(self.active) self.active.label(actually_not_labelled[0], 1) assert len(self.active) == init_length + 1 with pytest.warns(UserWarning): self.dataset.label(actually_labelled[0], None) assert len(self.active) == init_length + 1 def test_filedataset_segmentation(self): target_trans = Compose([ default_image_load_fn, Resize(60), RandomRotation(90), ToTensor() ]) file_dataset = FileDataset(self.paths, self.paths, self.transform, target_trans, seed=1337) x, y = file_dataset[0] assert np.allclose(x.numpy(), y.numpy()) out1 = list( DataLoader(file_dataset, batch_size=1, num_workers=3, shuffle=False)) out2 = list( DataLoader(file_dataset, batch_size=1, num_workers=3, shuffle=False)) assert all([ np.allclose(x1.numpy(), x2.numpy()) for (x1, _), (x2, _) in zip(out1, out2) ]) file_dataset = FileDataset(self.paths, self.paths, self.transform, target_trans, seed=None) x, y = file_dataset[0] assert np.allclose(x.numpy(), y.numpy()) out1 = list( DataLoader(file_dataset, batch_size=1, num_workers=3, shuffle=False)) out2 = list( DataLoader(file_dataset, batch_size=1, num_workers=3, shuffle=False)) assert not all([ np.allclose(x1.numpy(), x2.numpy()) for (x1, _), (x2, _) in zip(out1, out2) ]) def test_segmentation_pipeline(self): class DrawSquare: def __init__(self, side): self.side = side def __call__(self, x, **kwargs): x, canvas = x # x is a [int, ndarray] canvas[:self.side, :self.side] = x return canvas target_trans = BaaLCompose([ GetCanvas(), DrawSquare(3), ToPILImage(mode=None), Resize(60, interpolation=0), RandomRotation(10, resample=NEAREST, fill=0.0), PILToLongTensor() ]) file_dataset = FileDataset(self.paths, [1] * len(self.paths), self.transform, target_trans) x, y = file_dataset[0] assert np.allclose(np.unique(y), [0, 1]) assert y.shape[1:] == x.shape[1:]
def main(data_path, splits_path, preload, patch_size, batch_size, n_label_start, manual_seed, epochs, al_cycles, n_data_to_label, mc_iters, base_state_dict_path, heuristic, run, results_dir, balance_al, num_classes, save_maps, save_uncerts): print("Start of AL experiment using {} heuristic".format(heuristic)) torch.backends.cudnn.benchmark = True if manual_seed: torch.manual_seed(manual_seed) random.seed(manual_seed) np.random.seed(manual_seed) train_ds, test_ds, val_ds = load_glas(data_path, splits_path, preload, patch_size=patch_size) active_set = ActiveLearningDataset(train_ds, pool_specifics=pool_specifics) # Label 4 images, 2 for each class active_set.label([0, 1, 2, 3]) # Load model model = UNet(in_channels=3, n_classes=num_classes, dropout=True) print(model) method_wrapper = MCDropoutUncert(base_model=model, n_classes=num_classes, state_dict_path=base_state_dict_path) acq_scores = [] mean_dices = [] last_cycle = False for al_it in range(al_cycles + 1): print("##### ACTIVE LEARNING ITERATION {}/{}#####".format( al_it, al_cycles)) print("Labeled pool size: {}".format(active_set.n_labelled)) print("Unlabeled pool size: {}".format(active_set.n_unlabelled)) # Reset model method_wrapper.reset_params() # train method_wrapper.train(train_ds=active_set, val_ds=val_ds, epochs=epochs, batch_size=batch_size, opt_sch_callable=get_optimizer_scheduler) # test_metrics = method_wrapper.evaluate(DataLoader(dataset=test_ds, batch_size=1, shuffle=False), test=True) test_metrics = method_wrapper.test_bma(dataset=test_ds, n_predictions=20) mean_dices.append(test_metrics['mean_dice']) print('Test bma mean dice: {}'.format(test_metrics['mean_dice'])) if last_cycle: print( "Every sample from the pool has been labeled, closing AL loop." ) break # Make predictions on unlabeled pool predictions = method_wrapper.predict(active_set.pool, n_predictions=mc_iters) heur = heuristics_dict[heuristic] to_label, scores = heur.get_to_label(predictions=predictions, model=None, n_to_label=n_data_to_label, num_classes=num_classes, balance_al=balance_al) if save_maps: save_prediction_maps(predictions, al_it, map_processor=ComparativeVarianceMap(), heuristic=heuristic, uncertainties=scores) # Avoids memory issues del predictions acq_scores.append(scores) # Label new samples active_set.label(to_label) if active_set.n_unlabelled == 0: last_cycle = True print(mean_dices) if not os.path.isdir(results_dir): os.mkdir(results_dir) np.save(results_dir + '{}_{}.npy'.format(heuristic, run), np.array(mean_dices)) if save_uncerts: save_uncert_histogram(acq_scores, heuristic)