Exemplo n.º 1
0
def test_save_load(data_format):
    src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1)
    # test Frame save/load
    save_load_compare(
        Frame(frame_length=128, hop_length=64, input_shape=input_shape),
        batch_src,
        np.testing.assert_allclose,
    )
    # test Energy save/load
    save_load_compare(
        Energy(frame_length=128, hop_length=64, input_shape=input_shape),
        batch_src,
        np.testing.assert_allclose,
    )
    # test mu law layers
    save_load_compare(
        MuLawEncoding(quantization_channels=128),
        batch_src,
        np.testing.assert_allclose,
    )
    save_load_compare(
        MuLawDecoding(quantization_channels=128),
        np.arange(0, 256, 1).reshape((1, 256, 1)),
        np.testing.assert_allclose,
    )
    # test mfcc layer
    expand_dim = (0, 3) if data_format in (_CH_LAST_STR, _CH_DEFAULT_STR) else (0, 1)
    save_load_compare(
        LogmelToMFCC(n_mfccs=10),
        np.expand_dims(librosa.power_to_db(librosa.feature.melspectrogram(src_mono).T), expand_dim),
        np.testing.assert_allclose,
    )
Exemplo n.º 2
0
def test_save_load():
    """test saving/loading of models that has stft, melspectorgrma, and log frequency."""

    src_mono, batch_src, input_shape = get_audio(data_format='channels_last',
                                                 n_ch=1)
    # test STFT save/load
    save_load_compare(STFT(input_shape=input_shape, pad_begin=True), batch_src,
                      allclose_complex_numbers)
    # test melspectrogram save/load
    save_load_compare(
        get_melspectrogram_layer(input_shape=input_shape, return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test log frequency spectrogram save/load
    save_load_compare(
        get_log_frequency_spectrogram_layer(input_shape=input_shape,
                                            return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test stft_mag_phase
    save_load_compare(
        get_stft_mag_phase(input_shape=input_shape, return_decibel=True),
        batch_src,
        np.testing.assert_allclose,
    )
    # test stft mag
    save_load_compare(get_stft_magnitude_layer(input_shape=input_shape),
                      batch_src, np.testing.assert_allclose)
Exemplo n.º 3
0
def test_save_load_channel_swap(data_format, save_format):
    src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1)

    save_load_compare(
        ChannelSwap(input_shape=input_shape),
        batch_src,
        np.testing.assert_allclose,
        save_format=save_format,
        layer_class=ChannelSwap,
        training=None,
    )
Exemplo n.º 4
0
def test_save_load_spec_augment(data_format, save_format):
    batch_src, input_shape = get_spectrogram(data_format=data_format)

    spec_augment = SpecAugment(
        input_shape=input_shape,
        freq_mask_param=5,
        time_mask_param=10,
        n_freq_masks=4,
        n_time_masks=3,
        mask_value=0.0,
        data_format=data_format,
    )
    save_load_compare(
        spec_augment,
        batch_src,
        np.testing.assert_allclose,
        save_format=save_format,
        layer_class=SpecAugment,
        training=None,
    )
Exemplo n.º 5
0
def test_save_load(save_format):
    """test saving/loading of models that has stft, melspectorgrma, and log frequency."""

    src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1)
    # test STFT save/load
    save_load_compare(
        STFT(input_shape=input_shape, pad_begin=True),
        batch_src,
        allclose_complex_numbers,
        save_format,
        STFT,
    )

    # test ConcatenateFrequencyMap
    specs_batch = np.random.randn(2, 3, 5, 4).astype(np.float32)
    save_load_compare(
        ConcatenateFrequencyMap(input_shape=specs_batch.shape[1:]),
        specs_batch,
        np.testing.assert_allclose,
        save_format,
        ConcatenateFrequencyMap,
    )

    if save_format == 'tf':
        # test melspectrogram save/load
        save_load_compare(
            get_melspectrogram_layer(input_shape=input_shape, return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test log frequency spectrogram save/load
        save_load_compare(
            get_log_frequency_spectrogram_layer(input_shape=input_shape, return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test stft_mag_phase
        save_load_compare(
            get_stft_mag_phase(input_shape=input_shape, return_decibel=True),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )
        # test stft mag
        save_load_compare(
            get_stft_magnitude_layer(input_shape=input_shape),
            batch_src,
            np.testing.assert_allclose,
            save_format,
        )