Beispiel #1
0
    def test_compute(self, input_params, expected_avg):
        dice_metric = MeanDice(**input_params)

        y_pred = torch.Tensor([[[0], [1]], [[1], [0]]])
        y = torch.Tensor([[[0], [1]], [[0], [1]]])
        dice_metric.update([y_pred, y])

        y_pred = torch.Tensor([[[0], [1]], [[1], [0]]])
        y = torch.Tensor([[[0], [1]], [[1], [0]]])
        dice_metric.update([y_pred, y])

        avg_dice = dice_metric.compute()
        self.assertAlmostEqual(avg_dice, expected_avg, places=4)
    def test_compute(self, input_params, expected_avg, details_shape):
        dice_metric = MeanDice(**input_params)

        # set up engine

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        dice_metric.attach(engine=engine, name="mean_dice")

        y_pred = torch.Tensor([[[0], [1]], [[1], [0]]])
        y = torch.Tensor([[[0], [1]], [[0], [1]]])
        dice_metric.update([y_pred, y])

        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
        y = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
        dice_metric.update([y_pred, y])

        avg_dice = dice_metric.compute()
        self.assertAlmostEqual(avg_dice, expected_avg, places=4)
        self.assertTupleEqual(
            tuple(engine.state.metric_details["mean_dice"].shape),
            details_shape)