Example #1
0
    def test_prob_map_generator(self, name, size):
        # set up dataset
        dataset = TestDataset(name, size)
        data_loader = DataLoader(dataset, batch_size=1)

        # set up engine
        def inference(enging, batch):
            pass

        engine = Engine(inference)

        # add ProbMapGenerator() to evaluator
        output_dir = os.path.join(os.path.dirname(__file__), "testing_data")
        prob_map_gen = ProbMapProducer(output_dir=output_dir)

        evaluator = TestEvaluator(torch.device("cpu:0"), data_loader, size, val_handlers=[prob_map_gen])

        # set up validation handler
        validation = ValidationHandler(interval=1, validator=None)
        validation.attach(engine)
        validation.set_validator(validator=evaluator)

        engine.run(data_loader)

        prob_map = np.load(os.path.join(output_dir, name + ".npy"))
        self.assertListEqual(np.diag(prob_map).astype(int).tolist(), list(range(1, size + 1)))
    def test_content(self):
        data = [0] * 8

        # set up engine
        def _train_func(engine, batch):
            pass

        engine = Engine(_train_func)

        # set up testing handler
        val_data_loader = torch.utils.data.DataLoader(Dataset(data))
        evaluator = TestEvaluator(torch.device("cpu:0"), val_data_loader)
        saver = ValidationHandler(interval=2, validator=evaluator)
        saver.attach(engine)

        engine.run(data, max_epochs=5)
        self.assertEqual(evaluator.state.max_epochs, 4)
        self.assertEqual(evaluator.state.epoch_length, 8)