def test_compare_czt_fft_dft(debug=False): print("Compare CZT, FFT and DFT") # Create time-domain signal t = np.arange(0, 20e-3 + 1e-10, 1e-4) x = _signal_model(t) dt = t[1] - t[0] fs = 1 / dt # Frequency sweep f = np.fft.fftshift(np.fft.fftfreq(len(t)) * fs) # CZT (defaults to FFT settings) X_czt = np.fft.fftshift(czt.czt(x)) # FFT X_fft = np.fft.fftshift(np.fft.fft(x)) # DFT (defaults to FFT settings) _, X_dft = czt.dft(t, x) # Plot for debugging purposes if debug: plt.figure() plt.title("Imaginary") plt.plot(f, X_czt.imag, label='CZT') plt.plot(f, X_fft.imag, label='FFT', ls='--') plt.plot(f, X_dft.imag, label='DFT', ls='--') plt.legend() plt.figure() plt.title("Real") plt.plot(f, X_czt.real, label='CZT') plt.plot(f, X_fft.real, label='FFT', ls='--') plt.plot(f, X_dft.real, label='DFT', ls='--') plt.legend() plt.figure() plt.title("Absolute") plt.plot(f, np.abs(X_czt), label='CZT') plt.plot(f, np.abs(X_fft), label='FFT', ls='--') plt.plot(f, np.abs(X_dft), label='DFT', ls='--') plt.legend() plt.show() # Compare np.testing.assert_almost_equal(X_czt, X_fft, decimal=12) np.testing.assert_almost_equal(X_czt, X_dft, decimal=12)
def test_czt_to_iczt(debug=False): print("Test CZT -> ICZT") # Create time-domain signal t = np.arange(0, 20e-3, 1e-4) x = _signal_model(t) # CZT (defaults to FFT) X_czt = czt.czt(x) # ICZT x_iczt1 = czt.iczt(X_czt) x_iczt2 = czt.iczt(X_czt, simple=False) # Try unsupported t_method with pytest.raises(ValueError): czt.iczt(X_czt, simple=False, t_method='unsupported_t_method') # Try M != N with pytest.raises(ValueError): czt.iczt(X_czt, simple=False, N=len(X_czt) + 1) # Plot for debugging purposes if debug: plt.figure() plt.title("Imaginary") plt.plot(t * 1e3, x.imag) plt.plot(t * 1e3, x_iczt1.imag) plt.plot(t * 1e3, x_iczt2.imag) plt.figure() plt.title("Real") plt.plot(t * 1e3, x.real) plt.plot(t * 1e3, x_iczt1.real) plt.plot(t * 1e3, x_iczt2.real) plt.show() # Compare np.testing.assert_almost_equal(x, x_iczt1, decimal=12) np.testing.assert_almost_equal(x, x_iczt2, decimal=12)
def test_compare_czt_to_fft(): """Compare CZT to FFT.""" # Time data t = np.arange(0, 20e-3, 1e-4) dt = t[1] - t[0] Fs = 1 / dt N = len(t) # Signal data def model(t): output = (1.0 * np.sin(2 * np.pi * 1e3 * t) + 0.3 * np.sin(2 * np.pi * 2e3 * t) + 0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t) return output x = model(t) # CZT (defaults to FFT) X_czt = czt.czt(x) # FFT X_fft = np.fft.fft(x) # # Debug # import matplotlib.pyplot as plt # plt.figure() # plt.plot(np.abs(X_czt), 'k') # plt.plot(np.abs(X_fft), 'r--') # plt.figure() # plt.plot(X_czt.real, 'k') # plt.plot(X_fft.real, 'r--') # plt.figure() # plt.plot(X_czt.imag, 'k') # plt.plot(X_fft.imag, 'r--') # plt.show() # Compare np.testing.assert_almost_equal(X_czt, X_fft, decimal=12)
def test_iczt(): """Test inverse CZT.""" # Time data t = np.arange(0, 20e-3, 1e-4) dt = t[1] - t[0] Fs = 1 / dt N = len(t) # Signal data def model(t): output = (1.0 * np.sin(2 * np.pi * 1e3 * t) + 0.3 * np.sin(2 * np.pi * 2e3 * t) + 0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t) return output x = model(t) # CZT (defaults to FFT) X_czt = czt.czt(x) # ICZT x_iczt = czt.iczt(X_czt) # # Debug # import matplotlib.pyplot as plt # plt.figure() # plt.plot(x.real) # plt.plot(x_iczt.real) # plt.figure() # plt.plot(x.imag) # plt.plot(x_iczt.imag) # plt.show() # Compare np.testing.assert_almost_equal(x, x_iczt, decimal=12)
"""Benchmark czt.iczt function.""" import numpy as np import czt import perfplot def model(t): output = (1.0 * np.sin(2 * np.pi * 1e3 * t) + 0.3 * np.sin(2 * np.pi * 2e3 * t) + 0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t) return output perfplot.show( setup=lambda n: czt.czt(model(np.linspace(0, 20e-3, n))), kernels=[ lambda a: czt.iczt(a, simple=True), lambda a: czt.iczt(a, simple=False), ], labels=["simple=True", "simple=False"], n_range=[10**k for k in range(1, 8)], xlabel="Input length", # equality_check=np.allclose, equality_check=False, target_time_per_measurement=0.1, )
import czt import perfplot def model(t): """Signal model.""" output = (1.0 * np.sin(2 * np.pi * 1e3 * t) + 0.3 * np.sin(2 * np.pi * 2e3 * t) + 0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t) return output perfplot.show( setup=lambda n: model(np.linspace(0, 20e-3, n)), kernels=[ # lambda a: czt.czt(a, simple=True), lambda a: czt.czt(a, t_method='ce'), lambda a: czt.czt(a, t_method='pd'), # lambda a: czt.czt(a, t_method='mm'), lambda a: czt.czt(a, t_method='scipy'), # lambda a: czt.czt(a, t_method='ce', f_method='recursive'), # lambda a: czt.czt(a, t_method='pd', f_method='recursive'), ], # labels=["simple", "ce", "pd", "mm", "scipy", "ce/recursive", "pd/recursive"], labels=["ce", "pd", "scipy"], n_range=[10**k for k in range(1, 7)], xlabel="Input length", equality_check=np.allclose, target_time_per_measurement=0.1, )
def test_compare_different_czt_methods(debug=False): print("Compare different CZT calculation methods") # Create time-domain signal t = np.arange(0, 20e-3, 1e-4) x = _signal_model(t) # Calculate CZT using different methods X_czt1 = czt.czt(x, simple=True) X_czt2 = czt.czt(x, t_method='ce') X_czt3 = czt.czt(x, t_method='pd') X_czt4 = czt.czt(x, t_method='mm') X_czt5 = czt.czt(x, t_method='scipy') X_czt6 = czt.czt(x, t_method='ce', f_method='recursive') X_czt7 = czt.czt(x, t_method='pd', f_method='recursive') # Try unsupported t_method with pytest.raises(ValueError): czt.czt(x, t_method='unsupported_t_method') # Try unsupported f_method with pytest.raises(ValueError): czt.czt(x, t_method='ce', f_method='unsupported_f_method') # Plot for debugging purposes if debug: plt.figure() plt.title("Imaginary component") plt.plot(X_czt1.imag, label="simple") plt.plot(X_czt2.imag, label="ce") plt.plot(X_czt3.imag, label="pd") plt.plot(X_czt4.imag, label="mm") plt.plot(X_czt5.imag, label="scipy") plt.plot(X_czt6.imag, label="ce / recursive") plt.plot(X_czt7.imag, label="pd / recursive") plt.legend() plt.figure() plt.title("Real component") plt.plot(X_czt1.real, label="simple") plt.plot(X_czt2.real, label="ce") plt.plot(X_czt3.real, label="pd") plt.plot(X_czt4.real, label="mm") plt.plot(X_czt5.real, label="scipy") plt.plot(X_czt6.real, label="ce / recursive") plt.plot(X_czt7.real, label="pd / recursive") plt.legend() plt.figure() plt.title("Absolute value") plt.plot(np.abs(X_czt1), label="simple") plt.plot(np.abs(X_czt2), label="ce") plt.plot(np.abs(X_czt3), label="pd") plt.plot(np.abs(X_czt4), label="mm") plt.plot(np.abs(X_czt5), label="scipy") plt.plot(np.abs(X_czt6), label="ce / recursive") plt.plot(np.abs(X_czt7), label="pd / recursive") plt.legend() plt.show() # Compare Toeplitz matrix multiplication methods np.testing.assert_almost_equal(X_czt1, X_czt2, decimal=12) np.testing.assert_almost_equal(X_czt1, X_czt3, decimal=12) np.testing.assert_almost_equal(X_czt1, X_czt4, decimal=12) np.testing.assert_almost_equal(X_czt1, X_czt5, decimal=12) # Compare FFT methods np.testing.assert_almost_equal(X_czt1, X_czt6, decimal=12) np.testing.assert_almost_equal(X_czt1, X_czt7, decimal=12)
def model(t): output = (1.0 * np.sin(2 * np.pi * 1e3 * t) + 0.3 * np.sin(2 * np.pi * 2e3 * t) + 0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t) return output # Create time-domain data t = np.arange(0, 20e-3, 1e-4) dt = t[1] - t[0] Fs = 1 / dt N = len(t) x = model(t) # CZT X = czt.czt(x) # Tests def test1(): czt.iczt(X, simple=True) return def test2(): czt.iczt(X, simple=False) return N = 100 setup = "from __main__ import test1 as test"
def test7(): czt.czt(x, t_method='pd', f_method='recursive') return
def test6(): czt.czt(x, t_method='ce', f_method='recursive') return
def test5(): czt.czt(x, t_method='scipy') return
def test4(): czt.czt(x, t_method='mm') return
def test3(): czt.czt(x, t_method='pd') return
def test2(): czt.czt(x, t_method='ce') return
def test1(): czt.czt(x, simple=True) return
def test_compare_czt_methods(): """Compare different CZT calculation methods.""" # Time data t = np.arange(0, 20e-3, 1e-4) dt = t[1] - t[0] Fs = 1 / dt N = len(t) # Signal data def model(t): output = (1.0 * np.sin(2 * np.pi * 1e3 * t) + 0.3 * np.sin(2 * np.pi * 2e3 * t) + 0.1 * np.sin(2 * np.pi * 3e3 * t)) * np.exp(-1e3 * t) return output x = model(t) # Calculate CZT using different methods X_czt1 = czt.czt_simple(x) X_czt2 = czt.czt(x, t_method='ce') X_czt3 = czt.czt(x, t_method='pd') X_czt4 = czt.czt(x, t_method='mm') X_czt5 = czt.czt(x, t_method='ce', f_method='fast') X_czt6 = czt.czt(x, t_method='pd', f_method='fast') X_czt7 = czt.czt(x, t_method='mm', f_method='fast') # # Debug # import matplotlib.pyplot as plt # plt.figure() # plt.plot(np.abs(X_czt1)) # plt.plot(np.abs(X_czt2)) # plt.plot(np.abs(X_czt3)) # plt.plot(np.abs(X_czt4)) # plt.plot(np.abs(X_czt5)) # plt.plot(np.abs(X_czt6)) # plt.plot(np.abs(X_czt7)) # plt.figure() # plt.plot(X_czt1.real) # plt.plot(X_czt2.real) # plt.plot(X_czt3.real) # plt.plot(X_czt4.real) # plt.plot(X_czt5.real) # plt.plot(X_czt6.real) # plt.plot(X_czt7.real) # plt.figure() # plt.plot(X_czt1.imag) # plt.plot(X_czt2.imag) # plt.plot(X_czt3.imag) # plt.plot(X_czt4.imag) # plt.plot(X_czt5.imag) # plt.plot(X_czt6.imag) # plt.plot(X_czt7.imag) # plt.show() # Compare Toeplitz matrix multiplication methods np.testing.assert_almost_equal(X_czt1, X_czt2, decimal=12) np.testing.assert_almost_equal(X_czt1, X_czt3, decimal=12) np.testing.assert_almost_equal(X_czt1, X_czt4, decimal=12) # Compare FFT methods np.testing.assert_almost_equal(X_czt1, X_czt5, decimal=1) np.testing.assert_almost_equal(X_czt1, X_czt6, decimal=1) np.testing.assert_almost_equal(X_czt1, X_czt7, decimal=1)