예제 #1
0
파일: test_czt.py 프로젝트: PuhlUP/CZT
def test_compare_czt_fft_dft(debug=False):
    print("Compare CZT, FFT and DFT")

    # Create time-domain signal
    t = np.arange(0, 20e-3 + 1e-10, 1e-4)
    x = _signal_model(t)
    dt = t[1] - t[0]
    fs = 1 / dt

    # Frequency sweep
    f = np.fft.fftshift(np.fft.fftfreq(len(t)) * fs)

    # CZT (defaults to FFT settings)
    X_czt = np.fft.fftshift(czt.czt(x))

    # FFT
    X_fft = np.fft.fftshift(np.fft.fft(x))

    # DFT (defaults to FFT settings)
    _, X_dft = czt.dft(t, x)

    # Plot for debugging purposes
    if debug:
        plt.figure()
        plt.title("Imaginary")
        plt.plot(f, X_czt.imag, label='CZT')
        plt.plot(f, X_fft.imag, label='FFT', ls='--')
        plt.plot(f, X_dft.imag, label='DFT', ls='--')
        plt.legend()
        plt.figure()
        plt.title("Real")
        plt.plot(f, X_czt.real, label='CZT')
        plt.plot(f, X_fft.real, label='FFT', ls='--')
        plt.plot(f, X_dft.real, label='DFT', ls='--')
        plt.legend()
        plt.figure()
        plt.title("Absolute")
        plt.plot(f, np.abs(X_czt), label='CZT')
        plt.plot(f, np.abs(X_fft), label='FFT', ls='--')
        plt.plot(f, np.abs(X_dft), label='DFT', ls='--')
        plt.legend()
        plt.show()

    # Compare
    np.testing.assert_almost_equal(X_czt, X_fft, decimal=12)
    np.testing.assert_almost_equal(X_czt, X_dft, decimal=12)
예제 #2
0
파일: test_czt.py 프로젝트: PuhlUP/CZT
def test_czt_to_iczt(debug=False):
    print("Test CZT -> ICZT")

    # Create time-domain signal
    t = np.arange(0, 20e-3, 1e-4)
    x = _signal_model(t)

    # CZT (defaults to FFT)
    X_czt = czt.czt(x)

    # ICZT
    x_iczt1 = czt.iczt(X_czt)
    x_iczt2 = czt.iczt(X_czt, simple=False)

    # Try unsupported t_method
    with pytest.raises(ValueError):
        czt.iczt(X_czt, simple=False, t_method='unsupported_t_method')

    # Try M != N
    with pytest.raises(ValueError):
        czt.iczt(X_czt, simple=False, N=len(X_czt) + 1)

    # Plot for debugging purposes
    if debug:
        plt.figure()
        plt.title("Imaginary")
        plt.plot(t * 1e3, x.imag)
        plt.plot(t * 1e3, x_iczt1.imag)
        plt.plot(t * 1e3, x_iczt2.imag)
        plt.figure()
        plt.title("Real")
        plt.plot(t * 1e3, x.real)
        plt.plot(t * 1e3, x_iczt1.real)
        plt.plot(t * 1e3, x_iczt2.real)
        plt.show()

    # Compare
    np.testing.assert_almost_equal(x, x_iczt1, decimal=12)
    np.testing.assert_almost_equal(x, x_iczt2, decimal=12)
예제 #3
0
파일: test_czt.py 프로젝트: z9876/CZT
def test_compare_czt_to_fft():
    """Compare CZT to FFT."""

    # Time data
    t = np.arange(0, 20e-3, 1e-4)
    dt = t[1] - t[0]
    Fs = 1 / dt
    N = len(t)

    # Signal data
    def model(t):
        output = (1.0 * np.sin(2 * np.pi * 1e3 * t) +
                  0.3 * np.sin(2 * np.pi * 2e3 * t) +
                  0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t)
        return output

    x = model(t)

    # CZT (defaults to FFT)
    X_czt = czt.czt(x)

    # FFT
    X_fft = np.fft.fft(x)

    # # Debug
    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.plot(np.abs(X_czt), 'k')
    # plt.plot(np.abs(X_fft), 'r--')
    # plt.figure()
    # plt.plot(X_czt.real, 'k')
    # plt.plot(X_fft.real, 'r--')
    # plt.figure()
    # plt.plot(X_czt.imag, 'k')
    # plt.plot(X_fft.imag, 'r--')
    # plt.show()

    # Compare
    np.testing.assert_almost_equal(X_czt, X_fft, decimal=12)
