Example #1
0
def test_multi_class_dice_loss():
    """ test_multi_class_dice_loss """
    loss = nn.MultiClassDiceLoss(weights=None,
                                 ignore_indiex=None,
                                 activation="softmax")
    y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]),
                    mstype.float32)
    y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
    loss(y_pred, y)
Example #2
0
def test_multi_class_dice_loss_init_activation2():
    """ test_multi_class_dice_loss """
    with pytest.raises(KeyError):
        loss = nn.MultiClassDiceLoss(weights=None,
                                     ignore_indiex=None,
                                     activation='www')
        y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]),
                        mstype.float32)
        y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
        loss(y_pred, y)
Example #3
0
def test_multi_class_dice_loss_check_shape():
    """ test_multi_class_dice_loss """
    loss = nn.MultiClassDiceLoss(weights=None,
                                 ignore_indiex=None,
                                 activation="softmax")
    y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]),
                    mstype.float32)
    y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
    with pytest.raises(ValueError):
        loss(y_pred, y)