def test_pl_step(): hparams = HParams() dataset = DummyDataset() active_set = ActiveLearningDataset(dataset) active_set.label_randomly(10) model = DummyPytorchLightning(active_set, hparams) ckpt = {} save_chkp = model.on_save_checkpoint(ckpt) trainer = BaalTrainer(dataset=active_set, max_epochs=3, default_root_dir='/tmp', ndata_to_label=hparams.query_size, callbacks=[ResetCallback(copy.deepcopy(save_chkp))]) trainer.model = model before = len(active_set) trainer.step() after = len(active_set) assert after - before == hparams.query_size
def test_predict(): ckpt = {} hparams = HParams() dataset = DummyDataset() active_set = ActiveLearningDataset(dataset) active_set.label_randomly(10) model = DummyPytorchLightning(active_set, hparams) save_chkp = model.on_save_checkpoint(ckpt) trainer = BaalTrainer(dataset=active_set, max_epochs=3, default_root_dir='/tmp', callbacks=[ResetCallback(copy.deepcopy(save_chkp))]) trainer.model = model alt = trainer.predict_on_dataset() assert len(alt) == len(active_set.pool) assert 'active_dataset' in save_chkp n_labelled = len(active_set) copy_save_chkp = copy.deepcopy(save_chkp) active_set.label_randomly(5) model.on_load_checkpoint(copy_save_chkp) assert len(active_set) == n_labelled