def update(self, output: Sequence[Union[torch.Tensor, dict]]): assert len(output) == 2, 'MeanDice metric can only support y_pred and y.' y_pred, y = output average = compute_meandice(y_pred, y, self.include_background, self.to_onehot_y, self.mutually_exclusive, self.add_sigmoid, self.logit_thresh) batch_size = len(y) self._sum += average.item() * batch_size self._num_examples += batch_size
def update(self, output: Sequence[Union[torch.Tensor, dict]]): assert len( output) == 2, 'MeanDice metric can only support y_pred and y.' y_pred, y = output scores = compute_meandice(y_pred, y, self.include_background, self.to_onehot_y, self.mutually_exclusive, self.add_sigmoid, self.logit_thresh) # add all items in current batch for batch in scores: not_nan = ~torch.isnan(batch) if not_nan.sum() == 0: continue class_avg = batch[not_nan].mean().item() self._sum += class_avg self._num_examples += 1
def test_value(self, input_data, expected_value): result = compute_meandice(**input_data) self.assertAlmostEqual(result.item(), expected_value, places=4)
def test_nans(self, input_data, expected_value): result = compute_meandice(**input_data) self.assertTrue( np.allclose(np.isnan(result.cpu().numpy()), expected_value))
def test_value(self, input_data, expected_value): result = compute_meandice(**input_data) self.assertTrue( np.allclose(result.cpu().numpy(), expected_value, atol=1e-4))