def test_dft_mft(use_fallback_dft, norm): # load training data X, y = load_gunpoint(split="train", return_X_y=True) X_tab = from_nested_to_2d_array(X, return_numpy=True) word_length = 6 alphabet_size = 4 # Single DFT transformation window_size = np.shape(X_tab)[1] p = SFA( word_length=6, alphabet_size=4, window_size=window_size, norm=norm, use_fallback_dft=use_fallback_dft, ).fit(X, y) if use_fallback_dft: dft = p._discrete_fourier_transform(X_tab[0], word_length, norm, 1, True) else: dft = p._fast_fourier_transform(X_tab[0]) mft = p._mft(X_tab[0]) assert (mft - dft < 0.0001).all() # Windowed DFT transformation window_size = 140 p = SFA( word_length=word_length, alphabet_size=alphabet_size, window_size=window_size, norm=norm, use_fallback_dft=use_fallback_dft, ).fit(X, y) mft = p._mft(X_tab[0]) for i in range(len(X_tab[0]) - window_size + 1): if use_fallback_dft: dft = p._discrete_fourier_transform( X_tab[0, i : window_size + i], word_length, norm, 1, True ) else: dft = p._fast_fourier_transform(X_tab[0, i : window_size + i]) assert (mft[i] - dft < 0.001).all() assert len(mft) == len(X_tab[0]) - window_size + 1 assert len(mft[0]) == word_length
def test_dft_mft(): # load training data X, Y = load_gunpoint(split="train", return_X_y=True) X_tab = from_nested_to_2d_array(X, return_numpy=True) word_length = 6 alphabet_size = 4 # print("Single DFT transformation") window_size = np.shape(X_tab)[1] p = SFA( word_length=word_length, alphabet_size=alphabet_size, window_size=window_size, binning_method="equi-depth", ).fit(X, Y) dft = p._discrete_fourier_transform(X_tab[0]) mft = p._mft(X_tab[0]) assert (mft - dft < 0.0001).all() # print("Windowed DFT transformation") for norm in [True, False]: for window_size in [140]: p = SFA( word_length=word_length, norm=norm, alphabet_size=alphabet_size, window_size=window_size, binning_method="equi-depth", ).fit(X, Y) mft = p._mft(X_tab[0]) for i in range(len(X_tab[0]) - window_size + 1): dft_transformed = p._discrete_fourier_transform( X_tab[0, i:window_size + i]) assert (mft[i] - dft_transformed < 0.001).all() assert len(mft) == len(X_tab[0]) - window_size + 1 assert len(mft[0]) == word_length