def test_focal_loss_input(): """ test_FocalLoss """ x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) x2 = Tensor([[1]], mstype.int32) focalloss = nn.FocalLoss(weight=None, gamma=2.0, reduction='mean') with pytest.raises(ValueError): focalloss(x1, x2)
def test_focal_loss(): """ test_FocalLoss """ x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) x2 = Tensor([[1], [1], [0]], mstype.int32) focalloss = nn.FocalLoss() focalloss(x1, x2)