Exemple #1
0
def test_fwt_ifwt_mackey_haar():
    """Test the Haar case for a long signal."""
    wavelet = pywt.Wavelet("haar")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator = MackeyGenerator(batch_size=2,
                                tmax=512,
                                delta_t=1,
                                device=device)
    wavelet = pywt.Wavelet("haar")
    pt_data = torch.squeeze(generator())

    coeffs_max = pywt.wavedec(pt_data.cpu().numpy(), wavelet, level=9)

    matrix_wavedec = MatrixWavedec(wavelet, 9)
    coeffs_mat_max = matrix_wavedec(pt_data)

    test_lst = []
    for test_no in range(9):
        test_lst.append(
            np.sum(
                np.abs(coeffs_max[test_no] -
                       coeffs_mat_max[test_no].cpu().numpy())) < 0.001)
    print(test_lst)

    # test the inverse fwt.
    matrix_waverec = MatrixWaverec(wavelet)
    reconstructed_data = matrix_waverec(coeffs_mat_max)
    err1 = torch.mean(torch.abs(pt_data - reconstructed_data))
    print("abs ifwt reconstruction error", err1)
    assert np.allclose(pt_data.cpu().numpy(), reconstructed_data.cpu().numpy())
Exemple #2
0
def test_conv_fwt():
    """Test multiple convolution fwt, for various levels and padding options."""
    generator = MackeyGenerator(batch_size=2,
                                tmax=128,
                                delta_t=1,
                                device="cpu")

    mackey_data_1 = torch.squeeze(generator())
    for level in [1, 2, 3, None]:
        for wavelet_string in ["db1", "db2", "db3", "db4", "db5"]:
            for mode in ["reflect", "zero"]:
                wavelet = pywt.Wavelet(wavelet_string)
                ptcoeff = wavedec(mackey_data_1,
                                  wavelet,
                                  level=level,
                                  mode=mode)
                pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(),
                                       wavelet,
                                       level=level,
                                       mode=mode)
                cptcoeff = torch.cat(ptcoeff, -1)[0, :]
                cpycoeff = np.concatenate(pycoeff, -1)
                err = np.mean(np.abs(cpycoeff - cptcoeff.numpy()))
                print(
                    "db5 coefficient error scale 3:",
                    err,
                    ["ok" if err < 1e-4 else "failed!"],
                    "mode",
                    mode,
                )
                assert np.allclose(cptcoeff.numpy(), cpycoeff, atol=1e-6)

                res = waverec(
                    wavedec(mackey_data_1, wavelet, level=3, mode=mode),
                    wavelet,
                )
                err = torch.mean(torch.abs(mackey_data_1 - res)).numpy()
                print(
                    "db5 reconstruction error scale 3:",
                    err,
                    ["ok" if err < 1e-4 else "failed!"],
                    "mode",
                    mode,
                )
                assert np.allclose(mackey_data_1.numpy(), res.numpy())

                res = waverec(
                    wavedec(mackey_data_1, wavelet, level=4, mode=mode),
                    wavelet,
                )
                err = torch.mean(torch.abs(mackey_data_1 - res)).numpy()
                print(
                    "db5 reconstruction error scale 4:",
                    err,
                    ["ok" if err < 1e-4 else "failed!"],
                    "mode",
                    mode,
                )
                assert np.allclose(mackey_data_1.numpy(), res.numpy())
