def data_process_results_testing(index, voice_true, bg_true, voice_predicted, window_size, mix, mix_magnitude, mix_phase, hop, context_length, output_file_name=None): """Calculates SDR and SIR and creates the resulting audio files. :param index: The index of the current source/track. :type index: int :param voice_true: The true voice. :type voice_true: numpy.core.multiarray.ndarray :param bg_true: The true background music. :type bg_true: numpy.core.multiarray.ndarray :param voice_predicted: The predicted voice. :type voice_predicted: numpy.core.multiarray.ndarray :param window_size: The window size in samples. :type window_size: int :param mix: The mixture. :type mix: numpy.core.multiarray.ndarray :param mix_magnitude: The mixture magnitude. :type mix_magnitude: numpy.core.multiarray.ndarray :param mix_phase: The mixture phase. :type mix_phase: numpy.core.multiarray.ndarray :param hop: The hop size in samples. :type hop: int :param context_length: The context length in frames. :type context_length: int :param output_file_name: The output file name for the predicted voice\ and background music. If this argument is not None, then the function just synthesizes the voice and the background music, and saves them. :type output_file_name: list[str] | None :return: The values of SDR and SIR for each of the frames in\ the current track, for both voice and background music. :rtype: (list[numpy.core.multiarray.ndarray], list[numpy.core.multiarray.ndarray]) """ voice_predicted.shape = (voice_predicted.shape[0] * voice_predicted.shape[1], window_size) mix_magnitude, mix_phase = _context_based_reshaping( mix_magnitude, mix_phase, context_length, window_size) voice_hat = i_stft(voice_predicted, mix_phase, window_size, hop) # Removing the samples that no estimation exists mix = mix[context_length * hop:] if output_file_name is None: voice_true = voice_true[context_length * hop:] bg_true = bg_true[context_length * hop:] min_len = min(len(voice_true), len(voice_hat)) example_index = index + 1 else: voice_true = None bg_true = None example_index = None min_len = min(len(mix), len(voice_hat)) # Background music estimation bg_hat = mix[:min_len] - voice_hat[:min_len] if output_file_name is None: voice_hat_path = output_audio_paths['voice_predicted'].format( p=example_index) bg_hat_path = output_audio_paths['bg_predicted'].format( p=example_index) wav_write( voice_true, file_name=output_audio_paths['voice_true'].format(p=example_index), **wav_quality) wav_write( bg_true, file_name=output_audio_paths['bg_true'].format(p=example_index), **wav_quality) wav_write(mix, file_name=output_audio_paths['mix'].format(p=example_index), **wav_quality) # Metrics calculation sdr, sir = _get_me_the_metrics( bss_eval.bss_eval_images_framewise( [voice_true[:min_len], bg_true[:min_len]], [voice_hat[:min_len], bg_hat[:min_len]])) else: voice_hat_path = output_file_name[0] bg_hat_path = output_file_name[1] sdr = None sir = None wav_write(voice_hat, file_name=voice_hat_path, **wav_quality) wav_write(bg_hat, file_name=bg_hat_path, **wav_quality) return sdr, sir
def test_eval(nnet, B, T, N, L, wsz, hop): """ Method to test the model on the test data. Writes the outcomes in ".wav" format and. stores them under the defined results path. Optionally, it performs BSS-Eval using MIREval python toolbox (Used only for comparison to BSSEval Matlab implementation). The evaluation results are stored under the defined save path. Args: nnet : (List) A list containing the Pytorch modules of the skip-filtering model. B : (int) Batch size. T : (int) Length of the time-sequence. N : (int) The FFT size. L : (int) Number of context frames from the time-sequence. wsz : (int) Window size in samples. hop : (int) Hop size in samples. """ nnet[0].eval() nnet[1].eval() nnet[2].eval() nnet[3].eval() def my_res(mx, vx, L, wsz): """ A helper function to reshape data according to the context frame. """ mx = np.ascontiguousarray(mx[:, L:-L, :], dtype=np.float32) mx.shape = (mx.shape[0] * mx.shape[1], wsz) vx = np.ascontiguousarray(vx[:, L:-L, :], dtype=np.float32) vx.shape = (vx.shape[0] * vx.shape[1], wsz) return mx, vx # Paths for loading and storing the test-set # Generate full paths for test test_sources_list = sorted(os.listdir(sources_path + foldersList[1])) test_sources_list = [ sources_path + foldersList[1] + '/' + i for i in test_sources_list ] # Initializing the containers of the metrics sdr = [] sir = [] sar = [] for indx in xrange(len(test_sources_list)): print('Reading:' + test_sources_list[indx]) # Reading bass, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[0]), mono=False) drums, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[1]), mono=False) oth, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[2]), mono=False) vox, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[3]), mono=False) bk_true = np.sum(bass + drums + oth, axis=-1) * 0.5 mix = np.sum(bass + drums + oth + vox, axis=-1) * 0.5 sv_true = np.sum(vox, axis=-1) * 0.5 # STFT Analysing mx, px = tf.TimeFrequencyDecomposition.STFT(mix, tf.hamming(wsz, True), N, hop) # Data reshaping (magnitude and phase) mx, px, _ = prepare_overlap_sequences(mx, px, px, T, 2 * L, B) # The actual "denoising" part vx_hat = np.zeros((mx.shape[0], T - L * 2, wsz), dtype=np.float32) for batch in xrange(mx.shape[0] / B): H_enc = nnet[0](mx[batch * B:(batch + 1) * B, :, :]) H_j_dec = it_infer.iterative_recurrent_inference(nnet[1], H_enc, criterion=None, tol=1e-3, max_iter=10) vs_hat, mask = nnet[2](H_j_dec, mx[batch * B:(batch + 1) * B, :, :]) y_out = nnet[3](vs_hat) vx_hat[batch * B:(batch + 1) * B, :, :] = y_out.data.cpu().numpy() # Final reshaping vx_hat.shape = (vx_hat.shape[0] * vx_hat.shape[1], wsz) mx, px = my_res(mx, px, L, wsz) # Time-domain recovery # Iterative G-L algorithm for GLiter in range(10): sv_hat = tf.TimeFrequencyDecomposition.iSTFT( vx_hat, px, wsz, hop, True) _, px = tf.TimeFrequencyDecomposition.STFT(sv_hat, tf.hamming(wsz, True), N, hop) # Removing the samples that no estimation exists mix = mix[L * hop:] sv_true = sv_true[L * hop:] bk_true = bk_true[L * hop:] # Background music estimation if len(sv_true) > len(sv_hat): bk_hat = mix[:len(sv_hat)] - sv_hat else: bk_hat = mix - sv_hat[:len(mix)] # Disk writing for external BSS_eval using DSD100-tools (used in our paper) Io.wavWrite( sv_true, 44100, 16, os.path.join(save_path, 'tf_true_sv_' + str(indx) + '.wav')) Io.wavWrite( bk_true, 44100, 16, os.path.join(save_path, 'tf_true_bk_' + str(indx) + '.wav')) Io.wavWrite(sv_hat, 44100, 16, os.path.join(save_path, 'tf_hat_sv_' + str(indx) + '.wav')) Io.wavWrite(bk_hat, 44100, 16, os.path.join(save_path, 'tf_hat_bk_' + str(indx) + '.wav')) Io.wavWrite(mix, 44100, 16, os.path.join(save_path, 'tf_mix_' + str(indx) + '.wav')) # Internal BSSEval using librosa (just for comparison) if len(sv_true) > len(sv_hat): c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise( [sv_true[:len(sv_hat)], bk_true[:len(sv_hat)]], [sv_hat, bk_hat]) else: c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise( [sv_true, bk_true], [sv_hat[:len(sv_true)], bk_hat[:len(sv_true)]]) sdr.append(c_sdr) sir.append(c_sir) sar.append(c_sar) # Storing the results iteratively pickle.dump(sdr, open(os.path.join(save_path, 'SDR.p'), 'wb')) pickle.dump(sir, open(os.path.join(save_path, 'SIR.p'), 'wb')) pickle.dump(sar, open(os.path.join(save_path, 'SAR.p'), 'wb')) return None
IO.AudioIO.audioWrite( svhat, fs, 16, os.path.join(savepath, 'svhat_' + str(fileIndx) + '.m4a'), 'm4a') # Use wavWrite and '.wav' for Matlab-based evaluation IO.AudioIO.audioWrite( bkhat, fs, 16, os.path.join(savepath, 'bkhat_' + str(fileIndx) + '.m4a'), 'm4a') # Use wavWrite and '.wav' for Matlab-based evaluation IO.AudioIO.audioWrite( xsv, fs, 16, os.path.join(savepath, 'svtrue_' + str(fileIndx) + '.m4a'), 'm4a') # Use wavWrite and '.wav' for Matlab-based evaluation IO.AudioIO.audioWrite( xbk, fs, 16, os.path.join(savepath, 'bktrue_' + str(fileIndx) + '.m4a'), 'm4a') # Use wavWrite and '.wav' for Matlab-based evaluation # In case that evaluation takes place in python (Matlab BSSEval-images was used for the paper) print('Evaluating') cSDR, cISR, cSIR, cSAR, _ = bssEval.bss_eval_images_framewise( [xsv, xbk], [svhat, bkhat]) SDR.append(cSDR) ISR.append(cISR) SIR.append(cSIR) SAR.append(cSAR) # Saving Results pickle.dump(SDR, open(os.path.join(savepath, 'SDR.p'), 'wb')) pickle.dump(ISR, open(os.path.join(savepath, 'ISR.p'), 'wb')) pickle.dump(SIR, open(os.path.join(savepath, 'SIR.p'), 'wb')) pickle.dump(SAR, open(os.path.join(savepath, 'SAR.p'), 'wb'))