def test_predict(a_data_module, a_pl_module): trainer = BaalTrainer(dataset=a_data_module.active_dataset, max_epochs=3, default_root_dir='/tmp') active_set = a_data_module.active_dataset alt = trainer.predict_on_dataset(a_pl_module, a_data_module.pool_dataloader()) assert len(alt) == len(active_set.pool) # Replicate = False works too! a_pl_module.hparams.replicate_in_memory = False alt = trainer.predict_on_dataset(a_pl_module, a_data_module.pool_dataloader()) assert len(alt) == len(active_set.pool)
def test_predict(): ckpt = {} hparams = HParams() dataset = DummyDataset() active_set = ActiveLearningDataset(dataset) active_set.label_randomly(10) model = DummyPytorchLightning(active_set, hparams) save_chkp = model.on_save_checkpoint(ckpt) trainer = BaalTrainer(dataset=active_set, max_epochs=3, default_root_dir='/tmp', callbacks=[ResetCallback(copy.deepcopy(save_chkp))]) trainer.model = model alt = trainer.predict_on_dataset() assert len(alt) == len(active_set.pool) assert 'active_dataset' in save_chkp n_labelled = len(active_set) copy_save_chkp = copy.deepcopy(save_chkp) active_set.label_randomly(5) model.on_load_checkpoint(copy_save_chkp) assert len(active_set) == n_labelled