def test_pl_step():
    hparams = HParams()
    dataset = DummyDataset()
    active_set = ActiveLearningDataset(dataset)
    active_set.label_randomly(10)
    model = DummyPytorchLightning(active_set, hparams)
    ckpt = {}
    save_chkp = model.on_save_checkpoint(ckpt)
    trainer = BaalTrainer(dataset=active_set,
                          max_epochs=3, default_root_dir='/tmp',
                          ndata_to_label=hparams.query_size,
                          callbacks=[ResetCallback(copy.deepcopy(save_chkp))])
    trainer.model = model

    before = len(active_set)
    trainer.step()
    after = len(active_set)

    assert after - before == hparams.query_size
def test_predict():
    ckpt = {}
    hparams = HParams()
    dataset = DummyDataset()
    active_set = ActiveLearningDataset(dataset)
    active_set.label_randomly(10)
    model = DummyPytorchLightning(active_set, hparams)
    save_chkp = model.on_save_checkpoint(ckpt)
    trainer = BaalTrainer(dataset=active_set,
                          max_epochs=3, default_root_dir='/tmp',
                          callbacks=[ResetCallback(copy.deepcopy(save_chkp))])
    trainer.model = model
    alt = trainer.predict_on_dataset()
    assert len(alt) == len(active_set.pool)
    assert 'active_dataset' in save_chkp
    n_labelled = len(active_set)
    copy_save_chkp = copy.deepcopy(save_chkp)
    active_set.label_randomly(5)

    model.on_load_checkpoint(copy_save_chkp)
    assert len(active_set) == n_labelled