コード例 #1
0
ファイル: test_masked_loss.py プロジェクト: Nic-Ma/MONAI
    def test_ill_opts(self):
        with self.assertRaisesRegex(ValueError, ""):
            MaskedLoss(loss=[])

        dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5)
        with self.assertRaisesRegex(ValueError, ""):
            masked = MaskedLoss(loss=dice_loss)
            masked(input=torch.zeros((3, 1, 2, 2)), target=torch.zeros((3, 1, 2, 2)), mask=torch.zeros((3, 3, 2, 2)))
        with self.assertRaisesRegex(ValueError, ""):
            masked = MaskedLoss(loss=dice_loss)
            masked(input=torch.zeros((3, 3, 2, 2)), target=torch.zeros((3, 2, 2, 2)), mask=torch.zeros((3, 3, 2, 2)))
コード例 #2
0
 def test_script(self):
     input_param, expected_val = TEST_CASES[0]
     size = [3, 3, 5, 5]
     label = torch.randint(low=0, high=2, size=size)
     label = torch.argmax(label, dim=1, keepdim=True)
     pred = torch.randn(size)
     loss = MaskedLoss(**input_param)
     test_script_save(loss, pred, label)
コード例 #3
0
ファイル: test_masked_loss.py プロジェクト: Nic-Ma/MONAI
    def test_shape(self, input_param, expected_val):
        size = [3, 3, 5, 5]
        label = torch.randint(low=0, high=2, size=size)
        label = torch.argmax(label, dim=1, keepdim=True)
        pred = torch.randn(size)
        result = MaskedLoss(**input_param)(pred, label, None)
        out = result.detach().cpu().numpy()
        checked = np.allclose(out, expected_val[0][0]) or np.allclose(out, expected_val[0][1])
        self.assertTrue(checked)

        mask = torch.randint(low=0, high=2, size=label.shape)
        result = MaskedLoss(**input_param)(pred, label, mask)
        out = result.detach().cpu().numpy()
        checked = np.allclose(out, expected_val[1][0]) or np.allclose(out, expected_val[1][1])
        self.assertTrue(checked)
コード例 #4
0
 def __init__(self, *args, **kwargs) -> None:
     """
     Args follow :py:class:`monai.losses.DiceLoss`.
     """
     super().__init__(*args, **kwargs)
     self.spatial_weighted = MaskedLoss(loss=super().forward)