def get_batch_pesq_improvement(x_wav,y_wav,y_wav_est,batch_num,set_name): ''' inputs: x_wav, y_wav, y_wav_est: [batch,wave] return: mixture pesq, enhanced pesq, pesq improvement: [batch] ''' # calculate PESQ improvement pesq_ref_cleaned_list = [calc_pesq(ref, cleaned, FLAGS.PARAM.FS) for ref, cleaned in zip(y_wav, y_wav_est)] pesq_ref_mixed_list = [calc_pesq(ref, mixed, FLAGS.PARAM.FS) for ref, mixed in zip(y_wav, x_wav)] pesq_ref_cleaned_vec = np.array(pesq_ref_cleaned_list) pesq_ref_mixed_vec = np.array(pesq_ref_mixed_list) pesq_imp_vec = pesq_ref_cleaned_vec - pesq_ref_mixed_vec if FLAGS.PARAM.GET_AUDIO_IN_TEST: decode_ans_file = os.path.join(FLAGS.PARAM.SAVE_DIR,'decode_'+FLAGS.PARAM.CHECK_POINT, set_name) if os.path.exists(decode_ans_file) and batch_num == 1: shutil.rmtree(decode_ans_file) if not os.path.exists(decode_ans_file): os.makedirs(decode_ans_file) for i, ref, cleaned, mixed, pesqi in zip(range(len(y_wav)), y_wav, y_wav_est, x_wav, pesq_imp_vec): write_audio(os.path.join(decode_ans_file, "%04d_%03d_ref.wav" % (batch_num, i)), ref, FLAGS.PARAM.FS) write_audio(os.path.join(decode_ans_file, "%04d_%03d_cleaned_%.2f.wav" % (batch_num, i, pesqi)), cleaned, FLAGS.PARAM.FS) write_audio(os.path.join(decode_ans_file, "%04d_%03d_mixed.wav" % (batch_num, i)), mixed, FLAGS.PARAM.FS) return np.array([pesq_ref_mixed_vec, pesq_ref_cleaned_vec, pesq_imp_vec])
def decode_and_getMeature(mixed_file_list, ref_list, sess, model, decode_ans_file, save_audio, ans_file): ''' (mixed_dir,ref_dir,sess,model,'decode_nnet_C001_8_2',False,'xxxans.txt') ''' if os.path.exists(os.path.join(decode_ans_file,ans_file)): os.remove(os.path.join(decode_ans_file,ans_file)) pesq_raw_sum = 0 pesq_en_sum = 0 stoi_raw_sum = 0 stoi_en_sum = 0 sdr_raw_sum = 0 sdr_en_sum = 0 for i, mixed_dir in enumerate(mixed_file_list): print('\n',i+1,mixed_dir) waveData, sr = utils.audio_tool.read_audio(mixed_dir) reY, mask = decode_one_wav(sess,model,waveData) abs_max = (2 ** (PARAM.AUDIO_BITS - 1) - 1) reY = np.where(reY > abs_max, abs_max, reY) reY = np.where(reY < -abs_max, -abs_max, reY) file_name = mixed_dir[mixed_dir.rfind('/')+1:mixed_dir.rfind('.')] if save_audio: utils.audio_tool.write_audio(os.path.join(decode_ans_file, (ckpt+'_%03d_' % (i+1))+mixed_dir[mixed_dir.rfind('/')+1:]), reY, sr) spectrum_tool.picture_spec(mask, os.path.join(decode_ans_file, (ckpt+'_%03d_' % (i+1))+file_name)) if i<len(ref_list): ref, sr = utils.audio_tool.read_audio(ref_list[i]) print(' refer: ',ref_list[i]) len_small = min(len(ref),len(waveData),len(reY)) ref = np.array(ref[:len_small]) waveData = np.array(waveData[:len_small]) reY = np.array(reY[:len_small]) # sdr sdr_raw = audio_tool.cal_SDR(np.array([ref]), np.array([waveData])) sdr_en = audio_tool.cal_SDR(np.array([ref]), np.array(reY)) sdr_raw_sum += sdr_raw sdr_en_sum += sdr_en # pesq # pesq_raw = pesq(ref,waveData,sr) # pesq_en = pesq(ref,reY,sr) pesq_raw = pesqexe.calc_pesq(ref,waveData,sr) pesq_en = pesqexe.calc_pesq(ref,reY,sr) pesq_raw_sum += pesq_raw pesq_en_sum += pesq_en # stoi stoi_raw = stoi.stoi(ref,waveData,sr) stoi_en = stoi.stoi(ref,reY,sr) stoi_raw_sum += stoi_raw stoi_en_sum += stoi_en print("SR = %d" % sr) print("PESQ_raw: %.3f, PESQ_en: %.3f, PESQimp: %.3f. " % (pesq_raw,pesq_en,pesq_en-pesq_raw)) print("SDR_raw: %.3f, SDR_en: %.3f, SDRimp: %.3f. " % (sdr_raw,sdr_en,sdr_en-sdr_raw)) print("STOI_raw: %.3f, STOI_en: %.3f, STOIimp: %.3f. " % (stoi_raw,stoi_en,stoi_en-stoi_raw)) sys.stdout.flush() with open(os.path.join(decode_ans_file,ans_file),'a+') as f: f.write(file_name+'\r\n') f.write(" |-PESQ_raw: %.3f, PESQ_en: %.3f, PESQimp: %.3f. \r\n" % (pesq_raw,pesq_en,pesq_en-pesq_raw)) f.write(" |-SDR_raw: %.3f, SDR_en: %.3f, SDRimp: %.3f. \r\n" % (sdr_raw,sdr_en,sdr_en-sdr_raw)) f.write(" |-STOI_raw: %.3f, STOI_en: %.3f, STOIimp: %.3f. \r\n" % (stoi_raw,stoi_en,stoi_en-stoi_raw)) len_list = len(ref_list) with open(os.path.join(decode_ans_file,ans_file),'a+') as f: f.write('PESQ_raw:%.3f, PESQ_en:%.3f, PESQi_avg:%.3f. \r\n' % (pesq_raw_sum/len_list, pesq_en_sum/len_list, (pesq_en_sum-pesq_raw_sum)/len_list)) f.write('SDR_raw:%.3f, SDR_en:%.3f, SDRi_avg:%.3f. \r\n' % (sdr_raw_sum/len_list, sdr_en_sum/len_list, (sdr_en_sum-sdr_raw_sum)/len_list)) f.write('STOI_raw:%.3f, STOI_en:%.3f, STOIi_avg:%.3f. \r\n' % (stoi_raw_sum/len_list, stoi_en_sum/len_list, (stoi_en_sum-stoi_raw_sum)/len_list)) print('\n\n\n-----------------------------------------') print('PESQ_raw:%.3f, PESQ_en:%.3f, PESQi_avg:%.3f. \r\n' % (pesq_raw_sum/len_list, pesq_en_sum/len_list, (pesq_en_sum-pesq_raw_sum)/len_list)) print('SDR_raw:%.3f, SDR_en:%.3f, SDRi_avg:%.3f. \r\n' % (sdr_raw_sum/len_list, sdr_en_sum/len_list, (sdr_en_sum-sdr_raw_sum)/len_list)) print('STOI_raw:%.3f, STOI_en:%.3f, STOIi_avg:%.3f. \r\n' % (stoi_raw_sum/len_list, stoi_en_sum/len_list, (stoi_en_sum-stoi_raw_sum)/len_list)) sys.stdout.flush()