def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): MaskedLoss(loss=[]) dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) with self.assertRaisesRegex(ValueError, ""): masked = MaskedLoss(loss=dice_loss) masked(input=torch.zeros((3, 1, 2, 2)), target=torch.zeros((3, 1, 2, 2)), mask=torch.zeros((3, 3, 2, 2))) with self.assertRaisesRegex(ValueError, ""): masked = MaskedLoss(loss=dice_loss) masked(input=torch.zeros((3, 3, 2, 2)), target=torch.zeros((3, 2, 2, 2)), mask=torch.zeros((3, 3, 2, 2)))
def test_script(self): input_param, expected_val = TEST_CASES[0] size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) label = torch.argmax(label, dim=1, keepdim=True) pred = torch.randn(size) loss = MaskedLoss(**input_param) test_script_save(loss, pred, label)
def test_shape(self, input_param, expected_val): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) label = torch.argmax(label, dim=1, keepdim=True) pred = torch.randn(size) result = MaskedLoss(**input_param)(pred, label, None) out = result.detach().cpu().numpy() checked = np.allclose(out, expected_val[0][0]) or np.allclose(out, expected_val[0][1]) self.assertTrue(checked) mask = torch.randint(low=0, high=2, size=label.shape) result = MaskedLoss(**input_param)(pred, label, mask) out = result.detach().cpu().numpy() checked = np.allclose(out, expected_val[1][0]) or np.allclose(out, expected_val[1][1]) self.assertTrue(checked)
def __init__(self, *args, **kwargs) -> None: """ Args follow :py:class:`monai.losses.DiceLoss`. """ super().__init__(*args, **kwargs) self.spatial_weighted = MaskedLoss(loss=super().forward)