def test_2d_wavedec_rec(): """Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients. Wavedec2 and waverec2 invert each other. """ for level in [1, 2, 3, 4, 5, None]: for wavelet_str in ["haar", "db2", "db3", "db4", "db5"]: face = np.transpose( scipy.misc.face()[256:(512 + 64), 256:(512 + 64)], [2, 0, 1]).astype(np.float64) wavelet = pywt.Wavelet(wavelet_str) coeff2d = wavedec2(torch.from_numpy(face), wavelet, mode="reflect", level=level) pywt_coeff2d = pywt.wavedec2(face, wavelet, mode="reflect", level=level) for pos, coeffs in enumerate(pywt_coeff2d): if type(coeffs) is tuple: for tuple_pos, tuple_el in enumerate(coeffs): assert ( tuple_el.shape == torch.squeeze( coeff2d[pos][tuple_pos], 1).shape ), "pywt and ptwt should produce the same shapes." else: assert (coeffs.shape == torch.squeeze(coeff2d[pos], 1).shape ), "pywt and ptwt should produce the same shapes." flat_coeff_list_pywt = np.concatenate( flatten_2d_coeff_lst(pywt_coeff2d), -1) flat_coeff_list_ptwt = torch.cat(flatten_2d_coeff_lst(coeff2d), -1) cerr = np.mean( np.abs(flat_coeff_list_pywt - flat_coeff_list_ptwt.numpy())) print( "wavelet", wavelet_str, "level", str(level), "coeff err,", cerr, ["ok" if cerr < 1e-4 else "failed!"], ) assert np.allclose(flat_coeff_list_pywt, flat_coeff_list_ptwt.numpy()) rec = waverec2(coeff2d, wavelet) rec = rec.numpy().squeeze() err_img = np.abs(face - rec) err = np.mean(err_img) print( "wavelet", wavelet_str, "level", str(level), "rec err,", err, ["ok" if err < 1e-4 else "failed!"], ) assert np.allclose(face, rec)
def test_2d_haar_lvl1(): """Test a 2d-Haar wavelet conv-fwt.""" # ------------------------- 2d haar wavelet tests ----------------------- face = np.transpose(scipy.misc.face()[128:(512 + 128), 256:(512 + 256)], [2, 0, 1]).astype(np.float64) wavelet = pywt.Wavelet("haar") # single level haar - 2d coeff2d_pywt = pywt.dwt2(face, wavelet, mode="zero") coeff2d = wavedec2(torch.from_numpy(face), wavelet, level=1, mode="constant") flat_list_pywt = np.concatenate(flatten_2d_coeff_lst(coeff2d_pywt), -1) flat_list_ptwt = torch.cat(flatten_2d_coeff_lst(coeff2d), -1) cerr = np.mean(np.abs(flat_list_pywt - flat_list_ptwt.numpy())) print("haar 2d coeff err,", cerr, ["ok" if cerr < 1e-4 else "failed!"]) assert np.allclose(flat_list_pywt, flat_list_ptwt.numpy()) # plt.plot(flat_list_pywt, 'o') # plt.plot(flat_list_ptwt.numpy(), '.') # plt.show() rec = waverec2(coeff2d, wavelet).numpy().squeeze() err_img = np.abs(face - rec) err = np.mean(err_img) # err2 = np.mean(np.abs(face-ptwt_rec)) print("haar 2d rec err", err, ["ok" if err < 1e-4 else "failed!"]) assert np.allclose(rec, face)
def test_2d_db2_lvl1(): """Test a 2d-db2 wavelet conv-fwt.""" # single level db2 - 2d face = np.transpose(scipy.misc.face()[256:(512 + 128), 256:(512 + 128)], [2, 0, 1]).astype(np.float64) wavelet = pywt.Wavelet("db2") coeff2d_pywt = pywt.dwt2(face, wavelet, mode="reflect") coeff2d = wavedec2(torch.from_numpy(face), wavelet, level=1) flat_list_pywt = np.concatenate(flatten_2d_coeff_lst(coeff2d_pywt), -1) flat_list_ptwt = torch.cat(flatten_2d_coeff_lst(coeff2d), -1) cerr = np.mean(np.abs(flat_list_pywt - flat_list_ptwt.numpy())) print("db5 2d coeff err,", cerr, ["ok" if cerr < 1e-4 else "failed!"]) assert np.allclose(flat_list_pywt, flat_list_ptwt.numpy()) # single level db2 - 2d inverse. rec = waverec2(coeff2d, wavelet) err = np.mean(np.abs(face - rec.numpy().squeeze())) print("db5 2d rec err,", err, ["ok" if err < 1e-4 else "failed!"]) assert np.allclose(rec.numpy().squeeze(), face)
def test_2d_haar_multi(): """Test a 2d-db2 wavelet level 5 conv-fwt.""" # multi level haar - 2d face = np.transpose(scipy.misc.face()[256:(512 + 128), 256:(512 + 128)], [2, 0, 1]).astype(np.float64) wavelet = pywt.Wavelet("haar") coeff2d_pywt = pywt.wavedec2(face, wavelet, mode="reflect", level=5) coeff2d = wavedec2(torch.from_numpy(face), wavelet, level=5) flat_list_pywt = np.concatenate(flatten_2d_coeff_lst(coeff2d_pywt), -1) flat_list_ptwt = torch.cat(flatten_2d_coeff_lst(coeff2d), -1) cerr = np.mean(np.abs(flat_list_pywt - flat_list_ptwt.numpy())) # plt.plot(flat_list_pywt); plt.show() # plt.plot(flat_list_ptwt); plt.show() print("haar 2d scale 5 coeff err,", cerr, ["ok" if cerr < 1e-4 else "failed!"]) assert np.allclose(flat_list_pywt, flat_list_ptwt) # inverse multi level Harr - 2d rec = waverec2(coeff2d, wavelet).numpy().squeeze() err = np.mean(np.abs(face - rec)) print("haar 2d scale 5 rec err,", err, ["ok" if err < 1e-4 else "failed!"]) assert np.allclose(rec, face)