def test_reduction(self):
        sur_metric = SurfaceDistance(include_background=True,
                                     reduction="mean_channel")

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        sur_metric.attach(engine, "surface_distance")

        y_pred, y = TEST_SAMPLE_1
        sur_metric.update([y_pred, y])
        y_pred, y = TEST_SAMPLE_2
        sur_metric.update([y_pred, y])
        torch.testing.assert_allclose(sur_metric.compute().float(),
                                      torch.tensor([4.1713, 0.0000]))
    def test_compute(self):
        sur_metric = SurfaceDistance(include_background=True)

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        sur_metric.attach(engine, "surface_distance")

        y_pred, y = TEST_SAMPLE_1
        sur_metric.update([y_pred, y])
        self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4)
        y_pred, y = TEST_SAMPLE_2
        sur_metric.update([y_pred, y])
        self.assertAlmostEqual(sur_metric.compute(), 2.08566, places=4)
        y_pred, y = TEST_SAMPLE_3
        sur_metric.update([y_pred, y])
        self.assertAlmostEqual(sur_metric.compute(), float("inf"))
        y_pred, y = TEST_SAMPLE_4
        sur_metric.update([y_pred, y])
        self.assertAlmostEqual(sur_metric.compute(), float("inf"))
 def test_shape_mismatch(self):
     sur_metric = SurfaceDistance(include_background=True)
     with self.assertRaises((AssertionError, ValueError)):
         y_pred = TEST_SAMPLE_1[0]
         y = torch.ones((1, 1, 10, 10, 10))
         sur_metric.update([y_pred, y])
Ejemplo n.º 4
0
 def test_compute(self):
     sur_metric = SurfaceDistance(include_background=True)
     y_pred, y = TEST_SAMPLE_1
     sur_metric.update([y_pred, y])
     self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4)
     y_pred, y = TEST_SAMPLE_2
     sur_metric.update([y_pred, y])
     self.assertAlmostEqual(sur_metric.compute(), 2.08566, places=4)
     y_pred, y = TEST_SAMPLE_3
     sur_metric.update([y_pred, y])
     self.assertAlmostEqual(sur_metric.compute(), float("inf"))
     self.assertAlmostEqual(sur_metric._num_examples, 3)
     y_pred, y = TEST_SAMPLE_4
     sur_metric.update([y_pred, y])
     self.assertAlmostEqual(sur_metric._num_examples, 3)