예제 #1
0
 def test_ill_opts(self):
     pred = torch.rand(1, 3, 5, 5, 5)
     with self.assertRaisesRegex(ValueError, ""):
         BendingEnergyLoss(reduction="unknown")(pred)
     with self.assertRaisesRegex(ValueError, ""):
         BendingEnergyLoss(reduction=None)(pred)
예제 #2
0
 def test_shape(self, input_param, input_data, expected_val):
     result = BendingEnergyLoss(**input_param).forward(**input_data)
     np.testing.assert_allclose(result.detach().cpu().numpy(),
                                expected_val,
                                rtol=1e-5)
예제 #3
0
 def test_ill_shape(self):
     loss = BendingEnergyLoss()
     # not in 3-d, 4-d, 5-d
     with self.assertRaisesRegex(ValueError, ""):
         loss.forward(torch.ones((1, 3)))
     with self.assertRaisesRegex(ValueError, ""):
         loss.forward(torch.ones((1, 3, 5, 5, 5, 5)))
     # spatial_dim < 5
     with self.assertRaisesRegex(ValueError, ""):
         loss.forward(torch.ones((1, 3, 4, 5, 5)))
     with self.assertRaisesRegex(ValueError, ""):
         loss.forward(torch.ones((1, 3, 5, 4, 5)))
     with self.assertRaisesRegex(ValueError, ""):
         loss.forward(torch.ones((1, 3, 5, 5, 4)))
예제 #4
0
    def test_ill_shape(self):
        loss = BendingEnergyLoss()
        # not in 3-d, 4-d, 5-d
        with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
            loss.forward(torch.ones((1, 3), device=device))
        with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
            loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
        with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
            loss.forward(torch.ones((1, 3, 4, 5, 5), device=device))
        with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
            loss.forward(torch.ones((1, 3, 5, 4, 5)))
        with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
            loss.forward(torch.ones((1, 3, 5, 5, 4)))

        # number of vector components unequal to number of spatial dims
        with self.assertRaisesRegex(ValueError, "Number of vector components"):
            loss.forward(torch.ones((1, 2, 5, 5, 5)))
        with self.assertRaisesRegex(ValueError, "Number of vector components"):
            loss.forward(torch.ones((1, 2, 5, 5, 5)))