def test_misconfig(self, input_params): with self.assertRaisesRegex(ValueError, 'compatib'): dice_metric = MeanDice(**input_params) y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.ones((2, 1)) dice_metric.update([y_pred, y])
def test_shape_mismatch(self, input_params, _expected): dice_metric = MeanDice(**input_params) with self.assertRaises((AssertionError, ValueError)): y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.ones((2, 3)) dice_metric.update([y_pred, y]) with self.assertRaises((AssertionError, ValueError)): y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.ones((3, 2)) dice_metric.update([y_pred, y])
def test_compute(self, input_params, expected_avg): dice_metric = MeanDice(**input_params) y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.ones((2, 1)) dice_metric.update([y_pred, y]) y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.Tensor([[1.], [0.]]) dice_metric.update([y_pred, y]) avg_dice = dice_metric.compute() self.assertAlmostEqual(avg_dice, expected_avg)