예제 #4
0
파일: test_czt.py 프로젝트: z9876/CZT
def test_iczt():
    """Test inverse CZT."""

    # Time data
    t = np.arange(0, 20e-3, 1e-4)
    dt = t[1] - t[0]
    Fs = 1 / dt
    N = len(t)

    # Signal data
    def model(t):
        output = (1.0 * np.sin(2 * np.pi * 1e3 * t) +
                  0.3 * np.sin(2 * np.pi * 2e3 * t) +
                  0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t)
        return output

    x = model(t)

    # CZT (defaults to FFT)
    X_czt = czt.czt(x)

    # ICZT
    x_iczt = czt.iczt(X_czt)

    # # Debug
    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.plot(x.real)
    # plt.plot(x_iczt.real)
    # plt.figure()
    # plt.plot(x.imag)
    # plt.plot(x_iczt.imag)
    # plt.show()

    # Compare
    np.testing.assert_almost_equal(x, x_iczt, decimal=12)
예제 #5
0
"""Benchmark czt.iczt function."""

import numpy as np
import czt
import perfplot


def model(t):
    output = (1.0 * np.sin(2 * np.pi * 1e3 * t) +
              0.3 * np.sin(2 * np.pi * 2e3 * t) +
              0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t)
    return output


perfplot.show(
    setup=lambda n: czt.czt(model(np.linspace(0, 20e-3, n))),
    kernels=[
        lambda a: czt.iczt(a, simple=True),
        lambda a: czt.iczt(a, simple=False),
    ],
    labels=["simple=True", "simple=False"],
    n_range=[10**k for k in range(1, 8)],
    xlabel="Input length",
    # equality_check=np.allclose,
    equality_check=False,
    target_time_per_measurement=0.1,
)
예제 #6
0
import czt
import perfplot


def model(t):
    """Signal model."""
    output = (1.0 * np.sin(2 * np.pi * 1e3 * t) +
              0.3 * np.sin(2 * np.pi * 2e3 * t) +
              0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t)
    return output


perfplot.show(
    setup=lambda n: model(np.linspace(0, 20e-3, n)),
    kernels=[
        # lambda a: czt.czt(a, simple=True),
        lambda a: czt.czt(a, t_method='ce'),
        lambda a: czt.czt(a, t_method='pd'),
        # lambda a: czt.czt(a, t_method='mm'),
        lambda a: czt.czt(a, t_method='scipy'),
        # lambda a: czt.czt(a, t_method='ce', f_method='recursive'),
        # lambda a: czt.czt(a, t_method='pd', f_method='recursive'),
    ],
    # labels=["simple", "ce", "pd", "mm", "scipy", "ce/recursive", "pd/recursive"],
    labels=["ce", "pd", "scipy"],
    n_range=[10**k for k in range(1, 7)],
    xlabel="Input length",
    equality_check=np.allclose,
    target_time_per_measurement=0.1,
)
예제 #7
0
파일: test_czt.py 프로젝트: PuhlUP/CZT
def test_compare_different_czt_methods(debug=False):
    print("Compare different CZT calculation methods")

    # Create time-domain signal
    t = np.arange(0, 20e-3, 1e-4)
    x = _signal_model(t)

    # Calculate CZT using different methods
    X_czt1 = czt.czt(x, simple=True)
    X_czt2 = czt.czt(x, t_method='ce')
    X_czt3 = czt.czt(x, t_method='pd')
    X_czt4 = czt.czt(x, t_method='mm')
    X_czt5 = czt.czt(x, t_method='scipy')
    X_czt6 = czt.czt(x, t_method='ce', f_method='recursive')
    X_czt7 = czt.czt(x, t_method='pd', f_method='recursive')

    # Try unsupported t_method
    with pytest.raises(ValueError):
        czt.czt(x, t_method='unsupported_t_method')

    # Try unsupported f_method
    with pytest.raises(ValueError):
        czt.czt(x, t_method='ce', f_method='unsupported_f_method')

    # Plot for debugging purposes
    if debug:
        plt.figure()
        plt.title("Imaginary component")
        plt.plot(X_czt1.imag, label="simple")
        plt.plot(X_czt2.imag, label="ce")
        plt.plot(X_czt3.imag, label="pd")
        plt.plot(X_czt4.imag, label="mm")
        plt.plot(X_czt5.imag, label="scipy")
        plt.plot(X_czt6.imag, label="ce / recursive")
        plt.plot(X_czt7.imag, label="pd / recursive")
        plt.legend()
        plt.figure()
        plt.title("Real component")
        plt.plot(X_czt1.real, label="simple")
        plt.plot(X_czt2.real, label="ce")
        plt.plot(X_czt3.real, label="pd")
        plt.plot(X_czt4.real, label="mm")
        plt.plot(X_czt5.real, label="scipy")
        plt.plot(X_czt6.real, label="ce / recursive")
        plt.plot(X_czt7.real, label="pd / recursive")
        plt.legend()
        plt.figure()
        plt.title("Absolute value")
        plt.plot(np.abs(X_czt1), label="simple")
        plt.plot(np.abs(X_czt2), label="ce")
        plt.plot(np.abs(X_czt3), label="pd")
        plt.plot(np.abs(X_czt4), label="mm")
        plt.plot(np.abs(X_czt5), label="scipy")
        plt.plot(np.abs(X_czt6), label="ce / recursive")
        plt.plot(np.abs(X_czt7), label="pd / recursive")
        plt.legend()
        plt.show()

    # Compare Toeplitz matrix multiplication methods
    np.testing.assert_almost_equal(X_czt1, X_czt2, decimal=12)
    np.testing.assert_almost_equal(X_czt1, X_czt3, decimal=12)
    np.testing.assert_almost_equal(X_czt1, X_czt4, decimal=12)
    np.testing.assert_almost_equal(X_czt1, X_czt5, decimal=12)

    # Compare FFT methods
    np.testing.assert_almost_equal(X_czt1, X_czt6, decimal=12)
    np.testing.assert_almost_equal(X_czt1, X_czt7, decimal=12)
