def test_save_load(data_format): src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1) # test Frame save/load save_load_compare( Frame(frame_length=128, hop_length=64, input_shape=input_shape), batch_src, np.testing.assert_allclose, ) # test Energy save/load save_load_compare( Energy(frame_length=128, hop_length=64, input_shape=input_shape), batch_src, np.testing.assert_allclose, ) # test mu law layers save_load_compare( MuLawEncoding(quantization_channels=128), batch_src, np.testing.assert_allclose, ) save_load_compare( MuLawDecoding(quantization_channels=128), np.arange(0, 256, 1).reshape((1, 256, 1)), np.testing.assert_allclose, ) # test mfcc layer expand_dim = (0, 3) if data_format in (_CH_LAST_STR, _CH_DEFAULT_STR) else (0, 1) save_load_compare( LogmelToMFCC(n_mfccs=10), np.expand_dims(librosa.power_to_db(librosa.feature.melspectrogram(src_mono).T), expand_dim), np.testing.assert_allclose, )
def test_save_load(): """test saving/loading of models that has stft, melspectorgrma, and log frequency.""" src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1) # test STFT save/load save_load_compare(STFT(input_shape=input_shape, pad_begin=True), batch_src, allclose_complex_numbers) # test melspectrogram save/load save_load_compare( get_melspectrogram_layer(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, ) # test log frequency spectrogram save/load save_load_compare( get_log_frequency_spectrogram_layer(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, ) # test stft_mag_phase save_load_compare( get_stft_mag_phase(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, ) # test stft mag save_load_compare(get_stft_magnitude_layer(input_shape=input_shape), batch_src, np.testing.assert_allclose)
def test_save_load_channel_swap(data_format, save_format): src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1) save_load_compare( ChannelSwap(input_shape=input_shape), batch_src, np.testing.assert_allclose, save_format=save_format, layer_class=ChannelSwap, training=None, )
def test_save_load_spec_augment(data_format, save_format): batch_src, input_shape = get_spectrogram(data_format=data_format) spec_augment = SpecAugment( input_shape=input_shape, freq_mask_param=5, time_mask_param=10, n_freq_masks=4, n_time_masks=3, mask_value=0.0, data_format=data_format, ) save_load_compare( spec_augment, batch_src, np.testing.assert_allclose, save_format=save_format, layer_class=SpecAugment, training=None, )
def test_save_load(save_format): """test saving/loading of models that has stft, melspectorgrma, and log frequency.""" src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1) # test STFT save/load save_load_compare( STFT(input_shape=input_shape, pad_begin=True), batch_src, allclose_complex_numbers, save_format, STFT, ) # test ConcatenateFrequencyMap specs_batch = np.random.randn(2, 3, 5, 4).astype(np.float32) save_load_compare( ConcatenateFrequencyMap(input_shape=specs_batch.shape[1:]), specs_batch, np.testing.assert_allclose, save_format, ConcatenateFrequencyMap, ) if save_format == 'tf': # test melspectrogram save/load save_load_compare( get_melspectrogram_layer(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, save_format, ) # test log frequency spectrogram save/load save_load_compare( get_log_frequency_spectrogram_layer(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, save_format, ) # test stft_mag_phase save_load_compare( get_stft_mag_phase(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, save_format, ) # test stft mag save_load_compare( get_stft_magnitude_layer(input_shape=input_shape), batch_src, np.testing.assert_allclose, save_format, )