def test_fwt_ifwt_mackey_haar(): """Test the Haar case for a long signal.""" wavelet = pywt.Wavelet("haar") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") generator = MackeyGenerator(batch_size=2, tmax=512, delta_t=1, device=device) wavelet = pywt.Wavelet("haar") pt_data = torch.squeeze(generator()) coeffs_max = pywt.wavedec(pt_data.cpu().numpy(), wavelet, level=9) matrix_wavedec = MatrixWavedec(wavelet, 9) coeffs_mat_max = matrix_wavedec(pt_data) test_lst = [] for test_no in range(9): test_lst.append( np.sum( np.abs(coeffs_max[test_no] - coeffs_mat_max[test_no].cpu().numpy())) < 0.001) print(test_lst) # test the inverse fwt. matrix_waverec = MatrixWaverec(wavelet) reconstructed_data = matrix_waverec(coeffs_mat_max) err1 = torch.mean(torch.abs(pt_data - reconstructed_data)) print("abs ifwt reconstruction error", err1) assert np.allclose(pt_data.cpu().numpy(), reconstructed_data.cpu().numpy())
def test_conv_fwt(): """Test multiple convolution fwt, for various levels and padding options.""" generator = MackeyGenerator(batch_size=2, tmax=128, delta_t=1, device="cpu") mackey_data_1 = torch.squeeze(generator()) for level in [1, 2, 3, None]: for wavelet_string in ["db1", "db2", "db3", "db4", "db5"]: for mode in ["reflect", "zero"]: wavelet = pywt.Wavelet(wavelet_string) ptcoeff = wavedec(mackey_data_1, wavelet, level=level, mode=mode) pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(), wavelet, level=level, mode=mode) cptcoeff = torch.cat(ptcoeff, -1)[0, :] cpycoeff = np.concatenate(pycoeff, -1) err = np.mean(np.abs(cpycoeff - cptcoeff.numpy())) print( "db5 coefficient error scale 3:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(cptcoeff.numpy(), cpycoeff, atol=1e-6) res = waverec( wavedec(mackey_data_1, wavelet, level=3, mode=mode), wavelet, ) err = torch.mean(torch.abs(mackey_data_1 - res)).numpy() print( "db5 reconstruction error scale 3:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(mackey_data_1.numpy(), res.numpy()) res = waverec( wavedec(mackey_data_1, wavelet, level=4, mode=mode), wavelet, ) err = torch.mean(torch.abs(mackey_data_1 - res)).numpy() print( "db5 reconstruction error scale 4:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(mackey_data_1.numpy(), res.numpy())
def test_conv_fwt_db5_lvl3(): """Test a third level db5 conv-fwt.""" generator = MackeyGenerator(batch_size=2, tmax=128, delta_t=1, device="cpu") mackey_data_1 = torch.squeeze(generator()) wavelet = pywt.Wavelet("db5") for mode in ["reflect", "zero"]: ptcoeff = wavedec(mackey_data_1, wavelet, level=3, mode=mode) pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(), wavelet, level=3, mode=mode) cptcoeff = torch.cat(ptcoeff, -1)[0, :] cpycoeff = np.concatenate(pycoeff, -1) err = np.mean(np.abs(cpycoeff - cptcoeff.numpy())) print( "db5 coefficient error scale 3:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(cpycoeff, cptcoeff.numpy(), atol=1e-6) res = waverec(wavedec(mackey_data_1, wavelet, level=3, mode=mode), wavelet) err = torch.mean(torch.abs(mackey_data_1 - res)).numpy() print( "db5 reconstruction error scale 3:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(mackey_data_1.numpy(), res.numpy()) res = waverec(wavedec(mackey_data_1, wavelet, level=4, mode=mode), wavelet) err = torch.mean(torch.abs(mackey_data_1 - res)).numpy() print( "db5 reconstruction error scale 4:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(mackey_data_1.numpy(), res.numpy())
def test_fwt_ifwt_mackey_db2(): """Test the db2 case for a long signal.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") wavelet = pywt.Wavelet("db2") generator = MackeyGenerator(batch_size=2, tmax=512, delta_t=1, device=device) pt_data = torch.squeeze(generator()).cpu() matrix_wavedec = MatrixWavedec(wavelet, 4) coeffs_mat_max = matrix_wavedec(pt_data) matrix_waverec = MatrixWaverec(wavelet) reconstructed_data = matrix_waverec(coeffs_mat_max) err = torch.mean(torch.abs(pt_data - reconstructed_data)) print("reconstruction error:", err) assert err < 1e-6
def test_orth_wavelet(): """Test an orthogonal wavelet fwt.""" generator = MackeyGenerator(batch_size=2, tmax=64, delta_t=1, device="cpu") mackey_data_1 = torch.squeeze(generator()) # orthogonal wavelet object test wavelet = pywt.Wavelet("db5") orthwave = SoftOrthogonalWavelet( torch.tensor(wavelet.rec_lo), torch.tensor(wavelet.rec_hi), torch.tensor(wavelet.dec_lo), torch.tensor(wavelet.dec_hi), ) res = waverec(wavedec(mackey_data_1, orthwave), orthwave) err = torch.mean(torch.abs(mackey_data_1 - res.detach())).numpy() print("orth reconstruction error scale 4:", err, ["ok" if err < 1e-4 else "failed!"]) assert np.allclose(res.detach().numpy(), mackey_data_1.numpy())
def test_conv_fwt_haar_lvl4(): """Test a fourth level Haar wavelet conv-fwt.""" generator = MackeyGenerator(batch_size=2, tmax=64, delta_t=1, device="cpu") mackey_data_1 = torch.squeeze(generator()) wavelet = pywt.Wavelet("haar") ptcoeff = wavedec(mackey_data_1, wavelet, level=4) pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(), wavelet, level=4) ptwt_coeff = torch.cat(ptcoeff, -1)[0, :].numpy() pywt_coeff = np.concatenate(pycoeff) err = np.mean(np.abs(pywt_coeff - ptwt_coeff)) print("haar coefficient error scale 4:", err, ["ok" if err < 1e-4 else "failed!"]) assert np.allclose(pywt_coeff, ptwt_coeff, atol=1e-06) reconstruction = waverec(wavedec(mackey_data_1, wavelet), wavelet) err = torch.mean(torch.abs(mackey_data_1 - reconstruction)).numpy() print("haar reconstruction error scale 4:", err, ["ok" if err < 1e-4 else "failed!"]) assert np.allclose(reconstruction.numpy(), mackey_data_1.numpy())