Пример #1
0
def test_vsi_loss(input_tensors: Tuple[torch.Tensor, torch.Tensor], device: str) -> None:
    x, y = input_tensors
    x.requires_grad_()
    loss = VSILoss(data_range=1.)(x.to(device), y.to(device))
    loss.backward()
    assert torch.isfinite(x.grad).all(), \
        f'Expected finite gradient values after back propagation, got {x.grad}'
Пример #2
0
def test_vsi_loss_zero_for_equal_input(input_tensors: Tuple[torch.Tensor, torch.Tensor], device: str) -> None:
    x, _ = input_tensors
    y = x.clone()
    x.requires_grad_()
    loss = VSILoss(data_range=1.)(x.to(device), y.to(device))
    assert torch.isclose(loss, torch.zeros_like(loss)), \
        f'Expected loss equals zero for identical inputs, got {loss}'