def test_boundary_transform_1d(wavelet_str, data, level, boundary):
    """Ensure matrix fwt reconstructions are pywt compatible."""
    data_torch = torch.from_numpy(data.astype(np.float64))
    wavelet = pywt.Wavelet(wavelet_str)
    matrix_wavedec = MatrixWavedec(wavelet, level=level, boundary=boundary)
    coeffs = matrix_wavedec(data_torch)
    matrix_waverec = MatrixWaverec(wavelet, boundary=boundary)
    rec = matrix_waverec(coeffs)
    rec_pywt = pywt.waverec(
        pywt.wavedec(data_torch.numpy(), wavelet, mode="zero"), wavelet)
    error = np.sum(np.abs(rec_pywt - rec.numpy()))
    print(
        "wavelet: {},".format(wavelet_str),
        "level: {},".format(level),
        "shape: {},".format(data.shape[-1]),
        "error {:2.2e}".format(error),
    )
    assert np.allclose(rec.numpy(), rec_pywt)
    # test the operator matrices
    if not matrix_wavedec.padded and not matrix_waverec.padded:
        test_mat = torch.sparse.mm(
            matrix_waverec.sparse_ifwt_operator,
            matrix_wavedec.sparse_fwt_operator,
        )
        assert np.allclose(test_mat.to_dense().numpy(),
                           np.eye(test_mat.shape[0]))
Exemple #2
0
def test_fwt_ifwt_level_1():
    """Test the Haar case."""
    wavelet = pywt.Wavelet("haar")
    data2 = np.array([
        1.0,
        2.0,
        3.0,
        4.0,
        5.0,
        6.0,
        7.0,
        8.0,
        9.0,
        10.0,
        11.0,
        12.0,
        13.0,
        14.0,
        15.0,
        16.0,
    ])

    # level 1
    coeffs = pywt.dwt(data2, wavelet)
    print(coeffs[0], coeffs[1])
    matrix_wavedec = MatrixWavedec(wavelet, 1)
    coeffsmat1 = matrix_wavedec(torch.from_numpy(data2))
    err1 = np.mean(np.abs(coeffs[0] - coeffsmat1[0].squeeze().numpy()))
    err2 = np.mean(np.abs(coeffs[1] - coeffsmat1[1].squeeze().numpy()))
    print(err1 < 0.00001, err2 < 0.00001)
    assert err1 < 1e-6
    assert err2 < 1e-6
Exemple #3
0
def test_fwt_ifwt_mackey_haar():
    """Test the Haar case for a long signal."""
    wavelet = pywt.Wavelet("haar")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator = MackeyGenerator(batch_size=2,
                                tmax=512,
                                delta_t=1,
                                device=device)
    wavelet = pywt.Wavelet("haar")
    pt_data = torch.squeeze(generator())

    coeffs_max = pywt.wavedec(pt_data.cpu().numpy(), wavelet, level=9)

    matrix_wavedec = MatrixWavedec(wavelet, 9)
    coeffs_mat_max = matrix_wavedec(pt_data)

    test_lst = []
    for test_no in range(9):
        test_lst.append(
            np.sum(
                np.abs(coeffs_max[test_no] -
                       coeffs_mat_max[test_no].cpu().numpy())) < 0.001)
    print(test_lst)

    # test the inverse fwt.
    matrix_waverec = MatrixWaverec(wavelet)
    reconstructed_data = matrix_waverec(coeffs_mat_max)
    err1 = torch.mean(torch.abs(pt_data - reconstructed_data))
    print("abs ifwt reconstruction error", err1)
    assert np.allclose(pt_data.cpu().numpy(), reconstructed_data.cpu().numpy())
Exemple #4
0
def test_fwt_ifwt_mackey_db2():
    """Test the db2 case for a long signal."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    wavelet = pywt.Wavelet("db2")
    generator = MackeyGenerator(batch_size=2,
                                tmax=512,
                                delta_t=1,
                                device=device)
    pt_data = torch.squeeze(generator()).cpu()
    matrix_wavedec = MatrixWavedec(wavelet, 4)
    coeffs_mat_max = matrix_wavedec(pt_data)
    matrix_waverec = MatrixWaverec(wavelet)
    reconstructed_data = matrix_waverec(coeffs_mat_max)
    err = torch.mean(torch.abs(pt_data - reconstructed_data))
    print("reconstruction error:", err)
    assert err < 1e-6
Exemple #5
0
def test_fwt_ifwt_level_3():
    """Test the Haar level 3 case."""
    wavelet = pywt.Wavelet("haar")
    data2 = np.array([
        1.0,
        2.0,
        3.0,
        4.0,
        5.0,
        6.0,
        7.0,
        8.0,
        9.0,
        10.0,
        11.0,
        12.0,
        13.0,
        14.0,
        15.0,
        16.0,
    ])
    coeffs3 = pywt.wavedec(data2, wavelet, level=3)
    matrix_wavedec = MatrixWavedec(wavelet, level=3)
    coeffsmat3 = matrix_wavedec(torch.from_numpy(data2))

    err1 = np.mean(np.abs(coeffs3[0] - coeffsmat3[0].squeeze().numpy()))
    err2 = np.mean(np.abs(coeffs3[1] - coeffsmat3[1].squeeze().numpy()))
    err3 = np.mean(np.abs(coeffs3[2] - coeffsmat3[2].squeeze().numpy()))
    err4 = np.mean(np.abs(coeffs3[3] - coeffsmat3[3].squeeze().numpy()))
    print(err1 < 1e-6, err2 < 1e-6, err3 < 1e-6, err4 < 1e-6)

    assert err1 < 1e-6
    assert err2 < 1e-6
    assert err3 < 1e-6
    assert err4 < 1e-6

    matrix_waverec = MatrixWaverec(wavelet)
    reconstructed_data = matrix_waverec(coeffsmat3)
    err5 = torch.mean(torch.abs(torch.from_numpy(data2) - reconstructed_data))
    print("abs ifwt 3 reconstruction error", err5)
    assert np.allclose(data2, reconstructed_data.numpy())
Exemple #6
0
def test_fwt_ifwt_level_2():
    """Test the Haar level two case."""
    wavelet = pywt.Wavelet("haar")
    data2 = np.array([
        1.0,
        2.0,
        3.0,
        4.0,
        5.0,
        6.0,
        7.0,
        8.0,
        9.0,
        10.0,
        11.0,
        12.0,
        13.0,
        14.0,
        15.0,
        16.0,
        17.0,
        18.0,
    ])
    coeffs2 = pywt.wavedec(data2, wavelet, level=2, mode="zero")
    matrix_wavedec = MatrixWavedec(wavelet, 2)
    coeffsmat2 = matrix_wavedec(torch.from_numpy(data2))

    err1 = np.mean(np.abs(coeffs2[0] - coeffsmat2[0].squeeze().numpy()))
    err2 = np.mean(np.abs(coeffs2[1] - coeffsmat2[1].squeeze().numpy()))
    err3 = np.mean(np.abs(coeffs2[2] - coeffsmat2[2].squeeze().numpy()))
    print(
        np.mean(np.abs(coeffs2[0] - coeffsmat2[0].squeeze().numpy())) < 1e-6,
        np.mean(np.abs(coeffs2[1] - coeffsmat2[1].squeeze().numpy())) < 1e-6,
        np.mean(np.abs(coeffs2[2] - coeffsmat2[2].squeeze().numpy())) < 1e-6,
    )
    assert err1 < 1e-6
    assert err2 < 1e-6
    assert err3 < 1e-6