def test_mdsi_loss(input_tensors: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: x, y = input_tensors x.requires_grad_() loss = MDSILoss(data_range=1.)(x=x.to(device), y=y.to(device)) loss.backward() assert torch.isfinite( x.grad).all(), f'Expected finite gradient values, got {x.grad}'
def test_mdsi_loss(input_tensors: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: prediction, target = input_tensors prediction.requires_grad_() loss = MDSILoss(data_range=1.)(prediction=prediction.to(device), target=target.to(device)) loss.backward() assert torch.isfinite(prediction.grad).all( ), f'Expected finite gradient values, got {prediction.grad}'
def test_mdsi_loss_compare_with_matlab(input_images_score: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], combination: str, device: str) -> None: x, y, y_value = input_images_score y_value = y_value[combination] x = x.requires_grad_() score = MDSILoss(data_range=255, combination=combination)(x=x.to(device), y=y.to(device)) score.backward() assert torch.isclose(score, 1. - y_value.to(score)), f'The estimated value must be equal to MATLAB ' \ f'provided one, got {score.item():.8f}, ' \ f'while MATLAB equals {1. - y_value}' assert torch.isfinite(x.grad).all(), f'Expected finite gradient values, got {x.grad}'