def test_spectrogram_correctness_more(data_format, window_name): def _get_stft_model(following_layer=None): # compute with kapre stft_model = tensorflow.keras.models.Sequential() stft_model.add( STFT( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window_name=window_name, pad_end=False, input_data_format=data_format, output_data_format=data_format, input_shape=input_shape, name='stft', ) ) if following_layer is not None: stft_model.add(following_layer) return stft_model n_fft = 512 hop_length = 256 n_ch = 2 src_mono, batch_src, input_shape = get_audio(data_format=data_format, n_ch=n_ch) win_length = n_fft # test with x2 # compute with librosa S_ref = librosa.core.stft( src_mono, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, window=window_name.replace('_window', '') if window_name else 'hann', ).T # (time, freq) S_ref = np.expand_dims(S_ref, axis=2) # time, freq, ch=1 S_ref = np.tile(S_ref, [1, 1, n_ch]) # time, freq, ch=n_ch if data_format == 'channels_first': S_ref = np.transpose(S_ref, (2, 0, 1)) # ch, time, freq stft_model = _get_stft_model() S_complex = stft_model.predict(batch_src)[0] # 3d representation allclose_complex_numbers(S_ref, S_complex) # test Magnitude() stft_mag_model = _get_stft_model(Magnitude()) S = stft_mag_model.predict(batch_src)[0] # 3d representation np.testing.assert_allclose(np.abs(S_ref), S, atol=2e-4) # # test Phase() stft_phase_model = _get_stft_model(Phase()) S = stft_phase_model.predict(batch_src)[0] # 3d representation allclose_phase(np.angle(S_complex), S)
def test_spectrogram_tflite_correctness( n_fft, hop_length, n_ch, data_format, batch_size, win_length, pad_end ): def _get_stft_model(following_layer=None, tflite_compatible=False): # compute with kapre stft_model = tensorflow.keras.models.Sequential() if tflite_compatible: stft_model.add( STFTTflite( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window_name=None, pad_end=pad_end, input_data_format=data_format, output_data_format=data_format, input_shape=input_shape, name='stft', ) ) else: stft_model.add( STFT( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window_name=None, pad_end=pad_end, input_data_format=data_format, output_data_format=data_format, input_shape=input_shape, name='stft', ) ) if following_layer is not None: stft_model.add(following_layer) return stft_model src_mono, batch_src, input_shape = get_audio( data_format=data_format, n_ch=n_ch, batch_size=batch_size ) # tflite requires a known batch size batch_size = batch_src.shape[0] stft_model_tflite = _get_stft_model(tflite_compatible=True) stft_model = _get_stft_model(tflite_compatible=False) # test STFT() S_complex_tflite = predict_using_tflite(stft_model_tflite, batch_src) # predict using tflite # (batch, time, freq, chan, re/imag) - convert to complex number: S_complex_tflite = tf.complex( S_complex_tflite[..., 0], S_complex_tflite[..., 1] ) # (batch,time,freq,chan) S_complex = stft_model.predict(batch_src) # predict using tf model allclose_complex_numbers(S_complex, S_complex_tflite) # test Magnitude() stft_mag_model_tflite = _get_stft_model(MagnitudeTflite(), tflite_compatible=True) stft_mag_model = _get_stft_model(Magnitude(), tflite_compatible=False) S_lite = predict_using_tflite(stft_mag_model_tflite, batch_src) # predict using tflite S = stft_mag_model.predict(batch_src) # predict using tf model np.testing.assert_allclose(S, S_lite, atol=1e-4) # # test approx Phase() same for tflite and non-tflite stft_approx_phase_model_lite = _get_stft_model( PhaseTflite(approx_atan_accuracy=500), tflite_compatible=True ) stft_approx_phase_model = _get_stft_model( Phase(approx_atan_accuracy=500), tflite_compatible=False ) S_approx_phase_lite = predict_using_tflite( stft_approx_phase_model_lite, batch_src ) # predict using tflite S_approx_phase = stft_approx_phase_model.predict( batch_src, batch_size=batch_size ) # predict using tf model assert_approx_phase(S_approx_phase_lite, S_approx_phase, atol=1e-2, acceptable_fail_ratio=0.01) # # test accuracy of approx Phase() stft_phase_model = _get_stft_model(Phase(), tflite_compatible=False) S_phase = stft_phase_model.predict(batch_src, batch_size=batch_size) # predict using tf model assert_approx_phase(S_approx_phase_lite, S_phase, atol=1e-2, acceptable_fail_ratio=0.01)