def test_conv_fwt(): """Test multiple convolution fwt, for various levels and padding options.""" generator = MackeyGenerator(batch_size=2, tmax=128, delta_t=1, device="cpu") mackey_data_1 = torch.squeeze(generator()) for level in [1, 2, 3, None]: for wavelet_string in ["db1", "db2", "db3", "db4", "db5"]: for mode in ["reflect", "zero"]: wavelet = pywt.Wavelet(wavelet_string) ptcoeff = wavedec(mackey_data_1, wavelet, level=level, mode=mode) pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(), wavelet, level=level, mode=mode) cptcoeff = torch.cat(ptcoeff, -1)[0, :] cpycoeff = np.concatenate(pycoeff, -1) err = np.mean(np.abs(cpycoeff - cptcoeff.numpy())) print( "db5 coefficient error scale 3:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(cptcoeff.numpy(), cpycoeff, atol=1e-6) res = waverec( wavedec(mackey_data_1, wavelet, level=3, mode=mode), wavelet, ) err = torch.mean(torch.abs(mackey_data_1 - res)).numpy() print( "db5 reconstruction error scale 3:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(mackey_data_1.numpy(), res.numpy()) res = waverec( wavedec(mackey_data_1, wavelet, level=4, mode=mode), wavelet, ) err = torch.mean(torch.abs(mackey_data_1 - res)).numpy() print( "db5 reconstruction error scale 4:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(mackey_data_1.numpy(), res.numpy())
def test_conv_fwt_db5_lvl3(): """Test a third level db5 conv-fwt.""" generator = MackeyGenerator(batch_size=2, tmax=128, delta_t=1, device="cpu") mackey_data_1 = torch.squeeze(generator()) wavelet = pywt.Wavelet("db5") for mode in ["reflect", "zero"]: ptcoeff = wavedec(mackey_data_1, wavelet, level=3, mode=mode) pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(), wavelet, level=3, mode=mode) cptcoeff = torch.cat(ptcoeff, -1)[0, :] cpycoeff = np.concatenate(pycoeff, -1) err = np.mean(np.abs(cpycoeff - cptcoeff.numpy())) print( "db5 coefficient error scale 3:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(cpycoeff, cptcoeff.numpy(), atol=1e-6) res = waverec(wavedec(mackey_data_1, wavelet, level=3, mode=mode), wavelet) err = torch.mean(torch.abs(mackey_data_1 - res)).numpy() print( "db5 reconstruction error scale 3:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(mackey_data_1.numpy(), res.numpy()) res = waverec(wavedec(mackey_data_1, wavelet, level=4, mode=mode), wavelet) err = torch.mean(torch.abs(mackey_data_1 - res)).numpy() print( "db5 reconstruction error scale 4:", err, ["ok" if err < 1e-4 else "failed!"], "mode", mode, ) assert np.allclose(mackey_data_1.numpy(), res.numpy())
def test_conv_fwt_db2_lvl1(): """Test a second level db2 conv-fwt.""" data = np.array([ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, ]) # ------------------------- db2 wavelet tests ---------------------------- wavelet = pywt.Wavelet("db2") coeffs = pywt.wavedec(data, wavelet, level=1, mode="reflect") coeffs2 = wavedec(torch.from_numpy(data), wavelet, level=1, mode="reflect") ccoeffs = np.concatenate(coeffs, -1) ccoeffs2 = torch.cat(coeffs2, -1).numpy() err = np.mean(np.abs(ccoeffs - ccoeffs2)) print("db2 coefficient error scale 1:", err, ["ok" if err < 1e-4 else "failed!"]) assert np.allclose(ccoeffs, ccoeffs2, atol=1e-6) rec = waverec(coeffs2, wavelet) err = np.mean(np.abs(data - rec.numpy())) print("db2 reconstruction error scale 1:", err, ["ok" if err < 1e-4 else "failed!"]) assert np.allclose(data, rec.numpy())
def wavelet_analysis(self, x): """Compute a 1d-analysis transform. Args: x (torch.tensor): 2d input tensor Returns: [torch.tensor]: 2d output tensor. """ # c_lst = self.wavelet.analysis(x.unsqueeze(0).unsqueeze(0)) c_lst = wavedec(x.unsqueeze(1), self.wavelet, level=self.scales) shape_lst = [c_el.shape[-1] for c_el in c_lst] c_tensor = torch.cat([c for c in c_lst], -1) assert (shape_lst == self.coefficient_len_lst[::-1] ), "Wavelet shape assumptions false. This is a bug." return c_tensor
def test_conv_fwt_haar_lvl2(): """Test Haar wavelet level two conv fwt.""" data = np.array([ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, ]) wavelet = pywt.Wavelet("haar") coeffs = pywt.wavedec(data, wavelet, level=2) coeffs2 = wavedec(torch.from_numpy(data), wavelet, level=2) assert len(coeffs) == len(coeffs2) pywt_coeffs = np.concatenate(coeffs) ptwt_coeffs = torch.cat(coeffs2, -1).squeeze().numpy() err = np.mean(np.abs(pywt_coeffs - ptwt_coeffs)) print("haar coefficient error scale 2", err, ["ok" if err < 1e-6 else "failed!"]) assert np.allclose(pywt_coeffs, ptwt_coeffs) rec = waverec(coeffs2, wavelet).squeeze().numpy() err = np.mean(np.abs((data - rec))) print("haar reconstruction error scale 2", err, ["ok" if err < 1e-6 else "failed!"]) assert np.allclose(data, rec)