def test_reduction_for_dim_none(reduction): match = f"The `reduction={reduction}` will not have any effect when `dim` is None." with pytest.warns(UserWarning, match=match): PSNR(reduction=reduction, dim=None) with pytest.warns(UserWarning, match=match): psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None)
def test_psnr_base_e_wider_range(pred, target): score = psnr(pred=torch.tensor(pred), target=torch.tensor(target), data_range=4, base=2.718281828459045) sk_score = ski_psnr(np.array(pred), np.array(target), data_range=4) * np.log(10) assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float32), atol=1e-3)
def test_psnr_with_skimage(pred, target): score = psnr(pred=torch.tensor(pred), target=torch.tensor(target), data_range=3) sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3) assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float), atol=1e-3)
def psnr_lightning(real, fake, return_per_frame=False, normalize_range=True): if real.dim() == 3: real = real[None, None] fake = fake[None, None] elif real.dim() == 4: real = real[None] fake = fake[None] if normalize_range: real = (real + 1.) / 2. fake = (fake + 1.) / 2. psnr_batch = PF.psnr(fake.reshape(-1, *fake.shape[2:]), real.reshape(-1, *real.shape[2:])).cpu().numpy() if return_per_frame: psnr_per_frame = { i: PF.psnr(fake[:, i].contiguous(), real[:, i].contiguous()).cpu().numpy() for i in range(real.shape[1]) } return psnr_batch, psnr_per_frame return psnr_batch
def test_v1_5_metric_regress(): ExplainedVariance.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): ExplainedVariance() MeanAbsoluteError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanAbsoluteError() MeanSquaredError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanSquaredError() MeanSquaredLogError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanSquaredLogError() target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) explained_variance._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = explained_variance(preds, target) assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) x = torch.tensor([0., 1, 2, 3]) y = torch.tensor([0., 1, 2, 2]) mean_absolute_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_absolute_error(x, y) == 0.25 mean_relative_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_relative_error(x, y) == 0.125 mean_squared_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_squared_error(x, y) == 0.25 mean_squared_log_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = mean_squared_log_error(x, y) assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) PSNR.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): PSNR() R2Score.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): R2Score() SSIM.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): SSIM() preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) psnr._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = psnr(preds, target) assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4) target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) r2score._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = r2score(preds, target) assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4) preds = torch.rand([16, 1, 16, 16]) target = preds * 0.75 ssim._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = ssim(preds, target) assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4)
def test_missing_data_range(): with pytest.raises(ValueError): PSNR(data_range=None, dim=0) with pytest.raises(ValueError): psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)
def forward(self, preds, target): preds = self.unnormalize(preds) target = self.unnormalize(target) return psnr(preds, target, data_range=self.scale)