def test_prob_map_generator(self, name, size): # set up dataset dataset = TestDataset(name, size) data_loader = DataLoader(dataset, batch_size=1) # set up engine def inference(enging, batch): pass engine = Engine(inference) # add ProbMapGenerator() to evaluator output_dir = os.path.join(os.path.dirname(__file__), "testing_data") prob_map_gen = ProbMapProducer(output_dir=output_dir) evaluator = TestEvaluator(torch.device("cpu:0"), data_loader, size, val_handlers=[prob_map_gen]) # set up validation handler validation = ValidationHandler(interval=1, validator=None) validation.attach(engine) validation.set_validator(validator=evaluator) engine.run(data_loader) prob_map = np.load(os.path.join(output_dir, name + ".npy")) self.assertListEqual(np.diag(prob_map).astype(int).tolist(), list(range(1, size + 1)))
def test_content(self): data = [0] * 8 # set up engine def _train_func(engine, batch): pass engine = Engine(_train_func) # set up testing handler val_data_loader = torch.utils.data.DataLoader(Dataset(data)) evaluator = TestEvaluator(torch.device("cpu:0"), val_data_loader) saver = ValidationHandler(interval=2, validator=evaluator) saver.attach(engine) engine.run(data, max_epochs=5) self.assertEqual(evaluator.state.max_epochs, 4) self.assertEqual(evaluator.state.epoch_length, 8)