예제 #8
0
def model(t):
    output = (1.0 * np.sin(2 * np.pi * 1e3 * t) +
              0.3 * np.sin(2 * np.pi * 2e3 * t) +
              0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t)
    return output


# Create time-domain data
t = np.arange(0, 20e-3, 1e-4)
dt = t[1] - t[0]
Fs = 1 / dt
N = len(t)
x = model(t)

# CZT
X = czt.czt(x)


# Tests
def test1():
    czt.iczt(X, simple=True)
    return


def test2():
    czt.iczt(X, simple=False)
    return


N = 100
setup = "from __main__ import test1 as test"
예제 #9
0
파일: time_czt.py 프로젝트: PuhlUP/CZT
def test7():
    czt.czt(x, t_method='pd', f_method='recursive')
    return
예제 #10
0
파일: time_czt.py 프로젝트: PuhlUP/CZT
def test6():
    czt.czt(x, t_method='ce', f_method='recursive')
    return
예제 #11
0
파일: time_czt.py 프로젝트: PuhlUP/CZT
def test5():
    czt.czt(x, t_method='scipy')
    return
예제 #12
0
파일: time_czt.py 프로젝트: PuhlUP/CZT
def test4():
    czt.czt(x, t_method='mm')
    return
예제 #13
0
파일: time_czt.py 프로젝트: PuhlUP/CZT
def test3():
    czt.czt(x, t_method='pd')
    return
예제 #14
0
파일: time_czt.py 프로젝트: PuhlUP/CZT
def test2():
    czt.czt(x, t_method='ce')
    return
예제 #15
0
파일: time_czt.py 프로젝트: PuhlUP/CZT
def test1():
    czt.czt(x, simple=True)
    return
예제 #16
0
파일: test_czt.py 프로젝트: z9876/CZT
def test_compare_czt_methods():
    """Compare different CZT calculation methods."""

    # Time data
    t = np.arange(0, 20e-3, 1e-4)
    dt = t[1] - t[0]
    Fs = 1 / dt
    N = len(t)

    # Signal data
    def model(t):
        output = (1.0 * np.sin(2 * np.pi * 1e3 * t) +
                  0.3 * np.sin(2 * np.pi * 2e3 * t) +
                  0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t)
        return output

    x = model(t)

    # Calculate CZT using different methods
    X_czt1 = czt.czt_simple(x)
    X_czt2 = czt.czt(x, t_method='ce')
    X_czt3 = czt.czt(x, t_method='pd')
    X_czt4 = czt.czt(x, t_method='mm')
    X_czt5 = czt.czt(x, t_method='ce', f_method='fast')
    X_czt6 = czt.czt(x, t_method='pd', f_method='fast')
    X_czt7 = czt.czt(x, t_method='mm', f_method='fast')

    # # Debug
    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.plot(np.abs(X_czt1))
    # plt.plot(np.abs(X_czt2))
    # plt.plot(np.abs(X_czt3))
    # plt.plot(np.abs(X_czt4))
    # plt.plot(np.abs(X_czt5))
    # plt.plot(np.abs(X_czt6))
    # plt.plot(np.abs(X_czt7))
    # plt.figure()
    # plt.plot(X_czt1.real)
    # plt.plot(X_czt2.real)
    # plt.plot(X_czt3.real)
    # plt.plot(X_czt4.real)
    # plt.plot(X_czt5.real)
    # plt.plot(X_czt6.real)
    # plt.plot(X_czt7.real)
    # plt.figure()
    # plt.plot(X_czt1.imag)
    # plt.plot(X_czt2.imag)
    # plt.plot(X_czt3.imag)
    # plt.plot(X_czt4.imag)
    # plt.plot(X_czt5.imag)
    # plt.plot(X_czt6.imag)
    # plt.plot(X_czt7.imag)
    # plt.show()

    # Compare Toeplitz matrix multiplication methods
    np.testing.assert_almost_equal(X_czt1, X_czt2, decimal=12)
    np.testing.assert_almost_equal(X_czt1, X_czt3, decimal=12)
    np.testing.assert_almost_equal(X_czt1, X_czt4, decimal=12)

    # Compare FFT methods
    np.testing.assert_almost_equal(X_czt1, X_czt5, decimal=1)
    np.testing.assert_almost_equal(X_czt1, X_czt6, decimal=1)
    np.testing.assert_almost_equal(X_czt1, X_czt7, decimal=1)