def test_reduction(self):
        sur_metric = SurfaceDistance(include_background=True,
                                     reduction="mean_channel")

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        sur_metric.attach(engine, "surface_distance")

        y_pred, y = TEST_SAMPLE_1
        sur_metric.update([y_pred, y])
        y_pred, y = TEST_SAMPLE_2
        sur_metric.update([y_pred, y])
        torch.testing.assert_allclose(sur_metric.compute().float(),
                                      torch.tensor([4.1713, 0.0000]))
    def test_compute(self):
        sur_metric = SurfaceDistance(include_background=True)

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        sur_metric.attach(engine, "surface_distance")

        y_pred, y = TEST_SAMPLE_1
        sur_metric.update([y_pred, y])
        self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4)
        y_pred, y = TEST_SAMPLE_2
        sur_metric.update([y_pred, y])
        self.assertAlmostEqual(sur_metric.compute(), 2.08566, places=4)
        y_pred, y = TEST_SAMPLE_3
        sur_metric.update([y_pred, y])
        self.assertAlmostEqual(sur_metric.compute(), float("inf"))
        y_pred, y = TEST_SAMPLE_4
        sur_metric.update([y_pred, y])
        self.assertAlmostEqual(sur_metric.compute(), float("inf"))