예제 #1
0
def test_lpips_computes_grad(x, y, device: str) -> None:
    x.requires_grad_()
    loss_value = LPIPS()(x.to(device), y.to(device))
    loss_value.backward()
    assert x.grad is not None, NONE_GRAD_ERR_MSG
예제 #2
0
def test_lpips_computes_grad(prediction: torch.Tensor, target: torch.Tensor,
                             device: str) -> None:
    prediction.requires_grad_()
    loss_value = LPIPS()(prediction.to(device), target.to(device))
    loss_value.backward()
    assert prediction.grad is not None, NONE_GRAD_ERR_MSG