def test_reduction(self): sur_metric = SurfaceDistance(include_background=True, reduction="mean_channel") def _val_func(engine, batch): pass engine = Engine(_val_func) sur_metric.attach(engine, "surface_distance") y_pred, y = TEST_SAMPLE_1 sur_metric.update([y_pred, y]) y_pred, y = TEST_SAMPLE_2 sur_metric.update([y_pred, y]) torch.testing.assert_allclose(sur_metric.compute().float(), torch.tensor([4.1713, 0.0000]))
def test_compute(self): sur_metric = SurfaceDistance(include_background=True) def _val_func(engine, batch): pass engine = Engine(_val_func) sur_metric.attach(engine, "surface_distance") y_pred, y = TEST_SAMPLE_1 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4) y_pred, y = TEST_SAMPLE_2 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), 2.08566, places=4) y_pred, y = TEST_SAMPLE_3 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), float("inf")) y_pred, y = TEST_SAMPLE_4 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), float("inf"))
def test_shape_mismatch(self): sur_metric = SurfaceDistance(include_background=True) with self.assertRaises((AssertionError, ValueError)): y_pred = TEST_SAMPLE_1[0] y = torch.ones((1, 1, 10, 10, 10)) sur_metric.update([y_pred, y])
def test_compute(self): sur_metric = SurfaceDistance(include_background=True) y_pred, y = TEST_SAMPLE_1 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4) y_pred, y = TEST_SAMPLE_2 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), 2.08566, places=4) y_pred, y = TEST_SAMPLE_3 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), float("inf")) self.assertAlmostEqual(sur_metric._num_examples, 3) y_pred, y = TEST_SAMPLE_4 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric._num_examples, 3)