コード例 #1
0
ファイル: test_loss.py プロジェクト: chncwang/mindspore
def test_multi_class_dice_loss():
    """ test_multi_class_dice_loss """
    loss = nn.MultiClassDiceLoss(weights=None,
                                 ignore_indiex=None,
                                 activation="softmax")
    y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]),
                    mstype.float32)
    y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
    loss(y_pred, y)
コード例 #2
0
ファイル: test_loss.py プロジェクト: chncwang/mindspore
def test_multi_class_dice_loss_init_activation2():
    """ test_multi_class_dice_loss """
    with pytest.raises(KeyError):
        loss = nn.MultiClassDiceLoss(weights=None,
                                     ignore_indiex=None,
                                     activation='www')
        y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]),
                        mstype.float32)
        y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
        loss(y_pred, y)
コード例 #3
0
ファイル: test_loss.py プロジェクト: chncwang/mindspore
def test_multi_class_dice_loss_check_shape():
    """ test_multi_class_dice_loss """
    loss = nn.MultiClassDiceLoss(weights=None,
                                 ignore_indiex=None,
                                 activation="softmax")
    y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]),
                    mstype.float32)
    y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
    with pytest.raises(ValueError):
        loss(y_pred, y)