def test_vsi_loss(input_tensors: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: x, y = input_tensors x.requires_grad_() loss = VSILoss(data_range=1.)(x.to(device), y.to(device)) loss.backward() assert torch.isfinite(x.grad).all(), \ f'Expected finite gradient values after back propagation, got {x.grad}'
def test_vsi_loss_zero_for_equal_input(input_tensors: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: x, _ = input_tensors y = x.clone() x.requires_grad_() loss = VSILoss(data_range=1.)(x.to(device), y.to(device)) assert torch.isclose(loss, torch.zeros_like(loss)), \ f'Expected loss equals zero for identical inputs, got {loss}'