def test_brisque_loss_raises_if_wrong_reduction(x_grey: torch.Tensor, device: str) -> None: for mode in ['mean', 'sum', 'none']: BRISQUELoss(reduction=mode)(x_grey.to(device)) for mode in [None, 'n', 2]: with pytest.raises(KeyError): BRISQUELoss(reduction=mode)(x_grey.to(device))
def test_brisque_loss_if_works_with_rgb(x_rgb: torch.Tensor, device: str) -> None: x_rgb_grad = x_rgb.clone().to(device) x_rgb_grad.requires_grad_() loss_value = BRISQUELoss()(x_rgb_grad) loss_value.backward() assert torch.isfinite(x_rgb_grad.grad).all(), \ f'Expected non None gradient of leaf variable, got {x_rgb_grad.grad}'
def test_brisque_loss_if_works_with_grey(prediction_grey: torch.Tensor, device: str) -> None: prediction_grey_grad = prediction_grey.clone().to(device) prediction_grey_grad.requires_grad_() loss_value = BRISQUELoss()(prediction_grey_grad) loss_value.backward() assert torch.isfinite(prediction_grey_grad.grad).all(), f'Expected non None gradient of leaf variable, ' \ f'got {prediction_grey_grad.grad}'