def test_save_load():
    """test saving/loading of models that has stft, melspectorgrma, and log frequency."""

    src_mono, batch_src, input_shape = get_audio(data_format='channels_last',
                                                 n_ch=1)
    # test STFT save/load
    save_load_compare(STFT(input_shape=input_shape, pad_begin=True), batch_src,
                      allclose_complex_numbers)
    # test melspectrogram save/load
    save_load_compare(
        get_melspectrogram_layer(input_shape=input_shape, return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test log frequency spectrogram save/load
    save_load_compare(
        get_log_frequency_spectrogram_layer(input_shape=input_shape,
                                            return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test stft_mag_phase
    save_load_compare(
        get_stft_mag_phase(input_shape=input_shape, return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test stft mag
    save_load_compare(get_stft_magnitude_layer(input_shape=input_shape),
                      batch_src, np.testing.assert_allclose)
Beispiel #2
0
def test_mag_phase(data_format):
    n_ch = 1
    n_fft, hop_length, win_length = 512, 256, 512

    src_mono, batch_src, input_shape = get_audio(data_format=data_format, n_ch=n_ch)

    mag_phase_layer = get_stft_mag_phase(
        input_shape=input_shape,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        input_data_format=data_format,
        output_data_format=data_format,
    )
    model = tensorflow.keras.models.Sequential()
    model.add(mag_phase_layer)
    mag_phase_kapre = model(batch_src)[0]  # a 2d image shape

    ch_axis = 0 if data_format == 'channels_first' else 2  # non-batch
    mag_phase_ref = np.stack(
        librosa.magphase(
            librosa.stft(
                src_mono, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False,
            ).T
        ),
        axis=ch_axis,
    )
    np.testing.assert_equal(mag_phase_kapre.shape, mag_phase_ref.shape)
    # magnitude test
    np.testing.assert_allclose(
        np.take(mag_phase_kapre, [0,], axis=ch_axis,),
        np.take(mag_phase_ref, [0,], axis=ch_axis,),
        atol=2e-4,
    )
Beispiel #3
0
def test_save_load(save_format):
    """test saving/loading of models that has stft, melspectorgrma, and log frequency."""

    src_mono, batch_src, input_shape = get_audio(data_format='channels_last',
                                                 n_ch=1)
    # test STFT save/load
    save_load_compare(
        STFT(input_shape=input_shape, pad_begin=True),
        batch_src,
        allclose_complex_numbers,
        save_format,
        STFT,
    )

    # test ConcatenateFrequencyMap
    specs_batch = np.random.randn(2, 3, 5, 4).astype(np.float32)
    save_load_compare(
        ConcatenateFrequencyMap(input_shape=specs_batch.shape[1:]),
        specs_batch,
        np.testing.assert_allclose,
        save_format,
        ConcatenateFrequencyMap,
    )

    if save_format == 'tf':
        # test melspectrogram save/load
        save_load_compare(
            get_melspectrogram_layer(input_shape=input_shape,
                                     return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test log frequency spectrogram save/load
        save_load_compare(
            get_log_frequency_spectrogram_layer(input_shape=input_shape,
                                                return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test stft_mag_phase
        save_load_compare(
            get_stft_mag_phase(input_shape=input_shape, return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test stft mag
        save_load_compare(
            get_stft_magnitude_layer(input_shape=input_shape),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )