def test_input_warnings(self): chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertWarns(Warning): loss = GeneralizedDiceLoss(include_background=False) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = GeneralizedDiceLoss(softmax=True) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = GeneralizedDiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target)
def test_ill_shape(self): loss = GeneralizedDiceLoss() with self.assertRaisesRegex(AssertionError, ""): loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6)))