Ejemplo n.º 1
0
class SSLModuleTest(unittest.TestCase):
    def setUp(self):
        d1_len = 100
        d2_len = 1000
        d1 = SSLTestDataset(labeled=True, length=d1_len)
        d2 = SSLTestDataset(labeled=False, length=d2_len)
        dataset = ConcatDataset([d1, d2])

        print(len(dataset))

        self.al_dataset = ActiveLearningDataset(dataset)
        self.al_dataset.label(list(
            range(d1_len)))  # Label data from d1 (even numbers)

    def test_epoch(self):
        hparams = {
            'p': None,
            'num_steps': None,
            'batch_size': 10,
            'workers': 0
        }

        module = TestSSLModule(self.al_dataset, Namespace(**hparams))
        trainer = Trainer(max_epochs=1,
                          num_sanity_val_steps=0,
                          progress_bar_refresh_rate=0,
                          logger=False,
                          checkpoint_callback=False)
        trainer.fit(module)

        assert len(module.labeled_data) == len(module.unlabeled_data)
        assert torch.all(torch.tensor(module.labeled_data) % 2 == 0)
        assert torch.all(torch.tensor(module.unlabeled_data) % 2 != 0)
Ejemplo n.º 2
0
def test_arrowds():
    dataset = HFdata.load_dataset('glue', 'sst2')['test']
    dataset = ActiveLearningDataset(dataset)
    dataset.label(np.arange(10))
    assert len(dataset) == 10
    assert len(dataset.pool) == 1811
    data = dataset.pool[0]
    assert [k in ['idx', 'label', 'sentence'] for k, v in data.items()]
Ejemplo n.º 3
0
 def test_transform(self):
     train_transform = Lambda(lambda k: 1)
     test_transform = Lambda(lambda k: 0)
     dataset = ActiveLearningDataset(MyDataset(train_transform),
                                     test_transform,
                                     make_unlabelled=lambda x: (x[0], -1))
     dataset.label(np.arange(10))
     pool = dataset.pool
     assert np.equal([i for i in pool],
                     [(0, -1) for i in np.arange(10, 100)]).all()
     assert np.equal([i for i in dataset],
                     [(1, i) for i in np.arange(10)]).all()
Ejemplo n.º 4
0
    def test_transform(self):
        train_transform = Lambda(lambda k: 1)
        test_transform = Lambda(lambda k: 0)
        dataset = ActiveLearningDataset(MyDataset(train_transform), make_unlabelled=lambda x: (x[0], -1),
                                        pool_specifics={'transform': test_transform})
        dataset.label(np.arange(10))
        pool = dataset.pool
        assert np.equal([i for i in pool], [(0, -1) for i in np.arange(10, 100)]).all()
        assert np.equal([i for i in dataset], [(1, i) for i in np.arange(10)]).all()

        with pytest.raises(ValueError) as e:
            ActiveLearningDataset(MyDataset(train_transform), pool_specifics={'whatever': 123}).pool
