Esempio n. 1
0
def test_conv_fwt():
    """Test multiple convolution fwt, for various levels and padding options."""
    generator = MackeyGenerator(batch_size=2,
                                tmax=128,
                                delta_t=1,
                                device="cpu")

    mackey_data_1 = torch.squeeze(generator())
    for level in [1, 2, 3, None]:
        for wavelet_string in ["db1", "db2", "db3", "db4", "db5"]:
            for mode in ["reflect", "zero"]:
                wavelet = pywt.Wavelet(wavelet_string)
                ptcoeff = wavedec(mackey_data_1,
                                  wavelet,
                                  level=level,
                                  mode=mode)
                pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(),
                                       wavelet,
                                       level=level,
                                       mode=mode)
                cptcoeff = torch.cat(ptcoeff, -1)[0, :]
                cpycoeff = np.concatenate(pycoeff, -1)
                err = np.mean(np.abs(cpycoeff - cptcoeff.numpy()))
                print(
                    "db5 coefficient error scale 3:",
                    err,
                    ["ok" if err < 1e-4 else "failed!"],
                    "mode",
                    mode,
                )
                assert np.allclose(cptcoeff.numpy(), cpycoeff, atol=1e-6)

                res = waverec(
                    wavedec(mackey_data_1, wavelet, level=3, mode=mode),
                    wavelet,
                )
                err = torch.mean(torch.abs(mackey_data_1 - res)).numpy()
                print(
                    "db5 reconstruction error scale 3:",
                    err,
                    ["ok" if err < 1e-4 else "failed!"],
                    "mode",
                    mode,
                )
                assert np.allclose(mackey_data_1.numpy(), res.numpy())

                res = waverec(
                    wavedec(mackey_data_1, wavelet, level=4, mode=mode),
                    wavelet,
                )
                err = torch.mean(torch.abs(mackey_data_1 - res)).numpy()
                print(
                    "db5 reconstruction error scale 4:",
                    err,
                    ["ok" if err < 1e-4 else "failed!"],
                    "mode",
                    mode,
                )
                assert np.allclose(mackey_data_1.numpy(), res.numpy())
Esempio n. 2
0
def test_conv_fwt_db5_lvl3():
    """Test a third level db5 conv-fwt."""
    generator = MackeyGenerator(batch_size=2,
                                tmax=128,
                                delta_t=1,
                                device="cpu")

    mackey_data_1 = torch.squeeze(generator())
    wavelet = pywt.Wavelet("db5")
    for mode in ["reflect", "zero"]:
        ptcoeff = wavedec(mackey_data_1, wavelet, level=3, mode=mode)
        pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(),
                               wavelet,
                               level=3,
                               mode=mode)
        cptcoeff = torch.cat(ptcoeff, -1)[0, :]
        cpycoeff = np.concatenate(pycoeff, -1)
        err = np.mean(np.abs(cpycoeff - cptcoeff.numpy()))
        print(
            "db5 coefficient error scale 3:",
            err,
            ["ok" if err < 1e-4 else "failed!"],
            "mode",
            mode,
        )
        assert np.allclose(cpycoeff, cptcoeff.numpy(), atol=1e-6)

        res = waverec(wavedec(mackey_data_1, wavelet, level=3, mode=mode),
                      wavelet)
        err = torch.mean(torch.abs(mackey_data_1 - res)).numpy()
        print(
            "db5 reconstruction error scale 3:",
            err,
            ["ok" if err < 1e-4 else "failed!"],
            "mode",
            mode,
        )
        assert np.allclose(mackey_data_1.numpy(), res.numpy())
        res = waverec(wavedec(mackey_data_1, wavelet, level=4, mode=mode),
                      wavelet)
        err = torch.mean(torch.abs(mackey_data_1 - res)).numpy()
        print(
            "db5 reconstruction error scale 4:",
            err,
            ["ok" if err < 1e-4 else "failed!"],
            "mode",
            mode,
        )
        assert np.allclose(mackey_data_1.numpy(), res.numpy())
Esempio n. 3
0
def test_conv_fwt_db2_lvl1():
    """Test a second level db2 conv-fwt."""
    data = np.array([
        1.0,
        2.0,
        3.0,
        4.0,
        5.0,
        6.0,
        7.0,
        8.0,
        9.0,
        10.0,
        11.0,
        12.0,
        13.0,
        14.0,
        15.0,
        16.0,
    ])
    # ------------------------- db2 wavelet tests ----------------------------
    wavelet = pywt.Wavelet("db2")
    coeffs = pywt.wavedec(data, wavelet, level=1, mode="reflect")
    coeffs2 = wavedec(torch.from_numpy(data), wavelet, level=1, mode="reflect")
    ccoeffs = np.concatenate(coeffs, -1)
    ccoeffs2 = torch.cat(coeffs2, -1).numpy()
    err = np.mean(np.abs(ccoeffs - ccoeffs2))
    print("db2 coefficient error scale 1:", err,
          ["ok" if err < 1e-4 else "failed!"])
    assert np.allclose(ccoeffs, ccoeffs2, atol=1e-6)
    rec = waverec(coeffs2, wavelet)
    err = np.mean(np.abs(data - rec.numpy()))
    print("db2 reconstruction error scale 1:", err,
          ["ok" if err < 1e-4 else "failed!"])
    assert np.allclose(data, rec.numpy())
 def wavelet_analysis(self, x):
     """Compute a 1d-analysis transform.
     Args:
         x (torch.tensor): 2d input tensor
     Returns:
         [torch.tensor]: 2d output tensor.
     """
     # c_lst = self.wavelet.analysis(x.unsqueeze(0).unsqueeze(0))
     c_lst = wavedec(x.unsqueeze(1), self.wavelet, level=self.scales)
     shape_lst = [c_el.shape[-1] for c_el in c_lst]
     c_tensor = torch.cat([c for c in c_lst], -1)
     assert (shape_lst == self.coefficient_len_lst[::-1]
             ), "Wavelet shape assumptions false. This is a bug."
     return c_tensor
Esempio n. 5
0
def test_conv_fwt_haar_lvl2():
    """Test Haar wavelet level two conv fwt."""
    data = np.array([
        1.0,
        2.0,
        3.0,
        4.0,
        5.0,
        6.0,
        7.0,
        8.0,
        9.0,
        10.0,
        11.0,
        12.0,
        13.0,
        14.0,
        15.0,
        16.0,
    ])
    wavelet = pywt.Wavelet("haar")
    coeffs = pywt.wavedec(data, wavelet, level=2)
    coeffs2 = wavedec(torch.from_numpy(data), wavelet, level=2)
    assert len(coeffs) == len(coeffs2)

    pywt_coeffs = np.concatenate(coeffs)
    ptwt_coeffs = torch.cat(coeffs2, -1).squeeze().numpy()
    err = np.mean(np.abs(pywt_coeffs - ptwt_coeffs))
    print("haar coefficient error scale 2", err,
          ["ok" if err < 1e-6 else "failed!"])
    assert np.allclose(pywt_coeffs, ptwt_coeffs)
    rec = waverec(coeffs2, wavelet).squeeze().numpy()
    err = np.mean(np.abs((data - rec)))
    print("haar reconstruction error scale 2", err,
          ["ok" if err < 1e-6 else "failed!"])
    assert np.allclose(data, rec)