def test_save_load(): """test saving/loading of models that has stft, melspectorgrma, and log frequency.""" src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1) # test STFT save/load save_load_compare(STFT(input_shape=input_shape, pad_begin=True), batch_src, allclose_complex_numbers) # test melspectrogram save/load save_load_compare( get_melspectrogram_layer(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, ) # test log frequency spectrogram save/load save_load_compare( get_log_frequency_spectrogram_layer(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, ) # test stft_mag_phase save_load_compare( get_stft_mag_phase(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, ) # test stft mag save_load_compare(get_stft_magnitude_layer(input_shape=input_shape), batch_src, np.testing.assert_allclose)
def test_mag_phase(data_format): n_ch = 1 n_fft, hop_length, win_length = 512, 256, 512 src_mono, batch_src, input_shape = get_audio(data_format=data_format, n_ch=n_ch) mag_phase_layer = get_stft_mag_phase( input_shape=input_shape, n_fft=n_fft, win_length=win_length, hop_length=hop_length, input_data_format=data_format, output_data_format=data_format, ) model = tensorflow.keras.models.Sequential() model.add(mag_phase_layer) mag_phase_kapre = model(batch_src)[0] # a 2d image shape ch_axis = 0 if data_format == 'channels_first' else 2 # non-batch mag_phase_ref = np.stack( librosa.magphase( librosa.stft( src_mono, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, ).T ), axis=ch_axis, ) np.testing.assert_equal(mag_phase_kapre.shape, mag_phase_ref.shape) # magnitude test np.testing.assert_allclose( np.take(mag_phase_kapre, [0,], axis=ch_axis,), np.take(mag_phase_ref, [0,], axis=ch_axis,), atol=2e-4, )
def test_save_load(save_format): """test saving/loading of models that has stft, melspectorgrma, and log frequency.""" src_mono, batch_src, input_shape = get_audio(data_format='channels_last', n_ch=1) # test STFT save/load save_load_compare( STFT(input_shape=input_shape, pad_begin=True), batch_src, allclose_complex_numbers, save_format, STFT, ) # test ConcatenateFrequencyMap specs_batch = np.random.randn(2, 3, 5, 4).astype(np.float32) save_load_compare( ConcatenateFrequencyMap(input_shape=specs_batch.shape[1:]), specs_batch, np.testing.assert_allclose, save_format, ConcatenateFrequencyMap, ) if save_format == 'tf': # test melspectrogram save/load save_load_compare( get_melspectrogram_layer(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, save_format, ) # test log frequency spectrogram save/load save_load_compare( get_log_frequency_spectrogram_layer(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, save_format, ) # test stft_mag_phase save_load_compare( get_stft_mag_phase(input_shape=input_shape, return_decibel=True), batch_src, np.testing.assert_allclose, save_format, ) # test stft mag save_load_compare( get_stft_magnitude_layer(input_shape=input_shape), batch_src, np.testing.assert_allclose, save_format, )