示例#1
0
def test_2d_wavedec_rec():
    """Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients.

    Wavedec2 and waverec2 invert each other.
    """
    for level in [1, 2, 3, 4, 5, None]:
        for wavelet_str in ["haar", "db2", "db3", "db4", "db5"]:
            face = np.transpose(
                scipy.misc.face()[256:(512 + 64), 256:(512 + 64)],
                [2, 0, 1]).astype(np.float64)
            wavelet = pywt.Wavelet(wavelet_str)
            coeff2d = wavedec2(torch.from_numpy(face),
                               wavelet,
                               mode="reflect",
                               level=level)
            pywt_coeff2d = pywt.wavedec2(face,
                                         wavelet,
                                         mode="reflect",
                                         level=level)
            for pos, coeffs in enumerate(pywt_coeff2d):
                if type(coeffs) is tuple:
                    for tuple_pos, tuple_el in enumerate(coeffs):
                        assert (
                            tuple_el.shape == torch.squeeze(
                                coeff2d[pos][tuple_pos], 1).shape
                        ), "pywt and ptwt should produce the same shapes."
                else:
                    assert (coeffs.shape == torch.squeeze(coeff2d[pos],
                                                          1).shape
                            ), "pywt and ptwt should produce the same shapes."
            flat_coeff_list_pywt = np.concatenate(
                flatten_2d_coeff_lst(pywt_coeff2d), -1)
            flat_coeff_list_ptwt = torch.cat(flatten_2d_coeff_lst(coeff2d), -1)
            cerr = np.mean(
                np.abs(flat_coeff_list_pywt - flat_coeff_list_ptwt.numpy()))
            print(
                "wavelet",
                wavelet_str,
                "level",
                str(level),
                "coeff err,",
                cerr,
                ["ok" if cerr < 1e-4 else "failed!"],
            )
            assert np.allclose(flat_coeff_list_pywt,
                               flat_coeff_list_ptwt.numpy())
            rec = waverec2(coeff2d, wavelet)
            rec = rec.numpy().squeeze()
            err_img = np.abs(face - rec)
            err = np.mean(err_img)
            print(
                "wavelet",
                wavelet_str,
                "level",
                str(level),
                "rec   err,",
                err,
                ["ok" if err < 1e-4 else "failed!"],
            )
            assert np.allclose(face, rec)
示例#2
0
def test_2d_haar_lvl1():
    """Test a 2d-Haar wavelet conv-fwt."""
    # ------------------------- 2d haar wavelet tests -----------------------
    face = np.transpose(scipy.misc.face()[128:(512 + 128), 256:(512 + 256)],
                        [2, 0, 1]).astype(np.float64)
    wavelet = pywt.Wavelet("haar")

    # single level haar - 2d
    coeff2d_pywt = pywt.dwt2(face, wavelet, mode="zero")
    coeff2d = wavedec2(torch.from_numpy(face),
                       wavelet,
                       level=1,
                       mode="constant")
    flat_list_pywt = np.concatenate(flatten_2d_coeff_lst(coeff2d_pywt), -1)
    flat_list_ptwt = torch.cat(flatten_2d_coeff_lst(coeff2d), -1)
    cerr = np.mean(np.abs(flat_list_pywt - flat_list_ptwt.numpy()))
    print("haar 2d coeff err,", cerr, ["ok" if cerr < 1e-4 else "failed!"])
    assert np.allclose(flat_list_pywt, flat_list_ptwt.numpy())

    # plt.plot(flat_list_pywt, 'o')
    # plt.plot(flat_list_ptwt.numpy(), '.')
    # plt.show()

    rec = waverec2(coeff2d, wavelet).numpy().squeeze()
    err_img = np.abs(face - rec)
    err = np.mean(err_img)
    # err2 = np.mean(np.abs(face-ptwt_rec))
    print("haar 2d rec err", err, ["ok" if err < 1e-4 else "failed!"])
    assert np.allclose(rec, face)
示例#3
0
def test_2d_db2_lvl1():
    """Test a 2d-db2 wavelet conv-fwt."""
    # single level db2 - 2d
    face = np.transpose(scipy.misc.face()[256:(512 + 128), 256:(512 + 128)],
                        [2, 0, 1]).astype(np.float64)
    wavelet = pywt.Wavelet("db2")
    coeff2d_pywt = pywt.dwt2(face, wavelet, mode="reflect")
    coeff2d = wavedec2(torch.from_numpy(face), wavelet, level=1)
    flat_list_pywt = np.concatenate(flatten_2d_coeff_lst(coeff2d_pywt), -1)
    flat_list_ptwt = torch.cat(flatten_2d_coeff_lst(coeff2d), -1)
    cerr = np.mean(np.abs(flat_list_pywt - flat_list_ptwt.numpy()))
    print("db5 2d coeff err,", cerr, ["ok" if cerr < 1e-4 else "failed!"])
    assert np.allclose(flat_list_pywt, flat_list_ptwt.numpy())

    # single level db2 - 2d inverse.
    rec = waverec2(coeff2d, wavelet)
    err = np.mean(np.abs(face - rec.numpy().squeeze()))
    print("db5 2d rec err,", err, ["ok" if err < 1e-4 else "failed!"])
    assert np.allclose(rec.numpy().squeeze(), face)
示例#4
0
def test_2d_haar_multi():
    """Test a 2d-db2 wavelet level 5 conv-fwt."""
    # multi level haar - 2d
    face = np.transpose(scipy.misc.face()[256:(512 + 128), 256:(512 + 128)],
                        [2, 0, 1]).astype(np.float64)
    wavelet = pywt.Wavelet("haar")
    coeff2d_pywt = pywt.wavedec2(face, wavelet, mode="reflect", level=5)
    coeff2d = wavedec2(torch.from_numpy(face), wavelet, level=5)
    flat_list_pywt = np.concatenate(flatten_2d_coeff_lst(coeff2d_pywt), -1)
    flat_list_ptwt = torch.cat(flatten_2d_coeff_lst(coeff2d), -1)
    cerr = np.mean(np.abs(flat_list_pywt - flat_list_ptwt.numpy()))
    # plt.plot(flat_list_pywt); plt.show()
    # plt.plot(flat_list_ptwt); plt.show()
    print("haar 2d scale 5 coeff err,", cerr,
          ["ok" if cerr < 1e-4 else "failed!"])
    assert np.allclose(flat_list_pywt, flat_list_ptwt)

    # inverse multi level Harr - 2d
    rec = waverec2(coeff2d, wavelet).numpy().squeeze()
    err = np.mean(np.abs(face - rec))
    print("haar 2d scale 5 rec err,", err, ["ok" if err < 1e-4 else "failed!"])
    assert np.allclose(rec, face)