示例#1
0
def test_spectrogram_correctness_more(data_format, window_name):
    def _get_stft_model(following_layer=None):
        # compute with kapre
        stft_model = tensorflow.keras.models.Sequential()
        stft_model.add(
            STFT(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_name=window_name,
                pad_end=False,
                input_data_format=data_format,
                output_data_format=data_format,
                input_shape=input_shape,
                name='stft',
            )
        )
        if following_layer is not None:
            stft_model.add(following_layer)
        return stft_model

    n_fft = 512
    hop_length = 256
    n_ch = 2

    src_mono, batch_src, input_shape = get_audio(data_format=data_format, n_ch=n_ch)
    win_length = n_fft  # test with x2
    # compute with librosa
    S_ref = librosa.core.stft(
        src_mono,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        center=False,
        window=window_name.replace('_window', '') if window_name else 'hann',
    ).T  # (time, freq)

    S_ref = np.expand_dims(S_ref, axis=2)  # time, freq, ch=1
    S_ref = np.tile(S_ref, [1, 1, n_ch])  # time, freq, ch=n_ch
    if data_format == 'channels_first':
        S_ref = np.transpose(S_ref, (2, 0, 1))  # ch, time, freq

    stft_model = _get_stft_model()

    S_complex = stft_model.predict(batch_src)[0]  # 3d representation
    allclose_complex_numbers(S_ref, S_complex)

    # test Magnitude()
    stft_mag_model = _get_stft_model(Magnitude())
    S = stft_mag_model.predict(batch_src)[0]  # 3d representation
    np.testing.assert_allclose(np.abs(S_ref), S, atol=2e-4)

    # # test Phase()
    stft_phase_model = _get_stft_model(Phase())
    S = stft_phase_model.predict(batch_src)[0]  # 3d representation
    allclose_phase(np.angle(S_complex), S)
示例#2
0
def test_spectrogram_tflite_correctness(
    n_fft, hop_length, n_ch, data_format, batch_size, win_length, pad_end
):
    def _get_stft_model(following_layer=None, tflite_compatible=False):
        # compute with kapre
        stft_model = tensorflow.keras.models.Sequential()
        if tflite_compatible:
            stft_model.add(
                STFTTflite(
                    n_fft=n_fft,
                    win_length=win_length,
                    hop_length=hop_length,
                    window_name=None,
                    pad_end=pad_end,
                    input_data_format=data_format,
                    output_data_format=data_format,
                    input_shape=input_shape,
                    name='stft',
                )
            )
        else:
            stft_model.add(
                STFT(
                    n_fft=n_fft,
                    win_length=win_length,
                    hop_length=hop_length,
                    window_name=None,
                    pad_end=pad_end,
                    input_data_format=data_format,
                    output_data_format=data_format,
                    input_shape=input_shape,
                    name='stft',
                )
            )
        if following_layer is not None:
            stft_model.add(following_layer)
        return stft_model

    src_mono, batch_src, input_shape = get_audio(
        data_format=data_format, n_ch=n_ch, batch_size=batch_size
    )
    # tflite requires a known batch size
    batch_size = batch_src.shape[0]

    stft_model_tflite = _get_stft_model(tflite_compatible=True)
    stft_model = _get_stft_model(tflite_compatible=False)

    # test STFT()
    S_complex_tflite = predict_using_tflite(stft_model_tflite, batch_src)  # predict using tflite
    # (batch, time, freq, chan, re/imag) - convert to complex number:
    S_complex_tflite = tf.complex(
        S_complex_tflite[..., 0], S_complex_tflite[..., 1]
    )  # (batch,time,freq,chan)
    S_complex = stft_model.predict(batch_src)  # predict using tf model
    allclose_complex_numbers(S_complex, S_complex_tflite)

    # test Magnitude()
    stft_mag_model_tflite = _get_stft_model(MagnitudeTflite(), tflite_compatible=True)
    stft_mag_model = _get_stft_model(Magnitude(), tflite_compatible=False)
    S_lite = predict_using_tflite(stft_mag_model_tflite, batch_src)  # predict using tflite
    S = stft_mag_model.predict(batch_src)  # predict using tf model
    np.testing.assert_allclose(S, S_lite, atol=1e-4)

    # # test approx Phase() same for tflite and non-tflite
    stft_approx_phase_model_lite = _get_stft_model(
        PhaseTflite(approx_atan_accuracy=500), tflite_compatible=True
    )
    stft_approx_phase_model = _get_stft_model(
        Phase(approx_atan_accuracy=500), tflite_compatible=False
    )
    S_approx_phase_lite = predict_using_tflite(
        stft_approx_phase_model_lite, batch_src
    )  # predict using tflite
    S_approx_phase = stft_approx_phase_model.predict(
        batch_src, batch_size=batch_size
    )  # predict using tf model
    assert_approx_phase(S_approx_phase_lite, S_approx_phase, atol=1e-2, acceptable_fail_ratio=0.01)

    # # test accuracy of approx Phase()
    stft_phase_model = _get_stft_model(Phase(), tflite_compatible=False)
    S_phase = stft_phase_model.predict(batch_src, batch_size=batch_size)  # predict using tf model
    assert_approx_phase(S_approx_phase_lite, S_phase, atol=1e-2, acceptable_fail_ratio=0.01)