Ejemplo n.º 5
0
class SSLDatasetTest(unittest.TestCase):
    def setUp(self):
        d1_len = 100
        d2_len = 1000
        d1 = SSLTestDataset(labeled=True, length=d1_len)
        d2 = SSLTestDataset(labeled=False, length=d2_len)
        dataset = ConcatDataset([d1, d2])

        print(len(dataset))

        self.al_dataset = ActiveLearningDataset(dataset)
        self.al_dataset.label(list(
            range(d1_len)))  # Label data from d1 (even numbers)

        self.ss_iterator = SemiSupervisedIterator(self.al_dataset,
                                                  p=None,
                                                  num_steps=None,
                                                  batch_size=10)

    def test_epoch(self):
        labeled_data = []
        unlabeled_data = []

        for batch_idx, batch in enumerate(self.ss_iterator):
            if SemiSupervisedIterator.is_labeled(batch):
                batch = SemiSupervisedIterator.get_batch(batch)
                labeled_data.extend(batch)
            else:
                batch = SemiSupervisedIterator.get_batch(batch)
                unlabeled_data.extend(batch)

        labeled_data = torch.tensor(labeled_data)
        unlabeled_data = torch.tensor(unlabeled_data)

        assert len(labeled_data) == len(unlabeled_data)
        assert torch.all(labeled_data % 2 == 0)
        assert torch.all(unlabeled_data % 2 != 0)

    def test_p(self):
        ss_iterator = SemiSupervisedIterator(self.al_dataset,
                                             p=0.1,
                                             num_steps=None,
                                             batch_size=10)

        labeled_data = []
        unlabeled_data = []
        for batch_idx, batch in enumerate(ss_iterator):
            if SemiSupervisedIterator.is_labeled(batch):
                batch = SemiSupervisedIterator.get_batch(batch)
                labeled_data.extend(batch)
            else:
                batch = SemiSupervisedIterator.get_batch(batch)
                unlabeled_data.extend(batch)

        total = len(labeled_data) + len(unlabeled_data)
        l_ratio = len(labeled_data) / total
        u_ratio = len(unlabeled_data) / total
        assert l_ratio < .15
        assert u_ratio > 0.85

    def test_no_pool(self):
        d1 = SSLTestDataset(labeled=True, length=100)
        al_dataset = ActiveLearningDataset(d1)
        al_dataset.label_randomly(100)
        ss_iterator = SemiSupervisedIterator(al_dataset,
                                             p=0.1,
                                             num_steps=None,
                                             batch_size=10)

        labeled_data = []
        unlabeled_data = []
        for batch_idx, batch in enumerate(ss_iterator):
            if SemiSupervisedIterator.is_labeled(batch):
                batch = SemiSupervisedIterator.get_batch(batch)
                labeled_data.extend(batch)
            else:
                batch = SemiSupervisedIterator.get_batch(batch)
                unlabeled_data.extend(batch)

        total = len(labeled_data) + len(unlabeled_data)
        l_ratio = len(labeled_data) / total
        u_ratio = len(unlabeled_data) / total
        assert l_ratio == 1
        assert u_ratio == 0
Ejemplo n.º 6
0
class ActiveDatasetTest(unittest.TestCase):
    def setUp(self):
        self.dataset = ActiveLearningDataset(MyDataset(),
                                             make_unlabelled=lambda x:
                                             (x[0], -1))

    def test_len(self):
        assert len(self.dataset) == 0
        assert self.dataset.n_unlabelled == 100
        assert len(self.dataset.pool) == 100
        self.dataset.label(0)
        assert len(self.dataset) == self.dataset.n_labelled == 1
        assert self.dataset.n_unlabelled == 99
        assert len(self.dataset.pool) == 99
        self.dataset.label(list(range(99)))
        assert len(self.dataset) == 100
        assert self.dataset.n_unlabelled == 0
        assert len(self.dataset.pool) == 0

        dummy_dataset = ActiveLearningDataset(MyDataset(),
                                              labelled=self.dataset._labelled,
                                              make_unlabelled=lambda x:
                                              (x[0], -1))
        assert len(dummy_dataset) == len(self.dataset)
        assert len(dummy_dataset.pool) == len(self.dataset.pool)

        dummy_lbl = torch.from_numpy(self.dataset._labelled.astype(np.float32))
        dummy_dataset = ActiveLearningDataset(MyDataset(),
                                              labelled=dummy_lbl,
                                              make_unlabelled=lambda x:
                                              (x[0], -1))
        assert len(dummy_dataset) == len(self.dataset)
        assert len(dummy_dataset.pool) == len(self.dataset.pool)

    def test_pool(self):
        self.dataset._dataset.label = unittest.mock.MagicMock()
        labels_initial = self.dataset.n_labelled
        self.dataset.can_label = False
        self.dataset.label(0, value=np.arange(1, 10))
        self.dataset._dataset.label.assert_not_called()
        labels_next_1 = self.dataset.n_labelled
        assert labels_next_1 == labels_initial + 1
        self.dataset.can_label = True
        self.dataset.label(np.arange(0, 9))
        self.dataset._dataset.label.assert_not_called()
        labels_next_2 = self.dataset.n_labelled
        assert labels_next_1 == labels_next_2
        self.dataset.label(np.arange(0, 9), value=np.arange(1, 10))
        assert self.dataset._dataset.label.called_once_with(np.arange(1, 10))
        # cleanup
        del self.dataset._dataset.label
        self.dataset.can_label = False

        pool = self.dataset.pool
        assert np.equal([i for i in pool],
                        [(i, -1) for i in np.arange(2, 100)]).all()
        assert np.equal([i for i in self.dataset],
                        [(i, i) for i in np.arange(2)]).all()

    def test_get_raw(self):
        # check that get_raw returns the same thing regardless of labelling
        # status
        i_1 = self.dataset.get_raw(5)
        self.dataset.label(5)
        i_2 = self.dataset.get_raw(5)
        assert i_1 == i_2

    def test_state_dict(self):
        state_dict_1 = self.dataset.state_dict()
        assert np.equal(state_dict_1["labeled"], np.full((100, ), False)).all()
        self.dataset.label(0)
        assert np.equal(
            state_dict_1["labeled"],
            np.concatenate((np.array([True]), np.full((99, ), False)))).all()

    def test_transform(self):
        train_transform = Lambda(lambda k: 1)
        test_transform = Lambda(lambda k: 0)
        dataset = ActiveLearningDataset(MyDataset(train_transform),
                                        test_transform,
                                        make_unlabelled=lambda x: (x[0], -1))
        dataset.label(np.arange(10))
        pool = dataset.pool
        assert np.equal([i for i in pool],
                        [(0, -1) for i in np.arange(10, 100)]).all()
        assert np.equal([i for i in dataset],
                        [(1, i) for i in np.arange(10)]).all()

    def test_random(self):
        self.dataset.label_randomly(50)
        assert len(self.dataset) == 50
        assert len(self.dataset.pool) == 50
