def test_ill_opts(self):
     pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float)
     target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float)
     with self.assertRaisesRegex(ValueError, ""):
         LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(pred, target)
     with self.assertRaisesRegex(ValueError, ""):
         LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(pred, target)
     with self.assertRaisesRegex(ValueError, ""):
         LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(pred, target)
     with self.assertRaisesRegex(ValueError, ""):
         LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(pred, target)
     with self.assertRaisesRegex(ValueError, ""):
         LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(pred, target)
 def test_ill_shape(self):
     loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3)
     # in_channel unmatch
     with self.assertRaisesRegex(ValueError, ""):
         loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float))
     # ndim unmatch
     with self.assertRaisesRegex(ValueError, ""):
         loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float))
     # pred, target shape unmatch
     with self.assertRaisesRegex(ValueError, ""):
         loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float))
 def test_shape(self, input_param, input_data, expected_val):
     result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(
         **input_data)
     np.testing.assert_allclose(result.detach().cpu().numpy(),
                                expected_val,
                                rtol=1e-5)