def test_save_load():
    """test saving/loading of models that has stft, melspectorgrma, and log frequency."""

    src_mono, batch_src, input_shape = get_audio(data_format='channels_last',
                                                 n_ch=1)
    # test STFT save/load
    save_load_compare(STFT(input_shape=input_shape, pad_begin=True), batch_src,
                      allclose_complex_numbers)
    # test melspectrogram save/load
    save_load_compare(
        get_melspectrogram_layer(input_shape=input_shape, return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test log frequency spectrogram save/load
    save_load_compare(
        get_log_frequency_spectrogram_layer(input_shape=input_shape,
                                            return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test stft_mag_phase
    save_load_compare(
        get_stft_mag_phase(input_shape=input_shape, return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test stft mag
    save_load_compare(get_stft_magnitude_layer(input_shape=input_shape),
                      batch_src, np.testing.assert_allclose)
def test_log_spectrogram_fail():
    """test if log spectrogram layer works well"""
    src_mono, batch_src, input_shape = get_audio(data_format='channels_last',
                                                 n_ch=1)
    _ = get_log_frequency_spectrogram_layer(input_shape,
                                            return_decibel=True,
                                            log_n_bins=200)
Exemple #3
0
def test_save_load(save_format):
    """test saving/loading of models that has stft, melspectorgrma, and log frequency."""

    src_mono, batch_src, input_shape = get_audio(data_format='channels_last',
                                                 n_ch=1)
    # test STFT save/load
    save_load_compare(
        STFT(input_shape=input_shape, pad_begin=True),
        batch_src,
        allclose_complex_numbers,
        save_format,
        STFT,
    )

    # test ConcatenateFrequencyMap
    specs_batch = np.random.randn(2, 3, 5, 4).astype(np.float32)
    save_load_compare(
        ConcatenateFrequencyMap(input_shape=specs_batch.shape[1:]),
        specs_batch,
        np.testing.assert_allclose,
        save_format,
        ConcatenateFrequencyMap,
    )

    if save_format == 'tf':
        # test melspectrogram save/load
        save_load_compare(
            get_melspectrogram_layer(input_shape=input_shape,
                                     return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test log frequency spectrogram save/load
        save_load_compare(
            get_log_frequency_spectrogram_layer(input_shape=input_shape,
                                                return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test stft_mag_phase
        save_load_compare(
            get_stft_mag_phase(input_shape=input_shape, return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test stft mag
        save_load_compare(
            get_stft_magnitude_layer(input_shape=input_shape),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
Exemple #4
0
def test_save_load():
    """test saving/loading of models that has stft, melspectorgrma, and log frequency."""

    def _test(layer, input_batch, allclose_func, atol=1e-4):
        """test a model with `layer` with the given `input_batch`.
        The model prediction result is compared using `allclose_func` which may depend on the
        data type of the model output (e.g., float or complex).
        """
        model = tensorflow.keras.models.Sequential()
        model.add(layer)

        result_ref = model(input_batch)

        os_temp_dir = tempfile.gettempdir()
        model_temp_dir = tempfile.TemporaryDirectory(dir=os_temp_dir)
        model.save(filepath=model_temp_dir.name)

        new_model = tf.keras.models.load_model(model_temp_dir.name)
        result_new = new_model(input_batch)
        allclose_func(result_ref, result_new, atol)

        model_temp_dir.cleanup()

        return model

    src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1)
    # test STFT save/load
    _test(STFT(input_shape=input_shape), batch_src, allclose_complex_numbers)
    # test melspectrogram save/load
    _test(
        get_melspectrogram_layer(input_shape=input_shape, return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test log frequency spectrogram save/load
    _test(
        get_log_frequency_spectrogram_layer(input_shape=input_shape, return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
def test_log_spectrogram_runnable(data_format):
    """test if log spectrogram layer works well"""
    src_mono, batch_src, input_shape = get_audio(data_format=data_format,
                                                 n_ch=1)
    _ = get_log_frequency_spectrogram_layer(input_shape, return_decibel=True)
    _ = get_log_frequency_spectrogram_layer(input_shape, return_decibel=False)