def test_ill_opts(self): pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(pred, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(pred, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(pred, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(pred, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(pred, target)
def test_ill_shape(self): loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) # in_channel unmatch with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float)) # ndim unmatch with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float)) # pred, target shape unmatch with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float))
def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward( **input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)