def test_ill_opts(self): pred = torch.rand(1, 3, 5, 5, 5) with self.assertRaisesRegex(ValueError, ""): BendingEnergyLoss(reduction="unknown")(pred) with self.assertRaisesRegex(ValueError, ""): BendingEnergyLoss(reduction=None)(pred)
def test_shape(self, input_param, input_data, expected_val): result = BendingEnergyLoss(**input_param).forward(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3))) with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 5, 5, 5))) # spatial_dim < 5 with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 4, 5, 5))) with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 4, 5))) with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 5, 4)))
def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): loss.forward(torch.ones((1, 3), device=device)) with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 4, 5, 5), device=device)) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 4, 5))) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 5, 4))) # number of vector components unequal to number of spatial dims with self.assertRaisesRegex(ValueError, "Number of vector components"): loss.forward(torch.ones((1, 2, 5, 5, 5))) with self.assertRaisesRegex(ValueError, "Number of vector components"): loss.forward(torch.ones((1, 2, 5, 5, 5)))