Ejemplo n.º 7
0
class ActiveDatasetTest(unittest.TestCase):
    def setUp(self):
        self.dataset = ActiveLearningDataset(MyDataset(),
                                             make_unlabelled=lambda x:
                                             (x[0], -1))

    def test_len(self):
        assert len(self.dataset) == 0
        assert self.dataset.n_unlabelled == 100
        assert len(self.dataset.pool) == 100
        self.dataset.label(0)
        assert len(self.dataset) == self.dataset.n_labelled == 1
        assert self.dataset.n_unlabelled == 99
        assert len(self.dataset.pool) == 99
        self.dataset.label(list(range(99)))
        assert len(self.dataset) == 100
        assert self.dataset.n_unlabelled == 0
        assert len(self.dataset.pool) == 0

        dummy_dataset = ActiveLearningDataset(MyDataset(),
                                              labelled=self.dataset._labelled,
                                              make_unlabelled=lambda x:
                                              (x[0], -1))
        assert len(dummy_dataset) == len(self.dataset)
        assert len(dummy_dataset.pool) == len(self.dataset.pool)

        dummy_lbl = torch.from_numpy(self.dataset._labelled.astype(np.float32))
        dummy_dataset = ActiveLearningDataset(MyDataset(),
                                              labelled=dummy_lbl,
                                              make_unlabelled=lambda x:
                                              (x[0], -1))
        assert len(dummy_dataset) == len(self.dataset)
        assert len(dummy_dataset.pool) == len(self.dataset.pool)

    def test_pool(self):
        self.dataset._dataset.label = unittest.mock.MagicMock()
        labels_initial = self.dataset.n_labelled
        self.dataset.can_label = False
        self.dataset.label(0, value=np.arange(1, 10))
        self.dataset._dataset.label.assert_not_called()
        labels_next_1 = self.dataset.n_labelled
        assert labels_next_1 == labels_initial + 1
        self.dataset.can_label = True
        self.dataset.label(np.arange(0, 9))
        self.dataset._dataset.label.assert_not_called()
        labels_next_2 = self.dataset.n_labelled
        assert labels_next_1 == labels_next_2
        self.dataset.label(np.arange(0, 9), value=np.arange(1, 10))
        assert self.dataset._dataset.label.called_once_with(np.arange(1, 10))
        # cleanup
        del self.dataset._dataset.label
        self.dataset.can_label = False

        pool = self.dataset.pool
        assert np.equal([i for i in pool],
                        [(i, -1) for i in np.arange(2, 100)]).all()
        assert np.equal([i for i in self.dataset],
                        [(i, i) for i in np.arange(2)]).all()

    def test_get_raw(self):
        # check that get_raw returns the same thing regardless of labelling
        # status
        i_1 = self.dataset.get_raw(5)
        self.dataset.label(5)
        i_2 = self.dataset.get_raw(5)
        assert i_1 == i_2

    def test_types(self):
        self.dataset.label_randomly(2)
        assert self.dataset._pool_to_oracle_index(
            1) == self.dataset._pool_to_oracle_index([1])
        assert self.dataset._oracle_to_pool_index(
            1) == self.dataset._oracle_to_pool_index([1])

    def test_state_dict(self):
        state_dict_1 = self.dataset.state_dict()
        assert np.equal(state_dict_1["labelled"], np.full((100, ),
                                                          False)).all()

        self.dataset.label(0)
        assert np.equal(
            state_dict_1["labelled"],
            np.concatenate((np.array([True]), np.full((99, ), False)))).all()

    def test_load_state_dict(self):
        dataset_1 = ActiveLearningDataset(MyDataset(), random_state=50)
        dataset_1.label_randomly(10)
        state_dict1 = dataset_1.state_dict()

        dataset_2 = ActiveLearningDataset(MyDataset(), random_state=None)
        assert dataset_2.n_labelled == 0

        dataset_2.load_state_dict(state_dict1)
        assert dataset_2.n_labelled == 10

        # test if the second lable_randomly call have same behaviour
        dataset_1.label_randomly(5)
        dataset_2.label_randomly(5)

        assert np.allclose(dataset_1._labelled, dataset_2._labelled)

    def test_transform(self):
        train_transform = Lambda(lambda k: 1)
        test_transform = Lambda(lambda k: 0)
        dataset = ActiveLearningDataset(
            MyDataset(train_transform),
            pool_specifics={'transform': test_transform},
            make_unlabelled=lambda x: (x[0], -1))
        dataset.label(np.arange(10))
        pool = dataset.pool
        assert np.equal([i for i in pool],
                        [(0, -1) for i in np.arange(10, 100)]).all()
        assert np.equal([i for i in dataset],
                        [(1, i) for i in np.arange(10)]).all()

        with pytest.warns(DeprecationWarning) as e:
            ActiveLearningDataset(MyDataset(train_transform),
                                  eval_transform=train_transform)
        assert len(e) == 1

        with pytest.raises(ValueError) as e:
            ActiveLearningDataset(MyDataset(train_transform),
                                  pool_specifics={
                                      'whatever': 123
                                  }).pool

    def test_random(self):
        self.dataset.label_randomly(50)
        assert len(self.dataset) == 50
        assert len(self.dataset.pool) == 50

    def test_random_state(self):
        seed = None
        dataset_1 = ActiveLearningDataset(MyDataset(), random_state=seed)
        dataset_1.label_randomly(10)
        dataset_2 = ActiveLearningDataset(MyDataset(), random_state=seed)
        dataset_2.label_randomly(10)
        assert not np.allclose(dataset_1._labelled, dataset_2._labelled)

        seed = 50
        dataset_1 = ActiveLearningDataset(MyDataset(), random_state=seed)
        dataset_1.label_randomly(10)
        dataset_2 = ActiveLearningDataset(MyDataset(), random_state=seed)
        dataset_2.label_randomly(10)
        assert np.allclose(dataset_1._labelled, dataset_2._labelled)

        seed = np.random.RandomState(50)
        dataset_1 = ActiveLearningDataset(MyDataset(), random_state=seed)
        dataset_1.label_randomly(10)
        dataset_2 = ActiveLearningDataset(MyDataset(), random_state=seed)
        dataset_2.label_randomly(10)
        assert not np.allclose(dataset_1._labelled, dataset_2._labelled)

    def test_label_randomly_full(self):
        dataset_1 = ActiveLearningDataset(MyDataset())
        dataset_1.label_randomly(99)
        assert dataset_1.n_unlabelled == 1
        assert len(dataset_1.pool) == 1
        dataset_1.label_randomly(1)
        assert dataset_1.n_unlabelled == 0
        assert dataset_1.n_labelled == 100
