def test_one_cls_dice_one_cls(self):
     batch = 5
     class_num = 2  # class type: 0,1,2,3
     x = 10
     y = 10
     z = 10
     output = np.random.rand(batch, class_num, x, y, z)
     target = np.random.randint(0, 2, (batch, x, y, z), 'int')
     v_output = Variable(torch.Tensor(output))
     v_target = Variable(torch.Tensor(target))
     dice1 = one_cls_dice(v_output, v_target, 1)
     dice0 = one_cls_dice(v_output, v_target, 0)
     assert dice1 < 1
     assert dice0 < 1
예제 #2
0
def dice_cat2(output, target):
    """
    Calculate dice for all channels.
    :param output:Output dimension: Batch x Channel x X x Y (x Z) float
    :param target:Target dimension: Batch x X x Y (x Z) int:[0, Channel]
    :return:
    """

    return one_cls_dice(output, target, label_idx=2)
    def test_one_cls_dice_identical(self):
        batch = 5
        class_num = 2  # class type: 0,1,2,3
        x = 10
        y = 10
        z = 10
        output = np.random.rand(batch, class_num, x, y, z)
        target = np.argmax(output, 1)

        v_output = Variable(torch.Tensor(output))
        v_target = Variable(torch.Tensor(target))
        dice1 = one_cls_dice(v_output, v_target, 1)
        dice2 = one_cls_dice(v_output, v_target, 2)
        dice3 = one_cls_dice(v_output, v_target, 3)
        dice0 = one_cls_dice(v_output, v_target, 0)

        assert math.isclose(dice1, 1, rel_tol=0.001)
        assert math.isclose(dice2, 0, rel_tol=0.001)
        assert math.isclose(dice3, 0, rel_tol=0.001)
        assert math.isclose(dice0, 1, rel_tol=0.001)

        class_num = 4
        output = np.random.rand(batch, class_num, x, y, z)
        target = np.argmax(output, 1)

        v_output = Variable(torch.Tensor(output))
        v_target = Variable(torch.Tensor(target))
        dice1 = one_cls_dice(v_output, v_target, 1)
        dice2 = one_cls_dice(v_output, v_target, 2)
        dice3 = one_cls_dice(v_output, v_target, 3)
        dice0 = one_cls_dice(v_output, v_target, 0)

        assert math.isclose(dice1, 1, rel_tol=0.001)
        assert math.isclose(dice2, 1, rel_tol=0.001)
        assert math.isclose(dice3, 1, rel_tol=0.001)
        assert math.isclose(dice0, 1, rel_tol=0.001)