def test_ill_shape(self): loss = TverskyLoss() with self.assertRaisesRegex(AssertionError, ""): loss.forward(torch.ones((2, 2, 3)), torch.ones((4, 5, 6))) chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertRaisesRegex(ValueError, ""): TverskyLoss(reduction="unknown")(chn_input, chn_target) with self.assertRaisesRegex(ValueError, ""): TverskyLoss(reduction=None)(chn_input, chn_target)
def test_input_warnings(self): chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertWarns(Warning): loss = TverskyLoss(include_background=False) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = TverskyLoss(softmax=True) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = TverskyLoss(to_onehot_y=True) loss.forward(chn_input, chn_target)
def test_shape(self, input_param, input_data, expected_val): result = TverskyLoss(**input_param).forward(**input_data) self.assertAlmostEqual(result.item(), expected_val, places=4)
def test_shape(self, input_param, input_data, expected_val): result = TverskyLoss(**input_param).forward(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)
def test_script(self): loss = TverskyLoss() test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input)