Ejemplo n.º 8
0
class FileDatasetTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        tmp_dir = tempfile.gettempdir()
        paths = []
        for idx in range(100):
            path = os.path.join(tmp_dir, "{}.png".format(idx))
            Image.fromarray(np.random.randint(0, 100, [10, 10, 3],
                                              np.uint8)).save(path)
            paths.append(path)
        cls.paths = paths

    def setUp(self):
        self.lbls = None
        self.transform = Compose([Resize(60), RandomRotation(90), ToTensor()])
        testtransform = Compose([Resize(32), ToTensor()])
        self.dataset = FileDataset(self.paths,
                                   self.lbls,
                                   transform=self.transform)
        self.lbls = self.generate_labels(len(self.paths), 10)
        self.dataset = FileDataset(self.paths,
                                   self.lbls,
                                   transform=self.transform)
        self.active = ActiveLearningDataset(
            self.dataset,
            labelled=(np.array(self.lbls) != -1),
            pool_specifics={'transform': testtransform})

    def generate_labels(self, n, init_lbls):
        lbls = [-1] * n
        for i in random.sample(range(n), init_lbls):
            lbls[i] = i % 10
        return lbls

    def test_default_label(self):
        dataset = FileDataset(self.paths)
        assert dataset.lbls == [-1] * len(self.paths)

    def test_labelling(self):
        actually_labelled = [i for i, j in enumerate(self.lbls) if j >= 0]
        actually_not_labelled = [i for i, j in enumerate(self.lbls) if j < 0]
        with pytest.warns(UserWarning):
            self.dataset.label(actually_labelled[0], 1)
        self.dataset.label(actually_not_labelled[0], 1)
        assert sum([1 for i, j in enumerate(self.dataset.lbls)
                    if j >= 0]) == 11

    def test_active_labelling(self):
        assert self.active.can_label
        actually_not_labelled = [i for i, j in enumerate(self.lbls) if j < 0]
        actually_labelled = [i for i, j in enumerate(self.lbls) if j >= 0]

        init_length = len(self.active)
        self.active.label(actually_not_labelled[0], 1)
        assert len(self.active) == init_length + 1

        with pytest.warns(UserWarning):
            self.dataset.label(actually_labelled[0], None)
        assert len(self.active) == init_length + 1

    def test_filedataset_segmentation(self):
        target_trans = Compose([
            default_image_load_fn,
            Resize(60),
            RandomRotation(90),
            ToTensor()
        ])
        file_dataset = FileDataset(self.paths,
                                   self.paths,
                                   self.transform,
                                   target_trans,
                                   seed=1337)
        x, y = file_dataset[0]
        assert np.allclose(x.numpy(), y.numpy())
        out1 = list(
            DataLoader(file_dataset,
                       batch_size=1,
                       num_workers=3,
                       shuffle=False))
        out2 = list(
            DataLoader(file_dataset,
                       batch_size=1,
                       num_workers=3,
                       shuffle=False))
        assert all([
            np.allclose(x1.numpy(), x2.numpy())
            for (x1, _), (x2, _) in zip(out1, out2)
        ])

        file_dataset = FileDataset(self.paths,
                                   self.paths,
                                   self.transform,
                                   target_trans,
                                   seed=None)
        x, y = file_dataset[0]
        assert np.allclose(x.numpy(), y.numpy())
        out1 = list(
            DataLoader(file_dataset,
                       batch_size=1,
                       num_workers=3,
                       shuffle=False))
        out2 = list(
            DataLoader(file_dataset,
                       batch_size=1,
                       num_workers=3,
                       shuffle=False))
        assert not all([
            np.allclose(x1.numpy(), x2.numpy())
            for (x1, _), (x2, _) in zip(out1, out2)
        ])

    def test_segmentation_pipeline(self):
        class DrawSquare:
            def __init__(self, side):
                self.side = side

            def __call__(self, x, **kwargs):
                x, canvas = x  # x is a [int, ndarray]
                canvas[:self.side, :self.side] = x
                return canvas

        target_trans = BaaLCompose([
            GetCanvas(),
            DrawSquare(3),
            ToPILImage(mode=None),
            Resize(60, interpolation=0),
            RandomRotation(10, resample=NEAREST, fill=0.0),
            PILToLongTensor()
        ])
        file_dataset = FileDataset(self.paths, [1] * len(self.paths),
                                   self.transform, target_trans)

        x, y = file_dataset[0]
        assert np.allclose(np.unique(y), [0, 1])
        assert y.shape[1:] == x.shape[1:]
