Example #1
0
def test_bald_gpu(classification_task):
    torch.manual_seed(1337)
    model, test_set = classification_task
    wrap = BALDGPUWrapper(model)

    out = wrap.predict_on_dataset(test_set, 4, 10, False, 4)
    assert out.shape[0] == len(test_set)
    bald = BALD()
    torch.manual_seed(1337)
    out_bald = bald.get_uncertainties(model.predict_on_dataset(test_set, 4, 10, False, 4))
    assert np.allclose(out, out_bald, rtol=1e-5, atol=1e-5)
Example #2
0
def test_bald_gpu_seg(segmentation_task):
    torch.manual_seed(1337)
    model, test_set = segmentation_task
    wrap = BALDGPUWrapper(model, reduction='sum')

    out = wrap.predict_on_dataset(test_set, 4, 10, False, 4)
    assert out.shape[0] == len(test_set)
    bald = BALD(reduction='sum')
    torch.manual_seed(1337)
    out_bald = bald.get_uncertainties_generator(
        model.predict_on_dataset_generator(test_set, 4, 10, False, 4))
    assert np.allclose(out, out_bald, rtol=1e-5, atol=1e-5)
Example #3
0
def test_heuristic_reductio_check(distributions):
    np.random.seed(1337)
    heuristic = BALD(reduction='none')
    with pytest.raises(ValueError) as e_info:
        heuristic(distributions)
        assert "Can't order sequence with more than 1 dimension." in str(
            e_info.value)
Example #4
0
def main(hparams):
    train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])

    active_set = ActiveLearningDataset(
        CIFAR10(hparams.data_root, train=True, transform=train_transform, download=True),
        pool_specifics={
            'transform': test_transform
        })
    active_set.label_randomly(10)
    heuristic = BALD()
    model = VGG16(active_set, hparams)
    dp = 'dp' if hparams.n_gpus > 1 else None
    trainer = BaalTrainer(max_epochs=3, default_root_dir=hparams.data_root,
                          gpus=hparams.n_gpus, distributed_backend=dp,
                          # The weights of the model will change as it gets
                          # trained; we need to keep a copy (deepcopy) so that
                          # we can reset them.
                          callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))])
    loop = ActiveLearningLoop(active_set, get_probabilities=trainer.predict_on_dataset_generator,
                              heuristic=heuristic,
                              ndata_to_label=hparams.query_size)

    AL_STEPS = 100
    for al_step in range(AL_STEPS):
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = loop.step()
        if not should_continue:
            break
Example #5
0
def test_combine_heuristics_reorder_list():
    # we are just testing if given calculated uncertainty measures for chunks of data
    # the `reorder_indices` would make correct decision. Here index 0 has the
    # highest uncertainty chosen but both methods (uncertainties1 and uncertainties2)
    bald_firstchunk = np.array([0.98])
    bald_secondchunk = np.array([0.87, 0.68])

    variance_firstchunk = np.array([0.76])
    variance_secondchunk = np.array([0.63, 0.48])
    streaming_prediction = [[bald_firstchunk, variance_firstchunk],
                            [bald_secondchunk, variance_secondchunk]]

    heuristics = CombineHeuristics([BALD(), Variance()], weights=[0.5, 0.5],
                                   reduction='mean')
    ranks = heuristics.reorder_indices(streaming_prediction)
    assert np.all(ranks == [0, 1, 2]), "Combine Heuristics is not right {}".format(ranks)
Example #6
0
def test_combine_heuristics_uncertainty_generator():
    np.random.seed(1337)
    prediction_chunks = [chunks(distributions_3d, 2), chunks(distributions_5d, 2)]
    predictions = [distributions_3d, distributions_5d]

    heuristics = CombineHeuristics([BALD(), Variance()], weights=[0.5, 0.5],
                                   reduction='mean')

    assert np.allclose(
        heuristics.get_uncertainties(predictions),
        heuristics.get_uncertainties(prediction_chunks),
    )

    prediction_chunks = [chunks(distributions_3d, 2), chunks(distributions_5d, 2)]
    ranks = heuristics(prediction_chunks)
    assert np.all(ranks == [1, 2, 0]), "Combine Heuristics is not right {}".format(ranks)
Example #7
0
def test_bald(distributions, reduction):
    np.random.seed(1338)

    bald = BALD(reduction=reduction)
    marg = bald(distributions)
    str_marg = bald(chunks(distributions, 2))

    assert np.allclose(
        bald.get_uncertainties(distributions),
        bald.get_uncertainties_generator(chunks(distributions, 2)),
    )

    assert np.all(marg == [1, 2, 0]), "BALD is not right {}".format(marg)
    assert np.all(str_marg == [1, 2, 0]), "StreamingBALD is not right {}".format(marg)

    bald = BALD(threshold=0.1, reduction=reduction)
    marg = bald(distributions)
    assert np.any(distributions[marg] <= 0.1)

    bald = BALD(0.99, reduction=reduction)
    marg = bald(distributions)

    # Unlikely, but not 100% sure
    assert np.any(marg != [1, 2, 0])
Example #8
0
def test_heuristics_reorder_list():
    # we are just testing if given calculated uncertainty measures for chunks of data
    # the `reorder_indices` would make correct decision. Here index 0 has the
    # highest uncertainty chosen but both methods (uncertainties1 and uncertainties2)
    streaming_prediction = [
        np.array([0.98]),
        np.array([0.87, 0.68]),
        np.array([0.96, 0.54])
    ]
    heuristic = BALD()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [0, 3, 1, 2, 4]
                  ), "reorder list for BALD is not right {}".format(ranks)

    heuristic = Variance()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [0, 3, 1, 2, 4]
                  ), "reorder list for Variance is not right {}".format(ranks)

    heuristic = Entropy()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [0, 3, 1, 2, 4]
                  ), "reorder list for Entropy is not right {}".format(ranks)

    heuristic = Margin()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [4, 2, 1, 3, 0]
                  ), "reorder list for Margin is not right {}".format(ranks)

    heuristic = Certainty()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [4, 2, 1, 3, 0]
                  ), "reorder list for Certainty is not right {}".format(ranks)

    heuristic = Random()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert ranks.size == 5, "reorder list for Random is not right {}".format(
        ranks)
