def eval(self, audio_est, audio_ref): x_est, fs_est = sf.read(audio_est) x_ref, fs_ref = sf.read(audio_ref) # align len_x = np.min([len(x_est), len(x_ref)]) x_est = x_est[:len_x] x_ref = x_ref[:len_x] # x_ref = x_ref / np.max(np.abs(x_ref)) if fs_est != fs_ref: raise ValueError( 'Sampling rate is different amon estimated audio and reference audio' ) if self.metric == 'rmse': return compute_rmse(x_est, x_ref) elif self.metric == 'pesq': return pesq(x_ref, x_est, fs_est) elif self.metric == 'stoi': return stoi(x_ref, x_est, fs_est, extended=False) elif self.metric == 'estoi': return stoi(x_ref, x_est, fs_est, extended=True) elif self.metric == 'all': score_rmse = compute_rmse(x_est, x_ref) score_pesq = pesq(x_ref, x_est, fs_est) score_stoi = stoi(x_ref, x_est, fs_est, extended=False) return score_rmse, score_pesq, score_stoi else: raise ValueError( 'Evaluation only support: rmse, pesq, (e)stoi, all')
def make_line(clean, enh, fs): # Compute in NumPy line = [stoi(clean, enh, fs), stoi(clean, enh, fs, extended=True)] # Compute in PyTorch for use_vad in [True, False]: for extended in [True, False]: loss = NegSTOILoss(sample_rate=fs, use_vad=use_vad, extended=extended) line.append( -loss(torch.from_numpy(enh), torch.from_numpy(clean)).item()) return line
def stoi_on_batch(y_denoised,ytest,test_phase,sr=16000): stoivalue=0 y_denoised = np.squeeze(y_denoised,axis=3) y_denoised = np.squeeze(y_denoised,axis=0) y_denoised = librosa.db_to_amplitude(y_denoised) ytest = librosa.db_to_amplitude(ytest) denoised = y_denoised*test_phase original = ytest*test_phase denoised = librosa.istft(denoised) original = librosa.istft(original) #print(denoised) #print(original) denoised = librosa.util.normalize(denoised) original = librosa.util.normalize(original) #pesqvalue=pesq(sr, original, denoised, 'wb') stoivalue=stoi( original, denoised,sr, 'wb') #print(pesqvalue) #print("stoi didnt work") #stoivalue=0 return stoivalue
def convolution(audio_file, impulse_file, output_file): audio = read_wave(audio_file) # plt.plot(audio) # plt.show() # plt.close() # print(len(audio)) ir = read_wave(impulse_file)[:1000,] # plt.plot(ir) # plt.show() # plt.close() # print(len(ir)) # if len(audio) > len(ir): # ir = zero_padding(len(audio), audio) # else:k # audio = zero_padding(len(ir), ir) # plt.plot(ir) # plt.show() # plt.close() # print(len(audio), len(ir)) convolution = normalize(signal.convolve(audio, ir, mode='same')) # plt.plot(convolution) # plt.show() # write_wav(convolution, output_file) STOI = stoi(audio, convolution, 16000, extended=False) print("STOI is :", STOI) np.savetxt("stoi_kf-ad_3d.txt", [STOI])
def main(params, date, epoch, gpu): net = PonderEnhancer(params['model']) ckpt = get_model_ckpt(params, date, epoch) net.load_state_dict(torch.load(ckpt)) net.cuda() net.eval() # run test test_dset = get_dataset(params['test_data_config']) test_dataloader = get_dataloader(params, test_dset, train=False) loss_fn = get_loss_fn(params['loss']) fs = params['test_data_config']['fs'] per_db_results = {} for (clean, noise, mix, file_db) in tqdm.tqdm(test_dataloader): clean, mix = clean.cuda(), mix.cuda() db = file_db[0][:-4] # Train pred, ponder = net(mix, verbose=False) # change debug -> verbose _, loss_ponder = loss_fn(clean, pred, ponder) if db not in per_db_results: per_db_results[db] = {'enhance': [], 'ponder': []} # get the perceptual metrics np_clean = clean.detach().cpu().numpy().reshape(-1) np_pred = pred.detach().cpu().numpy().reshape(-1) per_db_results[db]['enhance'].append(pystoi.stoi(clean, enhanced, sr)) per_db_results[db]['ponder'].append(loss_ponder.item()) # save it all save_dict(per_db_results, 'stoi', epoch, date, params)
def eval_tts_scores(y_clean: ndarray, y_est: ndarray, T_ys: Sequence[int] = (0, ), sampling_rate=22050) -> Dict[str, float]: """ calculate metric using EvalModule. y can be a batch. Args: y_clean: real audio y_est: estimated audio T_ys: length of the non-zero parts of the histograms sampling_rate: The used Sampling rate. Returns: A dictionary mapping scoring systems (string) to numerical scores. 1st entry: 'STOI' 2nd entry: 'PESQ' """ if y_clean.ndim == 1: y_clean = y_clean[np.newaxis, ...] y_est = y_est[np.newaxis, ...] if T_ys == (0, ): T_ys = (y_clean.shape[1], ) * y_clean.shape[0] clean = y_clean[0, :T_ys[0]] estimated = y_est[0, :T_ys[0]] stoi_score = stoi(clean, estimated, sampling_rate, extended=False) pesq_score = pesq(16000, np.asarray(clean), estimated, 'wb') ## fs was set 16,000, as pesq lib doesnt currently support felxible fs. return {'STOI': stoi_score, 'PESQ': pesq_score}
def calc_metrics(loader, actor, device): pesq_all = [] stoi_all = [] for batch in loader: x = batch["noisy"].unsqueeze(1).to(device) t = batch["clean"].unsqueeze(1).to(device) m = batch["mask"].to(device) out_r, out_i = actor(x) out_r = torch.transpose(out_r, 1, 2) out_i = torch.transpose(out_i, 1, 2) y = predict(x.squeeze(1), (out_r, out_i)) t = t.squeeze() m = m.squeeze() #print("Y:", y.shape) #source, targets, preds = inverse(t, y, m, x) targets, preds = inverse(t, y, m, x) for j in range(len(targets)): curr_pesq = pesq(targets[j].detach().cpu().numpy(), preds[j].detach().cpu().numpy(), 16000) curr_stoi = stoi(targets[j].detach().cpu().numpy(), preds[j].detach().cpu().numpy(), 16000) pesq_all.append(curr_pesq) stoi_all.append(curr_stoi) PESQ = torch.mean(torch.tensor(pesq_all)) STOI = torch.mean(torch.tensor(stoi_all)) return PESQ, STOI
def test_preprocessed_data(net_type): from pystoi import stoi import pesq from mir_eval.separation import bss_eval_sources path = 'preprocessed_test_data_' + net_type + '/' if os.path.isdir(path): files = [f for f in os.listdir(path) if f.endswith('.npy')] sdr_a = [] pesq_a = [] stoi_a = [] processed = 0 for i, f in enumerate(files): signals = np.load(path + f) clean_speech = signals[:,0] recovered_speech = signals[:,1] if np.any(clean_speech) and np.any(recovered_speech): PESQ = pesq.pesq(dsp.audio_fs, clean_speech, recovered_speech, 'wb') STOI = stoi(clean_speech, recovered_speech, dsp.audio_fs, extended=False) SDR, sir, sar, perm = bss_eval_sources(clean_speech, recovered_speech) sdr_a.append(SDR[0]) pesq_a.append(PESQ) stoi_a.append(STOI) processed += 1 if i < len(files)-1: print('[Metric computation: {}% complete]'.format(100.0*(i+1)/len(files)), end='\r') else: print('[Metric computation: {}% complete]'.format(100.0*(i+1)/len(files)), end='\n') metrics = np.array([sdr_a, pesq_a, stoi_a]).T np.save(net_type + '_metrics.npy', metrics) print("Finished pre-processed testing of net '{}', {} files out of {} were processed into {}_metrics.npy".format(net_type, processed, len(files), net_type)) else: print("Error: Preprocessed data for the model not found")
def validation_step(self, batch, batch_idx): self.mode = OperationMode.validation self.model.mode = OperationMode.validation audio, audio_len = batch z, logdet, predicted_audio, spec, spec_len = self(audio=audio, audio_len=audio_len) loss = self.loss(z=z, logdet=logdet, gt_audio=audio, predicted_audio=predicted_audio, sigma=self.sigma) # compute average stoi score for batch stoi_score = 0 sr = self._cfg.preprocessor.params.sample_rate for audio_i, audio_recon_i in zip(audio.cpu(), predicted_audio.cpu()): stoi_score += stoi(audio_i, audio_recon_i, sr) stoi_score /= audio.shape[0] return { "val_loss": loss, "predicted_audio": predicted_audio, "mel_target": spec, "mel_len": spec_len, "stoi": stoi_score, }
def main(): parser = argparse.ArgumentParser() parser.add_argument('--results_dir', type=str, required=True) parser.add_argument('--audio_sampling_rate', type=int, default=16000) args = parser.parse_args() audio1, _ = librosa.load(os.path.join(args.results_dir, 'audio1_separated.wav'), sr=args.audio_sampling_rate) audio2, _ = librosa.load(os.path.join(args.results_dir, 'audio2_separated.wav'), sr=args.audio_sampling_rate) audio1_gt, _ = librosa.load(os.path.join(args.results_dir, 'audio1.wav'), sr=args.audio_sampling_rate) audio2_gt, _ = librosa.load(os.path.join(args.results_dir, 'audio2.wav'), sr=args.audio_sampling_rate) audio_mix, _ = librosa.load(os.path.join(args.results_dir, 'audio_mixed.wav'), sr=args.audio_sampling_rate) # SDR, SIR, SAR sdr, sir, sar = getSeparationMetrics(audio1, audio2, audio1_gt, audio2_gt) sdr_mixed, _, _ = getSeparationMetrics(audio_mix, audio_mix, audio1_gt, audio2_gt) # PESQ pesq_score1 = pesq(audio1, audio1_gt, args.audio_sampling_rate) pesq_score2 = pesq(audio2, audio2_gt, args.audio_sampling_rate) pesq_score = (pesq_score1 + pesq_score2) / 2 # STOI stoi_score1 = stoi(audio1_gt, audio1, args.audio_sampling_rate, extended=False) stoi_score2 = stoi(audio2_gt, audio2, args.audio_sampling_rate, extended=False) stoi_score = (stoi_score1 + stoi_score2) / 2 output_file = open(os.path.join(args.results_dir, 'eval.txt'), 'w') output_file.write( "%3f %3f %3f %3f %3f %3f %3f" % (sdr, sdr_mixed, sdr - sdr_mixed, sir, sar, pesq_score, stoi_score)) output_file.close()
def inference(clean_path, noisy_path, model_path, out_path): device = torch.device("cuda:1") model = Actor() model = nn.DataParallel(model, device_ids=[1, 2]) model.load_state_dict(torch.load(model_path + 'actor_best.pth')) model = model.to(device) dataset = Data(clean_path, noisy_path, mode='Test') loader = data.DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_custom) fnames = os.listdir(noisy_path) print("Num files:", len(fnames)) pesq_all = [] stoi_all = [] fcount = 0 for batch in tqdm(loader): x = batch["noisy"].unsqueeze(1).to(device) t = batch["clean"].unsqueeze(1).to(device) m = batch["mask"].to(device) out_r, out_i = model(x) out_r = torch.transpose(out_r, 1, 2) out_i = torch.transpose(out_i, 1, 2) y = predict(x.squeeze(1), (out_r, out_i)) t = t.squeeze() m = m.squeeze() x = x.squeeze() source, targets, preds = inverse(t, y, m, x) for j in range(len(targets)): t_j = targets[j].detach().cpu().numpy() p_j = preds[j].detach().cpu().numpy() p_j = 10 * (p_j / np.linalg.norm(p_j)) curr_pesq = pesq(t_j, p_j, 16000) curr_stoi = stoi(t_j, p_j, 16000) pesq_all.append(curr_pesq) stoi_all.append(curr_stoi) try: sf.write(os.path.join(out_path, fnames[fcount]), p_j, 16000) except IndexError: print("Fcount:", fcount, len(fnames)) fcount += 1 PESQ = torch.mean(torch.tensor(pesq_all)) STOI = torch.mean(torch.tensor(stoi_all)) print("PESQ: ", PESQ, "STOI: ", STOI) with open(os.path.join(model_path, 'test_scores.txt'), 'w') as fo: fo.write("Avg PESQ: " + str(float(PESQ)) + " Avg STOI: " + str(float(STOI)))
def __call__(self, a, b): """ :param a: 时域信号 :param b: 时域信号 :return: """ assert len(a.shape) == 1 assert len(a) == len(b) score = stoi(a, b, self.sr) return score
def forward(self, clean, enhanced): assert len(clean) == len(enhanced) scores = [] for c, e in zip(clean, enhanced): q = stoi(c.detach().cpu(), e.detach().cpu(), self.sr, extended=self.extended) scores.append(q) return torch.tensor(scores, device=self._device)
def test(device, net_type, model_path, dataset): from pystoi import stoi from pesq import pesq from mir_eval.separation import bss_eval_sources net = create_net_of_type(net_type, device, 1) net.load_state_dict(torch.load(model_path, map_location=device)) net = net.to(device) net.eval() dataset = create_dataset_for(dataset, 'test') loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) metrics = np.zeros((len(dataset), 3)) i=0 for clean, noisy, st, sm in loader: clean_speech = clean.numpy().flatten() noisy_speech = noisy.numpy().flatten() sm = sm.float().to(device) if net_type == 'phasen_1strm' or net_type == 'phasen_baseline': cIRM_est = net(sm) cIRM_est = cIRM_est.squeeze(0) sm = sm.squeeze(0) #C, T, F = decompressed_cIRM.shape C, T, F = cIRM_est.shape sout_c = torch.zeros(T, F, dtype=torch.cfloat) sout_c.real = cIRM_est[0,:,:]*sm[0,:,:] - cIRM_est[1,:,:]*sm[1,:,:] sout_c.imag = cIRM_est[0,:,:]*sm[1,:,:] + cIRM_est[1,:,:]*sm[0,:,:] sout_c = sout_c.T.detach().cpu().numpy() else: s_out, M, Phi = net(sm) # Convert s_out from (1, 2, T, F) to a complex array of shape (F, T) s_out = s_out.squeeze(0) C, T, F = s_out.shape sout_c = torch.zeros(T, F, dtype=torch.cfloat) sout_c.real = s_out[0,:,:] sout_c.imag = s_out[1,:,:] sout_c = sout_c.T.detach().cpu().numpy() # Recover time domain signal t, recovered_speech = dsp.recover_from_stft_spectrogram(sout_c, dsp.audio_fs) PESQ = pesq(dsp.audio_fs, clean_speech, recovered_speech, 'wb') STOI = stoi(clean_speech, recovered_speech, dsp.audio_fs, extended=False) SDR, sir, sar, perm = bss_eval_sources(clean_speech, recovered_speech) metrics[i,0] = SDR[0] metrics[i,1] = PESQ metrics[i,2] = STOI i += 1 if i < len(dataset)-1: print('[Sample {}% Complete]'.format(100*i/len(dataset)), end='\r') else: print('[Sample {}% Complete]'.format(100*i/len(dataset)), end='\n') np.save(net_type + '_metrics.npy', metrics) print("Finished testing of net '{}', metrics saved in {}_metrics.npy".format(net_type, net_type))
def CompositeEval(ref_wav, deg_wav, log_all=False): # returns [sig, bak, ovl] alpha = 0.95 len_ = min(ref_wav.shape[0], deg_wav.shape[0]) ref_wav = ref_wav[:len_] ref_len = ref_wav.shape[0] deg_wav = deg_wav[:len_] # Compute WSS measure wss_dist_vec = wss(ref_wav, deg_wav, 16000) wss_dist_vec = sorted(wss_dist_vec, reverse=False) wss_dist = np.mean(wss_dist_vec[:int(round(len(wss_dist_vec) * alpha))]) # Compute LLR measure LLR_dist = llr(ref_wav, deg_wav, 16000) LLR_dist = sorted(LLR_dist, reverse=False) LLRs = LLR_dist LLR_len = round(len(LLR_dist) * alpha) llr_mean = np.mean(LLRs[:LLR_len]) # Compute the SSNR snr_mean, segsnr_mean = SSNR(ref_wav, deg_wav, 16000) segSNR = np.mean(segsnr_mean) # print(' 1') # Compute the PESQ # pesq_raw = PESQ(ref_wav, deg_wav) pesq_raw = pesq(ref=ref_wav, deg=deg_wav, fs=16000, mode='wb') # print(' 2') # pesq_raw = pesq(ref_wav, deg_wav, 16000) stoi_val = stoi(ref_wav, deg_wav, 16000, extended=False) # stoi_val = 0.0 # print(' 3') # print('in utils.py L 435: ', type(pesq_raw)) #-> error here: <class 'NoneType'> # print('in utils.py L 436: ', pesq_raw) # -> error here: <class 'NoneType'> # if 'error!' not in pesq_raw: # pesq_raw = float(pesq_raw) # else: # pesq_raw = -1. def trim_mos(val): return min(max(val, 1), 5) Csig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_raw - 0.009 * wss_dist Csig = trim_mos(Csig) Cbak = 1.634 + 0.478 * pesq_raw - 0.007 * wss_dist + 0.063 * segSNR Cbak = trim_mos(Cbak) Covl = 1.594 + 0.805 * pesq_raw - 0.512 * llr_mean - 0.007 * wss_dist Covl = trim_mos(Covl) if log_all: return Csig, Cbak, Covl, pesq_raw, segSNR, stoi_val * 100 else: return Csig, Cbak, Covl
def sp_enhance_evals(est_source, clean_source, noisy_source, fs): est_source = est_source.cpu().clone().numpy() clean_source = clean_source.cpu().clone().numpy() noisy_source = noisy_source.cpu().clone().numpy() pesq_val = pesq(fs, clean_source, est_source, 'wb') stoi_val = stoi(clean_source, est_source, fs, extended=False) si_sdr_val = si_sdr(est_source, clean_source) noisy_si_sdr_val = si_sdr(noisy_source, clean_source) si_sdr_improvement = si_sdr_val - noisy_si_sdr_val return pesq_val, stoi_val, si_sdr_val, si_sdr_improvement
def Q(mode, clean, denoised, sr): if mode == "pesq": #Figure out wb and nb what is it return (pesq(sr, clean, denoised, 'wb') + 0.5) / 5.0 elif mode == "stoi": #criterion = NegSTOILoss(sample_rate=sr) #print(criterion(torch.from_numpy(denoised).unsqueeze(0), torch.from_numpy(clean).unsqueeze(0))) #return stoi(clean, denoised, sr, extended=False) #return -torch.sum(criterion(torch.from_numpy(denoised).unsqueeze(0), torch.from_numpy(clean).unsqueeze(0))) #clean = librosa.resample(clean, sr, 10000) #denoised = librosa.resample(denoised, sr, 10000) return stoi(clean, denoised, 16000, extended=False)
def get_stoi(ref_sig, out_sig, sr): """Calculate STOI. Args: ref_sig: numpy.ndarray, [B, T] out_sig: numpy.ndarray, [B, T] Returns: STOI """ stoi_val = 0 for i in range(len(ref_sig)): stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False) return stoi_val
def _eval(batch, metrics, including='output', sample_rate=8000, use_pypesq=False): if use_pypesq: metrics = [m for m in metrics if m != 'pesq'] has_estoi = False if 'estoi' in metrics: metrics = [m for m in metrics if m != 'estoi'] has_estoi = True has_wer = False if 'wer' in metrics: metrics = [m for m in metrics if m != 'wer'] has_wer = True mix = batch['mix'] clean = batch['clean'] estimate = batch['enh'] snr = batch['snr'] res = get_metrics(mix.numpy(), clean.numpy(), estimate.numpy(), sample_rate=sample_rate, metrics_list=metrics, including=including) if use_pypesq: res['pesq'] = pesq(clean.flatten(), estimate.flatten(), sample_rate) if has_estoi: res['estoi'] = stoi(clean.flatten(), estimate.flatten(), sample_rate, extended=True) if has_wer: res['wer'] = jiwer.wer(batch['clean_text'], batch['transcription'], truth_transform=_wer_trans, hypothesis_transform=_wer_trans) if including == 'input': for m in metrics: res[m] = res['input_' + m] del res['input_' + m] res['snr'] = snr[0].item() return res
def task1_metric(clean_speech, denoised_speech, sr=16000): ''' Compute evaluation metric for task 1 as (stoi+(1-word error rate)/2) This function computes such measure for 1 single datapoint ''' WER = wer(clean_speech, denoised_speech) if WER is not None: #if there is no speech in the segment STOI = stoi(clean_speech, denoised_speech, sr, extended=False) WER = np.clip(WER, 0., 1.) STOI = np.clip(STOI, 0., 1.) metric = (STOI + (1. - WER)) / 2. else: metric = None STOI = None return metric, WER, STOI
def stoi_score(y_true: np.array, y_pred: np.array, samplerate=16000, extended=False) -> float: """Computes the Short Term Objective Intelligibility metric between `y_true` and `y_pred`. Args: y_true (np.array): The original audio signal with shape (samplerate*length). y_pred (np.array): The predicted audio signal with shape (samplerate*length). samplerate (int, optional): Either 8000 or 16000. Defaults to 16000. extended (boolean, optional): Whenever to use the extended stoi metric instead. Returns: float: The stoi score between `y_true` and `y_pred`. """ return stoi(y_true, y_pred, fs=samplerate, extended=extended)
def main(): parser = argparse.ArgumentParser(description='Calculate performance index') parser.add_argument('--test_mix_folder', default='../test-mix-2-babble', type=str, help='test-set-mix') parser.add_argument('--test_clean_folder', default='../test-clean-2-babble', type=str, help='test-set-clean') parser.add_argument('--enhanced_folder', default='../test-result', type=str, help='test-set-enhanced') opt = parser.parse_args() MIX_FOLDER = opt.test_mix_folder CLEAN_FOLDER = opt.test_clean_folder ENHANCED_FOLDER = opt.enhanced_folder pesqs = [] stois = [] for cleanfile in os.listdir(CLEAN_FOLDER): mixfile = cleanfile.replace('clean', 'mix') enhancedfile = 'enhanced_' + mixfile cleanfile = os.path.join(CLEAN_FOLDER, cleanfile) mixfile = os.path.join(MIX_FOLDER, mixfile) enhancedfile = os.path.join(ENHANCED_FOLDER, enhancedfile) ref, sr1 = librosa.load(cleanfile, 16000) #deg_mix, sr2 = librosa.load(mixfile, 16000) deg_enh, sr3 = librosa.load(enhancedfile, 16000) #pesq1 = pesq.pesq(ref, deg_mix) pesq2 = pesq.pesq(ref, deg_enh[:len(ref)]) #print("pesq:", pesq1, " --> ", pesq2) pesqs.append(pesq2) #stoi1 = stoi(ref, deg_mix, fs_sig=16000) stoi2 = stoi(ref, deg_enh[:len(ref)], fs_sig=16000) #print("stoi:", stoi1, " --> ", stoi2) stois.append(stoi2) print('Epesq:', np.mean(pesqs), "Estoi:", np.mean(stois))
def cal_STOI(ref_sig, out_sig, sr): """Calculate STOI. Args: ref_sig: numpy.ndarray, [B, C, T] out_sig: numpy.ndarray, [B, C, T] Returns: STOI """ B, C, T = ref_sig.shape ref_sig = ref_sig.reshape(B * C, T) out_sig = out_sig.reshape(B * C, T) try: stoi_val = 0 for i in range(len(ref_sig)): stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False) return stoi_val / (B * C) except: return 0
def stoi_from_fft(noisy,phase_noisy,clean,phase_clean): """ Calculate STOI Metric on stft batch """ phase_noisy=np.array(phase_noisy) noisy=librosa.db_to_amplitude(noisy) noisy=noisy*phase_noisy noisy = librosa.istft(noisy) clean=np.array(clean) phase_clean=np.array(phase_clean) clean=librosa.db_to_amplitude(clean) clean=clean*phase_clean clean = librosa.istft(clean) sr =16000 stoivalue=stoi(noisy, clean,sr, 'wb') #print(pesqvalue) return stoivalue
def train(self): m_print( f"The amount of parameters in the project is {print_networks([self.model]) / 1e6} million." ) for epoch in range(self.config['epochs']): total_loss = 0 self.model.train() '''training''' for mixs_wav, cleans_wav, lengths, _ in tqdm( self.train_dataloader): self.optimizer.zero_grad() mixs = torch.stft(mixs_wav, n_fft=self.n_fft, hop_length=self.hop_len, win_length=self.win_len, window=torch.hamming_window( self.win_len)).permute(0, 2, 1, 3).cuda() # mixs = self.stft.transform(mixs_wav.cuda()) mixs_real = mixs[:, :, :, 0] mixs_imag = mixs[:, :, :, 1] mixs_mag = torch.sqrt(mixs_real**2 + mixs_imag**2) cleans = torch.stft(cleans_wav, n_fft=self.n_fft, hop_length=self.hop_len, win_length=self.win_len, window=torch.hamming_window( self.win_len)).permute(0, 2, 1, 3).cuda() cleans_real = cleans[:, :, :, 0] cleans_imag = cleans[:, :, :, 1] cleans_mag = torch.sqrt(cleans_real**2 + cleans_imag**2) # z_score # mixs_mag, _, _ = z_score(mixs_mag) # cleans_mag, _, _ = z_score(cleans_mag) enhances_mag = self.model(mixs_mag) frames = [] for length in lengths: frame = (length - self.win_len) // self.hop_len + 3 frames.append(frame) loss = self.loss_function.calculate_loss( enhances_mag, cleans_mag, frames) total_loss += loss.item() loss.backward() self.optimizer.step() # break # print(loss.item()) gc.collect() # break tqdm.write(f"\nepoch: {epoch}, total loss: {total_loss}") end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) m_print(f"epoch: {epoch} end logging time:\t{end_time}") self.model.eval() '''validating''' with torch.no_grad(): stoi_sum = 0 pesq_sum = 0 si_sdr_sum = 0 count_number = 0 for mixs_wav, cleans_wav, lengths, _ in tqdm( self.eval_dataloader): mixs = torch.stft(mixs_wav, n_fft=self.n_fft, hop_length=self.hop_len, win_length=self.win_len, window=torch.hamming_window( self.win_len)).permute(0, 2, 1, 3).cuda() # mixs = self.stft.transform(mixs_wav.cuda()) mixs_real = mixs[:, :, :, 0] mixs_imag = mixs[:, :, :, 1] mixs_mag = torch.sqrt(mixs_real**2 + mixs_imag**2) # z_score # mixs_mag, mixture_mean, mixture_std = z_score(mixs_mag) enhances_mag = self.model(mixs_mag) # z_score # enhances_mag = reverse_z_score(enhances_mag, mixture_mean, mixture_std) '''eval''' enhances_real = enhances_mag * mixs_real / mixs_mag enhances_imag = enhances_mag * mixs_imag / mixs_mag enhances = torch.stack([enhances_real, enhances_imag], 3) enhances = enhances.permute(0, 2, 1, 3) enhances_wav = torch.istft(enhances, n_fft=self.n_fft, hop_length=self.hop_len, win_length=self.win_len, window=torch.hamming_window( self.win_len).cuda(), length=max(lengths)) # enhances_wav = self.stft.inverse(enhances) frames = [] # len_list = [] for length in lengths: frame = (length - self.win_len) // self.hop_len + 3 frames.append(frame) # len_list.append((frame - 1) * 160 + 320) cleans_wav = cleans_wav.cpu().numpy() enhances_wav = enhances_wav.cpu().numpy() for clean, enhance, length in zip(cleans_wav, enhances_wav, lengths): clean = clean[:length] enhance = enhance[:length] stoi_transform = pystoi.stoi(clean, enhance, 16000) pesq_transform = pypesq.pesq(clean, enhance, 16000) si_sdr_transform = SI_SDR(clean, enhance) if np.isnan(stoi_transform) or np.isnan( pesq_transform) or np.isnan(si_sdr_transform): continue stoi_sum += stoi_transform pesq_sum += pesq_transform si_sdr_sum += si_sdr_transform count_number += 1 score = self._calculate_score(stoi=stoi_sum, pesq=pesq_sum) if self._is_best_score(score): is_best = True else: is_best = False self._save_checkpoints(epoch, is_best) print(count_number) m_print(f"stoi score: " f"{stoi_sum / count_number}," f"pesq score: " f"{pesq_sum / count_number}," f"si_sdr score: " f"{si_sdr_sum / count_number}," f"is best: {is_best}")
def scoring( output_dir: str, dtype: str, log_level: Union[int, str], key_file: str, ref_scp: List[str], inf_scp: List[str], ref_channel: int, ): assert check_argument_types() logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) assert len(ref_scp) == len(inf_scp), ref_scp num_spk = len(ref_scp) keys = [ line.rstrip().split(maxsplit=1)[0] for line in open(key_file, encoding="utf-8") ] ref_readers = [ SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp ] inf_readers = [ SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp ] # get sample rate sample_rate, _ = ref_readers[0][keys[0]] # check keys for inf_reader, ref_reader in zip(inf_readers, ref_readers): assert inf_reader.keys() == ref_reader.keys() with DatadirWriter(output_dir) as writer: for key in keys: ref_audios = [ref_reader[key][1] for ref_reader in ref_readers] inf_audios = [inf_reader[key][1] for inf_reader in inf_readers] ref = np.array(ref_audios) inf = np.array(inf_audios) if ref.ndim > inf.ndim: # multi-channel reference and single-channel output ref = ref[..., ref_channel] assert ref.shape == inf.shape, (ref.shape, inf.shape) elif ref.ndim < inf.ndim: # single-channel reference and multi-channel output raise ValueError("Reference must be multi-channel when the \ network output is multi-channel.") elif ref.ndim == inf.ndim == 3: # multi-channel reference and output ref = ref[..., ref_channel] inf = inf[..., ref_channel] sdr, sir, sar, perm = bss_eval_sources(ref, inf, compute_permutation=True) for i in range(num_spk): stoi_score = stoi(ref[i], inf[int(perm[i])], fs_sig=sample_rate) si_snr_score = -float( si_snr_loss( torch.from_numpy(ref[i][None, ...]), torch.from_numpy(inf[int(perm[i])][None, ...]), )) writer[f"STOI_spk{i + 1}"][key] = str(stoi_score) writer[f"SI_SNR_spk{i + 1}"][key] = str(si_snr_score) writer[f"SDR_spk{i + 1}"][key] = str(sdr[i]) writer[f"SAR_spk{i + 1}"][key] = str(sar[i]) writer[f"SIR_spk{i + 1}"][key] = str(sir[i]) # save permutation assigned script file writer[f"wav_spk{i + 1}"][key] = inf_readers[perm[i]].data[key]
def scoring( output_dir: str, dtype: str, log_level: Union[int, str], key_file: str, ref_scp: List[str], inf_scp: List[str], ref_channel: int, metrics: List[str], frame_size: int = 512, frame_hop: int = 256, ): assert check_argument_types() for metric in metrics: assert metric in ( "STOI", "ESTOI", "SNR", "SI_SNR", "SDR", "SAR", "SIR", "framewise-SNR", ), metric logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) assert len(ref_scp) == len(inf_scp), ref_scp num_spk = len(ref_scp) keys = [ line.rstrip().split(maxsplit=1)[0] for line in open(key_file, encoding="utf-8") ] ref_readers = [ SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp ] inf_readers = [ SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp ] # get sample rate fs, _ = ref_readers[0][keys[0]] # check keys for inf_reader, ref_reader in zip(inf_readers, ref_readers): assert inf_reader.keys() == ref_reader.keys() stft = STFTEncoder(n_fft=frame_size, hop_length=frame_hop) do_bss_eval = "SDR" in metrics or "SAR" in metrics or "SIR" in metrics with DatadirWriter(output_dir) as writer: for key in keys: ref_audios = [ref_reader[key][1] for ref_reader in ref_readers] inf_audios = [inf_reader[key][1] for inf_reader in inf_readers] ref = np.array(ref_audios) inf = np.array(inf_audios) if ref.ndim > inf.ndim: # multi-channel reference and single-channel output ref = ref[..., ref_channel] assert ref.shape == inf.shape, (ref.shape, inf.shape) elif ref.ndim < inf.ndim: # single-channel reference and multi-channel output raise ValueError("Reference must be multi-channel when the " "network output is multi-channel.") elif ref.ndim == inf.ndim == 3: # multi-channel reference and output ref = ref[..., ref_channel] inf = inf[..., ref_channel] if do_bss_eval or num_spk > 1: sdr, sir, sar, perm = bss_eval_sources( ref, inf, compute_permutation=True) else: perm = [0] ilens = torch.LongTensor([ref.shape[1]]) # (num_spk, T, F) ref_spec, flens = stft(torch.from_numpy(ref), ilens) inf_spec, _ = stft(torch.from_numpy(inf), ilens) for i in range(num_spk): p = int(perm[i]) for metric in metrics: name = f"{metric}_spk{i + 1}" if metric == "STOI": writer[name][key] = str( stoi(ref[i], inf[p], fs_sig=fs, extended=False)) elif metric == "ESTOI": writer[name][key] = str( stoi(ref[i], inf[p], fs_sig=fs, extended=True)) elif metric == "SNR": si_snr_score = -float( ESPnetEnhancementModel.snr_loss( torch.from_numpy(ref[i][None, ...]), torch.from_numpy(inf[p][None, ...]), )) writer[name][key] = str(si_snr_score) elif metric == "SI_SNR": si_snr_score = -float( ESPnetEnhancementModel.si_snr_loss( torch.from_numpy(ref[i][None, ...]), torch.from_numpy(inf[p][None, ...]), )) writer[name][key] = str(si_snr_score) elif metric == "SDR": writer[name][key] = str(sdr[i]) elif metric == "SAR": writer[name][key] = str(sar[i]) elif metric == "SIR": writer[name][key] = str(sir[i]) elif metric == "framewise-SNR": framewise_snr = -ESPnetEnhancementModel.snr_loss( ref_spec[i].abs(), inf_spec[i].abs()) writer[name][key] = " ".join( map(str, framewise_snr.tolist())) else: raise ValueError("Unsupported metric: %s" % metric) # save permutation assigned script file writer[f"wav_spk{i + 1}"][key] = inf_readers[ perm[i]].data[key]
def estoi_eval(pred_wav, target_wav): return stoi(x=target_wav.numpy(), y=pred_wav.numpy(), fs_sig=16000, extended=True)
def compute_metrics_utt(args): # Separate args file_path, snr_db = args[0], args[1] # print(file_path) # Read files s_t, fs_s = sf.read(processed_data_dir + os.path.splitext(file_path)[0] + '_s.wav') # clean speech n_t, fs_n = sf.read(processed_data_dir + os.path.splitext(file_path)[0] + '_n.wav') # noise x_t, fs_x = sf.read(processed_data_dir + os.path.splitext(file_path)[0] + '_x.wav') # mixture s_hat_t, fs_s_hat = sf.read(model_data_dir + os.path.splitext(file_path)[0] + '_s_est.wav') # est. speech # compute metrics ## SI-SDR, SI-SAR, SI-SNR si_sdr, si_sir, si_sar = energy_ratios(s_hat=s_hat_t, s=s_t, n=n_t) ## STOI (or ESTOI?) stoi_s_hat = stoi(s_t, s_hat_t, fs, extended=True) # all_stoi.append(stoi_s_hat) ## PESQ pesq_s_hat = pesq(fs, s_t, s_hat_t, 'wb') # wb = wideband # all_pesq.append(pesq_s_hat) ## POLQA # polqa_s_hat = polqa(s, s_t, fs) # all_polqa.append(polqa_s_hat) # TF representation s_tf = stft(s_t, fs=fs, wlen_sec=wlen_sec, win=win, hop_percent=hop_percent, dtype=dtype) # shape = (freq_bins, frames) # plots of target / estimation # TF representation x_tf = stft(x_t, fs=fs, wlen_sec=wlen_sec, win=win, hop_percent=hop_percent, dtype=dtype) # shape = (freq_bins, frames) s_hat_tf = stft(s_hat_t, fs=fs, wlen_sec=wlen_sec, win=win, hop_percent=hop_percent, dtype=dtype) # shape = (freq_bins, frames) ## mixture signal (wav + spectro) ## target signal (wav + spectro + mask) ## estimated signal (wav + spectro + mask) signal_list = [ [x_t, x_tf, None], # mixture: (waveform, tf_signal, no mask) [s_t, s_tf, None], # clean speech [s_hat_t, s_hat_tf, None] ] fig = display_multiple_signals(signal_list, fs=fs, vmin=vmin, vmax=vmax, wlen_sec=wlen_sec, hop_percent=hop_percent, xticks_sec=xticks_sec, fontsize=fontsize) # put all metrics in the title of the figure title = "Input SNR = {:.1f} dB \n" \ "SI-SDR = {:.1f} dB, " \ "SI-SIR = {:.1f} dB, " \ "SI-SAR = {:.1f} dB \n" \ "STOI = {:.2f}, " \ "PESQ = {:.2f} \n" \ "".format(snr_db, si_sdr, si_sir, si_sar, stoi_s_hat, pesq_s_hat) fig.suptitle(title, fontsize=40) # Save figure fig.savefig(model_data_dir + os.path.splitext(file_path)[0] + '_fig.png') # Clear figure plt.close() metrics = [si_sdr, si_sir, si_sar, stoi_s_hat, pesq_s_hat] return metrics
def main(): # Load input SNR all_snr_db = read_dataset(processed_data_dir, dataset_type, 'snr_db') all_snr_db = np.array(all_snr_db) # Create file list file_paths = speech_list(input_speech_dir=input_speech_dir, dataset_type=dataset_type) # 1 list per metric all_stoi = [] all_pesq = [] all_polqa = [] all_f1score = [] for i, file_path in tqdm(enumerate(file_paths)): # Read files s_t, fs_s = sf.read(processed_data_dir + os.path.splitext(file_path)[0] + '_s.wav') # clean speech n_t, fs_n = sf.read(processed_data_dir + os.path.splitext(file_path)[0] + '_n.wav') # noise x_t, fs_x = sf.read(processed_data_dir + os.path.splitext(file_path)[0] + '_x.wav') # mixture # compute metrics ## STOI (or ESTOI?) stoi_s_hat = stoi(s_t, x_t, fs, extended=True) all_stoi.append(stoi_s_hat) ## PESQ pesq_s_hat = pesq(fs, s_t, x_t, 'wb') # wb = wideband all_pesq.append(pesq_s_hat) ## POLQA # polqa_s_hat = polqa(s, s_t, fs) # all_polqa.append(polqa_s_hat) # TF representation n_tf = stft(n_t, fs=fs, wlen_sec=wlen_sec, win=win, hop_percent=hop_percent, dtype=dtype) # shape = (freq_bins, frames) s_tf = stft(s_t, fs=fs, wlen_sec=wlen_sec, win=win, hop_percent=hop_percent, dtype=dtype) # shape = (freq_bins, frames) # plots of target / estimation # TF representation x_tf = stft(x_t, fs=fs, wlen_sec=wlen_sec, win=win, hop_percent=hop_percent, dtype=dtype) # shape = (freq_bins, frames) # ## mixture signal (wav + spectro) # ## target signal (wav + spectro + mask) # ## estimated signal (wav + spectro + mask) # signal_list = [ # [x_t, x_tf, None], # mixture: (waveform, tf_signal, no mask) # [s_t, s_tf, None], # clean speech # [n_t, n_tf, None] # ] # fig = display_multiple_signals(signal_list, # fs=fs, vmin=vmin, vmax=vmax, # wlen_sec=wlen_sec, hop_percent=hop_percent, # xticks_sec=xticks_sec, fontsize=fontsize) # # put all metrics in the title of the figure # title = "Input SNR = {:.1f} dB \n" \ # "STOI = {:.2f}, " \ # "PESQ = {:.2f} \n" \ # "".format(all_snr_db[i], stoi_s_hat, pesq_s_hat) # fig.suptitle(title, fontsize=40) # # Save figure # fig.savefig(processed_data_dir + os.path.splitext(file_path)[0] + '_fig.png') # # Clear figure # plt.close() # Confidence interval metrics = {'SNR': all_snr_db, 'STOI': all_stoi, 'PESQ': all_pesq} stats = {} # Print the names of the columns. print("{:<10} {:<10} {:<10}".format('METRIC', 'AVERAGE', 'CONF. INT.')) for key, metric in metrics.items(): m, h = mean_confidence_interval(metric, confidence=confidence) stats[key] = {'avg': m, '+/-': h} print("{:<10} {:<10} {:<10}".format(key, m, h)) print('\n') # Save stats (si_sdr, si_sar, etc. ) with open( processed_data_dir + os.path.dirname(os.path.dirname(file_path)) + 'stats.json', 'w') as f: json.dump(stats, f) # Metrics by input SNR for snr_db in np.unique(all_snr_db): stats = {} print('Input SNR = {:.2f}'.format(snr_db)) # Print the names of the columns. print("{:<10} {:<10} {:<10}".format('METRIC', 'AVERAGE', 'CONF. INT.')) for key, metric in metrics.items(): subset_metric = np.array(metric)[np.where(all_snr_db == snr_db)] m, h = mean_confidence_interval(subset_metric, confidence=confidence) stats[key] = {'avg': m, '+/-': h} print("{:<10} {:<10} {:<10}".format(key, m, h)) print('\n') # Save stats (si_sdr, si_sar, etc. ) with open( processed_data_dir + os.path.dirname(os.path.dirname(file_path)) + 'stats_{:g}.json'.format(snr_db), 'w') as f: json.dump(stats, f)