Exemplo n.º 1
0
def test_brisque_loss_raises_if_wrong_reduction(x_grey: torch.Tensor, device: str) -> None:
    for mode in ['mean', 'sum', 'none']:
        BRISQUELoss(reduction=mode)(x_grey.to(device))

    for mode in [None, 'n', 2]:
        with pytest.raises(KeyError):
            BRISQUELoss(reduction=mode)(x_grey.to(device))
Exemplo n.º 2
0
def test_brisque_loss_if_works_with_rgb(x_rgb: torch.Tensor, device: str) -> None:
    x_rgb_grad = x_rgb.clone().to(device)
    x_rgb_grad.requires_grad_()
    loss_value = BRISQUELoss()(x_rgb_grad)
    loss_value.backward()
    assert torch.isfinite(x_rgb_grad.grad).all(), \
        f'Expected non None gradient of leaf variable, got {x_rgb_grad.grad}'
Exemplo n.º 3
0
def test_brisque_loss_if_works_with_grey(prediction_grey: torch.Tensor,
                                         device: str) -> None:
    prediction_grey_grad = prediction_grey.clone().to(device)
    prediction_grey_grad.requires_grad_()
    loss_value = BRISQUELoss()(prediction_grey_grad)
    loss_value.backward()
    assert torch.isfinite(prediction_grey_grad.grad).all(), f'Expected non None gradient of leaf variable, ' \
                                                            f'got {prediction_grey_grad.grad}'