Exemple #3
0
def test_conv_fwt_db5_lvl3():
    """Test a third level db5 conv-fwt."""
    generator = MackeyGenerator(batch_size=2,
                                tmax=128,
                                delta_t=1,
                                device="cpu")

    mackey_data_1 = torch.squeeze(generator())
    wavelet = pywt.Wavelet("db5")
    for mode in ["reflect", "zero"]:
        ptcoeff = wavedec(mackey_data_1, wavelet, level=3, mode=mode)
        pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(),
                               wavelet,
                               level=3,
                               mode=mode)
        cptcoeff = torch.cat(ptcoeff, -1)[0, :]
        cpycoeff = np.concatenate(pycoeff, -1)
        err = np.mean(np.abs(cpycoeff - cptcoeff.numpy()))
        print(
            "db5 coefficient error scale 3:",
            err,
            ["ok" if err < 1e-4 else "failed!"],
            "mode",
            mode,
        )
        assert np.allclose(cpycoeff, cptcoeff.numpy(), atol=1e-6)

        res = waverec(wavedec(mackey_data_1, wavelet, level=3, mode=mode),
                      wavelet)
        err = torch.mean(torch.abs(mackey_data_1 - res)).numpy()
        print(
            "db5 reconstruction error scale 3:",
            err,
            ["ok" if err < 1e-4 else "failed!"],
            "mode",
            mode,
        )
        assert np.allclose(mackey_data_1.numpy(), res.numpy())
        res = waverec(wavedec(mackey_data_1, wavelet, level=4, mode=mode),
                      wavelet)
        err = torch.mean(torch.abs(mackey_data_1 - res)).numpy()
        print(
            "db5 reconstruction error scale 4:",
            err,
            ["ok" if err < 1e-4 else "failed!"],
            "mode",
            mode,
        )
        assert np.allclose(mackey_data_1.numpy(), res.numpy())
Exemple #4
0
def test_fwt_ifwt_mackey_db2():
    """Test the db2 case for a long signal."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    wavelet = pywt.Wavelet("db2")
    generator = MackeyGenerator(batch_size=2,
                                tmax=512,
                                delta_t=1,
                                device=device)
    pt_data = torch.squeeze(generator()).cpu()
    matrix_wavedec = MatrixWavedec(wavelet, 4)
    coeffs_mat_max = matrix_wavedec(pt_data)
    matrix_waverec = MatrixWaverec(wavelet)
    reconstructed_data = matrix_waverec(coeffs_mat_max)
    err = torch.mean(torch.abs(pt_data - reconstructed_data))
    print("reconstruction error:", err)
    assert err < 1e-6
Exemple #5
0
def test_orth_wavelet():
    """Test an orthogonal wavelet fwt."""
    generator = MackeyGenerator(batch_size=2, tmax=64, delta_t=1, device="cpu")

    mackey_data_1 = torch.squeeze(generator())
    # orthogonal wavelet object test
    wavelet = pywt.Wavelet("db5")
    orthwave = SoftOrthogonalWavelet(
        torch.tensor(wavelet.rec_lo),
        torch.tensor(wavelet.rec_hi),
        torch.tensor(wavelet.dec_lo),
        torch.tensor(wavelet.dec_hi),
    )
    res = waverec(wavedec(mackey_data_1, orthwave), orthwave)
    err = torch.mean(torch.abs(mackey_data_1 - res.detach())).numpy()
    print("orth reconstruction error scale 4:", err,
          ["ok" if err < 1e-4 else "failed!"])
    assert np.allclose(res.detach().numpy(), mackey_data_1.numpy())
Exemple #6
0
def test_conv_fwt_haar_lvl4():
    """Test a fourth level Haar wavelet conv-fwt."""
    generator = MackeyGenerator(batch_size=2, tmax=64, delta_t=1, device="cpu")
    mackey_data_1 = torch.squeeze(generator())
    wavelet = pywt.Wavelet("haar")
    ptcoeff = wavedec(mackey_data_1, wavelet, level=4)
    pycoeff = pywt.wavedec(mackey_data_1[0, :].numpy(), wavelet, level=4)
    ptwt_coeff = torch.cat(ptcoeff, -1)[0, :].numpy()
    pywt_coeff = np.concatenate(pycoeff)
    err = np.mean(np.abs(pywt_coeff - ptwt_coeff))
    print("haar coefficient error scale 4:", err,
          ["ok" if err < 1e-4 else "failed!"])
    assert np.allclose(pywt_coeff, ptwt_coeff, atol=1e-06)

    reconstruction = waverec(wavedec(mackey_data_1, wavelet), wavelet)
    err = torch.mean(torch.abs(mackey_data_1 - reconstruction)).numpy()
    print("haar reconstruction error scale 4:", err,
          ["ok" if err < 1e-4 else "failed!"])
    assert np.allclose(reconstruction.numpy(), mackey_data_1.numpy())