示例#1
0
def test_integration():
    transform_pipeline = Compose([Resize((64, 64)), ToTensor()])
    cifar10_train = DummyDataset(transform_pipeline)
    cifar10_test = DummyDataset(transform_pipeline)

    al_dataset = ActiveLearningDataset(
        cifar10_train, pool_specifics={'transform': transform_pipeline})
    al_dataset.label_randomly(10)

    use_cuda = False
    model = vgg.vgg16(pretrained=False, num_classes=10)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=0.001,
                          momentum=0.9,
                          weight_decay=0.0005)

    # We can now use BaaL to create the active learning loop.

    model = ModelWrapper(model, criterion)
    # We create an ActiveLearningLoop that will automatically label the most uncertain samples.
    # In this case, we use the widely used BALD heuristic.

    active_loop = ActiveLearningLoop(al_dataset,
                                     model.predict_on_dataset,
                                     heuristic=heuristics.BALD(),
                                     ndata_to_label=10,
                                     batch_size=10,
                                     iterations=10,
                                     use_cuda=use_cuda,
                                     workers=4)

    # We're all set!
    num_steps = 10
    for step in range(num_steps):
        old_param = list(map(lambda x: x.clone(), model.model.parameters()))
        model.train_on_dataset(al_dataset,
                               optimizer=optimizer,
                               batch_size=10,
                               epoch=5,
                               use_cuda=use_cuda,
                               workers=2)
        model.test_on_dataset(cifar10_test,
                              batch_size=10,
                              use_cuda=use_cuda,
                              workers=2)

        if not active_loop.step():
            break
        new_param = list(map(lambda x: x.clone(), model.model.parameters()))
        assert any([
            not np.allclose(i.detach(), j.detach())
            for i, j in zip(old_param, new_param)
        ])
    assert step == 4  # 10 + (4 * 10) = 50, so it stops at iterations 4
示例#2
0
def test_deprecation():
    heur = heuristics.BALD()
    ds = MyDataset()
    dataset = ActiveLearningDataset(ds, make_unlabelled=lambda x: -1)
    with warnings.catch_warnings(record=True) as w:
        active_loop = ActiveLearningLoop(dataset,
                                         get_probs_iter,
                                         heur,
                                         ndata_to_label=10,
                                         dummy_param=1)
        assert issubclass(w[-1].category, DeprecationWarning)
        assert "ndata_to_label" in str(w[-1].message)
示例#3
0
def test_file_saving(tmpdir):
    tmpdir = str(tmpdir)
    heur = heuristics.BALD()
    ds = MyDataset()
    dataset = ActiveLearningDataset(ds, make_unlabelled=lambda x: -1)
    active_loop = ActiveLearningLoop(dataset,
                                     get_probs_iter,
                                     heur,
                                     uncertainty_folder=tmpdir,
                                     query_size=10,
                                     dummy_param=1)
    dataset.label_randomly(10)
    _ = active_loop.step()
    assert len(os.listdir(tmpdir)) == 1
    file = pjoin(tmpdir, os.listdir(tmpdir)[0])
    assert "pool=90" in file and "labelled=10" in file
    data = pickle.load(open(file, 'rb'))
    assert len(data['uncertainty']) == 90
    # The diff between the current state and the step before is the newly labelled item.
    assert (data['dataset']['labelled'] != dataset.labelled).sum() == 10
示例#4
0
    return None


def get_probs_iter(pool, dummy_param=None):
    assert dummy_param is not None
    if len(pool) == 0:
        return None
    for x in pool:
        b = np.zeros([1, 3, 10])
        b[:, x % 3, :] = 1
        yield b


@pytest.mark.parametrize('heur', [
    heuristics.Random(),
    heuristics.BALD(),
    heuristics.Entropy(),
    heuristics.Variance(reduction='sum')
])
def test_should_stop(heur):
    dataset = ActiveLearningDataset(MyDataset(), make_unlabelled=lambda x: -1)
    active_loop = ActiveLearningLoop(dataset,
                                     get_probs,
                                     heur,
                                     query_size=10,
                                     dummy_param=1)
    dataset.label_randomly(10)
    step = 0
    for _ in range(15):
        flg = active_loop.step()
        step += 1