示例#1
0
    def test_ignore_index_loss_with_dice_coeff(self):
        loss = DiceCoefficient(ignore_index=-1)
        input = torch.zeros((3, 3))
        input[1, 1] = 1.
        target = -1. * torch.ones((3, 3))
        target[1, 1] = 1.

        actual = loss(input, target)

        target = input.clone()
        expected = loss(input, target)

        assert expected == actual
示例#2
0
 def test_dice_coefficient(self):
     results = _compute_criterion(DiceCoefficient())
     # check that all of the coefficients belong to [0, 1]
     results = np.array(results)
     assert np.all(results > 0)
     assert np.all(results < 1)