Example #9
0
    def wrapped(_, logits):
        return logits

    probability_distribution = wrapped(None, logits)
    assert np.alltrue((probability_distribution >= 0)
                      & (probability_distribution <= 1)).all()


def test_that_precomputed_passes_back_predictions():
    precomputed = Precomputed()
    ranks = np.arange(10)
    assert (precomputed(ranks) == ranks).all()


@pytest.mark.parametrize('heuristic1, heuristic2, weights',
                         [(BALD(), Variance(), [0.7, 0.3]),
                          (BALD(), Entropy(reduction='mean'), [0.9, 0.8]),
                          (Entropy(), Variance(), [4, 8]),
                          (Certainty(), Variance(), [9, 2]),
                          (Certainty(), Certainty(reduction='mean'), [1, 3])])
def test_combine_heuristics(heuristic1, heuristic2, weights):
    np.random.seed(1337)
    predictions = [distributions_3d, distributions_5d]

    if isinstance(heuristic1,
                  Certainty) and not isinstance(heuristic2, Certainty):
        with pytest.raises(Exception) as e_info:
            heuristics = CombineHeuristics([heuristic1, heuristic2],
                                           weights=weights,
                                           reduction='mean')
            assert 'heuristics should have the same value for `revesed` parameter' in str(
Example #10
0
    nn.Dropout(),
    nn.Linear(512, 512),
    nn.Dropout(),
    nn.Linear(512, 10),
)
model = patch_module(model)  # Set dropout layers for MC-Dropout.
if use_cuda:
    model = model.cuda()
wrapper = ModelWrapper(model=model, criterion=nn.CrossEntropyLoss())
optimizer = optim.SGD(model.parameters(),
                      lr=0.001,
                      momentum=0.9,
                      weight_decay=5e-4)

# We will use BALD as our heuristic as it is a great tradeoff between performance and efficiency.
bald = BALD()
# Setup our active learning loop for our experiments
al_loop = ActiveLearningLoop(
    dataset=al_dataset,
    get_probabilities=wrapper.predict_on_dataset,
    heuristic=bald,
    query_size=100,  # We will label 100 examples per step.
    # KWARGS for predict_on_dataset
    iterations=20,  # 20 sampling for MC-Dropout
    batch_size=32,
    use_cuda=use_cuda,
    verbose=False,
)

# Following Gal 2016, we reset the weights at the beginning of each step.
initial_weights = deepcopy(model.state_dict())
Example #11
0
    def __init__(
        self,
        labelled: Optional[DataModule] = None,
        heuristic: "AbstractHeuristic" = BALD(),
        map_dataset_to_labelled: Optional[Callable] = dataset_to_non_labelled_tensor,
        filter_unlabelled_data: Optional[Callable] = filter_unlabelled_data,
        initial_num_labels: Optional[int] = None,
        query_size: int = 1,
        val_split: Optional[float] = None,
    ):
        """The `ActiveLearningDataModule` handles data manipulation for ActiveLearning.

        Args:
            labelled: DataModule containing labelled train data for research use-case.
                The labelled data would be masked.
            heuristic: Sorting algorithm used to rank samples on how likely they can help with model performance.
            map_dataset_to_labelled: Function used to emulate masking on labelled dataset.
            filter_unlabelled_data: Function used to filter the unlabelled data while computing uncertainties.
            initial_num_labels: Number of samples to randomly label to start the training with.
            query_size: Number of samples to be labelled at each Active Learning loop based on the fed heuristic.
            val_split: Float to split train dataset into train and validation set.
        """
        super().__init__(batch_size=1)
        self.labelled = labelled
        self.heuristic = heuristic
        self.map_dataset_to_labelled = map_dataset_to_labelled
        self.filter_unlabelled_data = filter_unlabelled_data
        self.initial_num_labels = initial_num_labels
        self.query_size = query_size
        self.val_split = val_split
        self._dataset: Optional[ActiveLearningDataset] = None

        if not self.labelled:
            raise MisconfigurationException("The labelled `datamodule` should be provided.")

        if not self.labelled.num_classes:
            raise MisconfigurationException("The labelled dataset should be labelled")

        if self.labelled and (self.labelled._val_input or self.labelled._predict_input):
            raise MisconfigurationException("The labelled `datamodule` should have only train data.")

        self._dataset = ActiveLearningDataset(
            self.labelled._train_input, labelled=self.map_dataset_to_labelled(self.labelled._train_input)
        )

        if not self.val_split or not self.has_labelled_data:
            self.val_dataloader = None
        elif self.val_split < 0 or self.val_split > 1:
            raise MisconfigurationException("The `val_split` should a float between 0 and 1.")

        if self.labelled._test_input:
            self.test_dataloader = self._test_dataloader

        if hasattr(self.labelled, "on_after_batch_transfer"):
            self.on_after_batch_transfer = self.labelled.on_after_batch_transfer

        if not self.initial_num_labels:
            warnings.warn(
                "No labels provided for the initial step," "the estimated uncertainties are unreliable!", UserWarning
            )
        else:
            self._dataset.label_randomly(self.initial_num_labels)