コード例 #1
0
def test_reduction_for_dim_none(reduction):
    match = f"The `reduction={reduction}` will not have any effect when `dim` is None."
    with pytest.warns(UserWarning, match=match):
        PSNR(reduction=reduction, dim=None)

    with pytest.warns(UserWarning, match=match):
        psnr(_inputs[0].preds,
             _inputs[0].target,
             reduction=reduction,
             dim=None)
コード例 #2
0
def test_psnr_base_e_wider_range(pred, target):
    score = psnr(pred=torch.tensor(pred),
                 target=torch.tensor(target),
                 data_range=4,
                 base=2.718281828459045)
    sk_score = ski_psnr(np.array(pred), np.array(target), data_range=4) * np.log(10)
    assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float32), atol=1e-3)
コード例 #3
0
def test_psnr_with_skimage(pred, target):
    score = psnr(pred=torch.tensor(pred),
                 target=torch.tensor(target),
                 data_range=3)
    sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3)
    assert torch.allclose(score,
                          torch.tensor(sk_score, dtype=torch.float),
                          atol=1e-3)
コード例 #4
0
def psnr_lightning(real, fake, return_per_frame=False, normalize_range=True):
    if real.dim() == 3:
        real = real[None, None]
        fake = fake[None, None]
    elif real.dim() == 4:
        real = real[None]
        fake = fake[None]
    if normalize_range:
        real = (real + 1.) / 2.
        fake = (fake + 1.) / 2.

    psnr_batch = PF.psnr(fake.reshape(-1, *fake.shape[2:]),
                         real.reshape(-1, *real.shape[2:])).cpu().numpy()

    if return_per_frame:
        psnr_per_frame = {
            i: PF.psnr(fake[:, i].contiguous(),
                       real[:, i].contiguous()).cpu().numpy()
            for i in range(real.shape[1])
        }

        return psnr_batch, psnr_per_frame

    return psnr_batch
コード例 #5
0
def test_v1_5_metric_regress():
    ExplainedVariance.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        ExplainedVariance()

    MeanAbsoluteError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanAbsoluteError()

    MeanSquaredError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanSquaredError()

    MeanSquaredLogError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanSquaredLogError()

    target = torch.tensor([3, -0.5, 2, 7])
    preds = torch.tensor([2.5, 0.0, 2, 8])
    explained_variance._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = explained_variance(preds, target)
    assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4)

    x = torch.tensor([0., 1, 2, 3])
    y = torch.tensor([0., 1, 2, 2])
    mean_absolute_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_absolute_error(x, y) == 0.25

    mean_relative_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_relative_error(x, y) == 0.125

    mean_squared_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_squared_error(x, y) == 0.25

    mean_squared_log_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = mean_squared_log_error(x, y)
    assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4)

    PSNR.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        PSNR()

    R2Score.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        R2Score()

    SSIM.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        SSIM()

    preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
    target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
    psnr._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = psnr(preds, target)
    assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4)

    target = torch.tensor([3, -0.5, 2, 7])
    preds = torch.tensor([2.5, 0.0, 2, 8])
    r2score._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = r2score(preds, target)
    assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4)

    preds = torch.rand([16, 1, 16, 16])
    target = preds * 0.75
    ssim._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = ssim(preds, target)
    assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4)
コード例 #6
0
def test_missing_data_range():
    with pytest.raises(ValueError):
        PSNR(data_range=None, dim=0)

    with pytest.raises(ValueError):
        psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)
コード例 #7
0
ファイル: metrics.py プロジェクト: mcbuehler/VariTex
 def forward(self, preds, target):
     preds = self.unnormalize(preds)
     target = self.unnormalize(target)
     return psnr(preds, target, data_range=self.scale)