Ejemplo n.º 9
0
def main(data_path, splits_path, preload, patch_size, batch_size,
         n_label_start, manual_seed, epochs, al_cycles, n_data_to_label,
         mc_iters, base_state_dict_path, heuristic, run, results_dir,
         balance_al, num_classes, save_maps, save_uncerts):
    print("Start of AL experiment using {} heuristic".format(heuristic))
    torch.backends.cudnn.benchmark = True

    if manual_seed:
        torch.manual_seed(manual_seed)
        random.seed(manual_seed)
        np.random.seed(manual_seed)

    train_ds, test_ds, val_ds = load_glas(data_path,
                                          splits_path,
                                          preload,
                                          patch_size=patch_size)
    active_set = ActiveLearningDataset(train_ds, pool_specifics=pool_specifics)

    # Label 4 images, 2 for each class
    active_set.label([0, 1, 2, 3])

    # Load model
    model = UNet(in_channels=3, n_classes=num_classes, dropout=True)
    print(model)

    method_wrapper = MCDropoutUncert(base_model=model,
                                     n_classes=num_classes,
                                     state_dict_path=base_state_dict_path)

    acq_scores = []
    mean_dices = []
    last_cycle = False

    for al_it in range(al_cycles + 1):

        print("##### ACTIVE LEARNING ITERATION {}/{}#####".format(
            al_it, al_cycles))
        print("Labeled pool size: {}".format(active_set.n_labelled))
        print("Unlabeled pool size: {}".format(active_set.n_unlabelled))

        # Reset model
        method_wrapper.reset_params()

        # train
        method_wrapper.train(train_ds=active_set,
                             val_ds=val_ds,
                             epochs=epochs,
                             batch_size=batch_size,
                             opt_sch_callable=get_optimizer_scheduler)

        # test_metrics = method_wrapper.evaluate(DataLoader(dataset=test_ds, batch_size=1, shuffle=False), test=True)
        test_metrics = method_wrapper.test_bma(dataset=test_ds,
                                               n_predictions=20)
        mean_dices.append(test_metrics['mean_dice'])
        print('Test bma mean dice: {}'.format(test_metrics['mean_dice']))

        if last_cycle:
            print(
                "Every sample from the pool has been labeled, closing AL loop."
            )
            break

        # Make predictions on unlabeled pool
        predictions = method_wrapper.predict(active_set.pool,
                                             n_predictions=mc_iters)

        heur = heuristics_dict[heuristic]
        to_label, scores = heur.get_to_label(predictions=predictions,
                                             model=None,
                                             n_to_label=n_data_to_label,
                                             num_classes=num_classes,
                                             balance_al=balance_al)

        if save_maps:
            save_prediction_maps(predictions,
                                 al_it,
                                 map_processor=ComparativeVarianceMap(),
                                 heuristic=heuristic,
                                 uncertainties=scores)

        # Avoids memory issues
        del predictions

        acq_scores.append(scores)

        # Label new samples
        active_set.label(to_label)

        if active_set.n_unlabelled == 0:
            last_cycle = True

    print(mean_dices)

    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)

    np.save(results_dir + '{}_{}.npy'.format(heuristic, run),
            np.array(mean_dices))
    if save_uncerts:
        save_uncert_histogram(acq_scores, heuristic)