def test_compute(self, input_params, expected_avg): dice_metric = MeanDice(**input_params) y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) y = torch.Tensor([[[0], [1]], [[0], [1]]]) dice_metric.update([y_pred, y]) y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) y = torch.Tensor([[[0], [1]], [[1], [0]]]) dice_metric.update([y_pred, y]) avg_dice = dice_metric.compute() self.assertAlmostEqual(avg_dice, expected_avg, places=4)
def test_compute(self, input_params, expected_avg, details_shape): dice_metric = MeanDice(**input_params) # set up engine def _val_func(engine, batch): pass engine = Engine(_val_func) dice_metric.attach(engine=engine, name="mean_dice") y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) y = torch.Tensor([[[0], [1]], [[0], [1]]]) dice_metric.update([y_pred, y]) y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] y = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] dice_metric.update([y_pred, y]) avg_dice = dice_metric.compute() self.assertAlmostEqual(avg_dice, expected_avg, places=4) self.assertTupleEqual( tuple(engine.state.metric_details["mean_dice"].shape), details_shape)