예제 #1
0
    def eval(self, audio_est, audio_ref):

        x_est, fs_est = sf.read(audio_est)
        x_ref, fs_ref = sf.read(audio_ref)
        # align
        len_x = np.min([len(x_est), len(x_ref)])
        x_est = x_est[:len_x]
        x_ref = x_ref[:len_x]

        # x_ref = x_ref / np.max(np.abs(x_ref))

        if fs_est != fs_ref:
            raise ValueError(
                'Sampling rate is different amon estimated audio and reference audio'
            )

        if self.metric == 'rmse':
            return compute_rmse(x_est, x_ref)
        elif self.metric == 'pesq':
            return pesq(x_ref, x_est, fs_est)
        elif self.metric == 'stoi':
            return stoi(x_ref, x_est, fs_est, extended=False)
        elif self.metric == 'estoi':
            return stoi(x_ref, x_est, fs_est, extended=True)
        elif self.metric == 'all':
            score_rmse = compute_rmse(x_est, x_ref)
            score_pesq = pesq(x_ref, x_est, fs_est)
            score_stoi = stoi(x_ref, x_est, fs_est, extended=False)
            return score_rmse, score_pesq, score_stoi
        else:
            raise ValueError(
                'Evaluation only support: rmse, pesq, (e)stoi, all')
예제 #2
0
def make_line(clean, enh, fs):
    # Compute in NumPy
    line = [stoi(clean, enh, fs), stoi(clean, enh, fs, extended=True)]
    # Compute in PyTorch
    for use_vad in [True, False]:
        for extended in [True, False]:
            loss = NegSTOILoss(sample_rate=fs,
                               use_vad=use_vad,
                               extended=extended)
            line.append(
                -loss(torch.from_numpy(enh), torch.from_numpy(clean)).item())
    return line
def stoi_on_batch(y_denoised,ytest,test_phase,sr=16000):

    stoivalue=0

    y_denoised = np.squeeze(y_denoised,axis=3)
    y_denoised = np.squeeze(y_denoised,axis=0)



    y_denoised = librosa.db_to_amplitude(y_denoised)
    ytest = librosa.db_to_amplitude(ytest)

    denoised = y_denoised*test_phase
    original = ytest*test_phase

    denoised = librosa.istft(denoised)
    original = librosa.istft(original)

    #print(denoised)
    #print(original)

    denoised = librosa.util.normalize(denoised)
    original = librosa.util.normalize(original)

    #pesqvalue=pesq(sr, original, denoised, 'wb')


    stoivalue=stoi( original, denoised,sr, 'wb')
    #print(pesqvalue)

    #print("stoi didnt work")
    #stoivalue=0

    return stoivalue
예제 #4
0
def convolution(audio_file, impulse_file, output_file):

    audio = read_wave(audio_file)
    # plt.plot(audio)
    # plt.show()
    # plt.close()
    # print(len(audio))

    ir = read_wave(impulse_file)[:1000,]
    # plt.plot(ir)
    # plt.show()
    # plt.close()

    # print(len(ir))

    # if len(audio) > len(ir):
    #     ir = zero_padding(len(audio), audio)

    # else:k
    #     audio = zero_padding(len(ir), ir)
    # plt.plot(ir)
    # plt.show()
    # plt.close()
    # print(len(audio), len(ir))

    convolution = normalize(signal.convolve(audio, ir, mode='same'))
    # plt.plot(convolution)
    # plt.show()

    # write_wav(convolution, output_file)
    STOI = stoi(audio, convolution, 16000, extended=False)
    print("STOI is :", STOI)
    np.savetxt("stoi_kf-ad_3d.txt", [STOI])
예제 #5
0
def main(params, date, epoch, gpu):
    net = PonderEnhancer(params['model'])
    ckpt = get_model_ckpt(params, date, epoch)
    net.load_state_dict(torch.load(ckpt))

    net.cuda()
    net.eval()

    # run  test 
    test_dset = get_dataset(params['test_data_config'])
    test_dataloader = get_dataloader(params, test_dset, train=False)
    loss_fn = get_loss_fn(params['loss'])
    fs = params['test_data_config']['fs']

    per_db_results = {}
    for (clean, noise, mix, file_db) in tqdm.tqdm(test_dataloader):
        clean, mix = clean.cuda(), mix.cuda()
        db = file_db[0][:-4]

        # Train
        pred, ponder = net(mix, verbose=False)  # change debug -> verbose
        _, loss_ponder = loss_fn(clean, pred, ponder)

        if db not in per_db_results:
            per_db_results[db] = {'enhance': [], 'ponder': []}

        # get the perceptual metrics
        np_clean = clean.detach().cpu().numpy().reshape(-1)
        np_pred = pred.detach().cpu().numpy().reshape(-1)
        
        per_db_results[db]['enhance'].append(pystoi.stoi(clean, enhanced, sr))
        per_db_results[db]['ponder'].append(loss_ponder.item())

    # save it all
    save_dict(per_db_results, 'stoi', epoch, date, params)
예제 #6
0
파일: helpers.py 프로젝트: manneh/NeMo
def eval_tts_scores(y_clean: ndarray,
                    y_est: ndarray,
                    T_ys: Sequence[int] = (0, ),
                    sampling_rate=22050) -> Dict[str, float]:
    """
    calculate metric using EvalModule. y can be a batch.
    Args:
        y_clean: real audio
        y_est: estimated audio
        T_ys: length of the non-zero parts of the histograms
        sampling_rate: The used Sampling rate.

    Returns:
        A dictionary mapping scoring systems (string) to numerical scores.
        1st entry: 'STOI'
        2nd entry: 'PESQ'
    """

    if y_clean.ndim == 1:
        y_clean = y_clean[np.newaxis, ...]
        y_est = y_est[np.newaxis, ...]
    if T_ys == (0, ):
        T_ys = (y_clean.shape[1], ) * y_clean.shape[0]

    clean = y_clean[0, :T_ys[0]]
    estimated = y_est[0, :T_ys[0]]
    stoi_score = stoi(clean, estimated, sampling_rate, extended=False)
    pesq_score = pesq(16000, np.asarray(clean), estimated, 'wb')
    ## fs was set 16,000, as pesq lib doesnt currently support felxible fs.

    return {'STOI': stoi_score, 'PESQ': pesq_score}
예제 #7
0
def calc_metrics(loader, actor, device):
    pesq_all = []
    stoi_all = []
    for batch in loader:
        x = batch["noisy"].unsqueeze(1).to(device)
        t = batch["clean"].unsqueeze(1).to(device)
        m = batch["mask"].to(device)
        out_r, out_i = actor(x)
        out_r = torch.transpose(out_r, 1, 2)
        out_i = torch.transpose(out_i, 1, 2)
        y = predict(x.squeeze(1), (out_r, out_i))
        t = t.squeeze()
        m = m.squeeze()
        #print("Y:", y.shape)
        #source, targets, preds = inverse(t, y, m, x)
        targets, preds = inverse(t, y, m, x)

        for j in range(len(targets)):
            curr_pesq = pesq(targets[j].detach().cpu().numpy(),
                             preds[j].detach().cpu().numpy(), 16000)
            curr_stoi = stoi(targets[j].detach().cpu().numpy(),
                             preds[j].detach().cpu().numpy(), 16000)
            pesq_all.append(curr_pesq)
            stoi_all.append(curr_stoi)
    PESQ = torch.mean(torch.tensor(pesq_all))
    STOI = torch.mean(torch.tensor(stoi_all))
    return PESQ, STOI
예제 #8
0
def test_preprocessed_data(net_type):
    from pystoi import stoi
    import pesq
    from mir_eval.separation import bss_eval_sources
    path = 'preprocessed_test_data_' + net_type + '/'
    if os.path.isdir(path):
        files = [f for f in os.listdir(path) if f.endswith('.npy')]
        sdr_a = []
        pesq_a = []
        stoi_a = []
        processed = 0
        for i, f in enumerate(files):
            signals = np.load(path + f)
            clean_speech = signals[:,0]
            recovered_speech = signals[:,1]
            if np.any(clean_speech) and np.any(recovered_speech):
                PESQ = pesq.pesq(dsp.audio_fs, clean_speech, recovered_speech, 'wb')
                STOI = stoi(clean_speech, recovered_speech, dsp.audio_fs, extended=False)
                SDR, sir, sar, perm = bss_eval_sources(clean_speech, recovered_speech)
                sdr_a.append(SDR[0])
                pesq_a.append(PESQ)
                stoi_a.append(STOI)
                processed += 1
                if i < len(files)-1:
                    print('[Metric computation: {}% complete]'.format(100.0*(i+1)/len(files)), end='\r')
                else:
                    print('[Metric computation: {}% complete]'.format(100.0*(i+1)/len(files)), end='\n')
        metrics = np.array([sdr_a, pesq_a, stoi_a]).T
        np.save(net_type + '_metrics.npy', metrics)
        print("Finished pre-processed testing of net '{}', {} files out of {} were processed into {}_metrics.npy".format(net_type, processed, len(files), net_type))
    else:
        print("Error: Preprocessed data for the model not found")
예제 #9
0
    def validation_step(self, batch, batch_idx):
        self.mode = OperationMode.validation
        self.model.mode = OperationMode.validation
        audio, audio_len = batch
        z, logdet, predicted_audio, spec, spec_len = self(audio=audio,
                                                          audio_len=audio_len)
        loss = self.loss(z=z,
                         logdet=logdet,
                         gt_audio=audio,
                         predicted_audio=predicted_audio,
                         sigma=self.sigma)

        # compute average stoi score for batch
        stoi_score = 0
        sr = self._cfg.preprocessor.params.sample_rate
        for audio_i, audio_recon_i in zip(audio.cpu(), predicted_audio.cpu()):
            stoi_score += stoi(audio_i, audio_recon_i, sr)
        stoi_score /= audio.shape[0]

        return {
            "val_loss": loss,
            "predicted_audio": predicted_audio,
            "mel_target": spec,
            "mel_len": spec_len,
            "stoi": stoi_score,
        }
예제 #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--results_dir', type=str, required=True)
    parser.add_argument('--audio_sampling_rate', type=int, default=16000)
    args = parser.parse_args()

    audio1, _ = librosa.load(os.path.join(args.results_dir,
                                          'audio1_separated.wav'),
                             sr=args.audio_sampling_rate)
    audio2, _ = librosa.load(os.path.join(args.results_dir,
                                          'audio2_separated.wav'),
                             sr=args.audio_sampling_rate)
    audio1_gt, _ = librosa.load(os.path.join(args.results_dir, 'audio1.wav'),
                                sr=args.audio_sampling_rate)
    audio2_gt, _ = librosa.load(os.path.join(args.results_dir, 'audio2.wav'),
                                sr=args.audio_sampling_rate)
    audio_mix, _ = librosa.load(os.path.join(args.results_dir,
                                             'audio_mixed.wav'),
                                sr=args.audio_sampling_rate)

    # SDR, SIR, SAR
    sdr, sir, sar = getSeparationMetrics(audio1, audio2, audio1_gt, audio2_gt)
    sdr_mixed, _, _ = getSeparationMetrics(audio_mix, audio_mix, audio1_gt,
                                           audio2_gt)

    # PESQ
    pesq_score1 = pesq(audio1, audio1_gt, args.audio_sampling_rate)
    pesq_score2 = pesq(audio2, audio2_gt, args.audio_sampling_rate)
    pesq_score = (pesq_score1 + pesq_score2) / 2

    # STOI
    stoi_score1 = stoi(audio1_gt,
                       audio1,
                       args.audio_sampling_rate,
                       extended=False)
    stoi_score2 = stoi(audio2_gt,
                       audio2,
                       args.audio_sampling_rate,
                       extended=False)
    stoi_score = (stoi_score1 + stoi_score2) / 2

    output_file = open(os.path.join(args.results_dir, 'eval.txt'), 'w')
    output_file.write(
        "%3f %3f %3f %3f %3f %3f %3f" %
        (sdr, sdr_mixed, sdr - sdr_mixed, sir, sar, pesq_score, stoi_score))
    output_file.close()
예제 #11
0
def inference(clean_path, noisy_path, model_path, out_path):
    device = torch.device("cuda:1")
    model = Actor()
    model = nn.DataParallel(model, device_ids=[1, 2])
    model.load_state_dict(torch.load(model_path + 'actor_best.pth'))
    model = model.to(device)

    dataset = Data(clean_path, noisy_path, mode='Test')
    loader = data.DataLoader(dataset,
                             batch_size=32,
                             shuffle=False,
                             collate_fn=collate_custom)

    fnames = os.listdir(noisy_path)

    print("Num files:", len(fnames))

    pesq_all = []
    stoi_all = []
    fcount = 0

    for batch in tqdm(loader):
        x = batch["noisy"].unsqueeze(1).to(device)
        t = batch["clean"].unsqueeze(1).to(device)
        m = batch["mask"].to(device)
        out_r, out_i = model(x)
        out_r = torch.transpose(out_r, 1, 2)
        out_i = torch.transpose(out_i, 1, 2)
        y = predict(x.squeeze(1), (out_r, out_i))
        t = t.squeeze()
        m = m.squeeze()
        x = x.squeeze()
        source, targets, preds = inverse(t, y, m, x)

        for j in range(len(targets)):
            t_j = targets[j].detach().cpu().numpy()
            p_j = preds[j].detach().cpu().numpy()
            p_j = 10 * (p_j / np.linalg.norm(p_j))
            curr_pesq = pesq(t_j, p_j, 16000)
            curr_stoi = stoi(t_j, p_j, 16000)
            pesq_all.append(curr_pesq)
            stoi_all.append(curr_stoi)
            try:
                sf.write(os.path.join(out_path, fnames[fcount]), p_j, 16000)
            except IndexError:
                print("Fcount:", fcount, len(fnames))
            fcount += 1

    PESQ = torch.mean(torch.tensor(pesq_all))
    STOI = torch.mean(torch.tensor(stoi_all))

    print("PESQ: ", PESQ, "STOI: ", STOI)

    with open(os.path.join(model_path, 'test_scores.txt'), 'w') as fo:
        fo.write("Avg PESQ: " + str(float(PESQ)) + " Avg STOI: " +
                 str(float(STOI)))
예제 #12
0
 def __call__(self, a, b):
     """
     :param a: 时域信号
     :param b: 时域信号
     :return:
     """
     assert len(a.shape) == 1
     assert len(a) == len(b)
     score = stoi(a, b, self.sr)
     return score
예제 #13
0
 def forward(self, clean, enhanced):
     assert len(clean) == len(enhanced)
     scores = []
     for c, e in zip(clean, enhanced):
         q = stoi(c.detach().cpu(),
                  e.detach().cpu(),
                  self.sr,
                  extended=self.extended)
         scores.append(q)
     return torch.tensor(scores, device=self._device)
예제 #14
0
def test(device, net_type, model_path, dataset):
    from pystoi import stoi
    from pesq import pesq
    from mir_eval.separation import bss_eval_sources
    net = create_net_of_type(net_type, device, 1)
    net.load_state_dict(torch.load(model_path, map_location=device))
    net = net.to(device)
    net.eval()
    dataset = create_dataset_for(dataset, 'test')
    loader = torch.utils.data.DataLoader(dataset,
                                        batch_size=1,
                                        shuffle=True,
                                        num_workers=1)

    metrics = np.zeros((len(dataset), 3))
    i=0
    for clean, noisy, st, sm in loader:
        clean_speech = clean.numpy().flatten()
        noisy_speech = noisy.numpy().flatten()
        sm = sm.float().to(device)
        if net_type == 'phasen_1strm' or net_type == 'phasen_baseline':
            cIRM_est = net(sm)
            cIRM_est = cIRM_est.squeeze(0)
            sm = sm.squeeze(0)
            #C, T, F = decompressed_cIRM.shape
            C, T, F = cIRM_est.shape
            sout_c = torch.zeros(T, F, dtype=torch.cfloat)
            sout_c.real = cIRM_est[0,:,:]*sm[0,:,:] - cIRM_est[1,:,:]*sm[1,:,:]
            sout_c.imag = cIRM_est[0,:,:]*sm[1,:,:] + cIRM_est[1,:,:]*sm[0,:,:]
            sout_c = sout_c.T.detach().cpu().numpy()
        else:
            s_out, M, Phi = net(sm)
            # Convert s_out from (1, 2, T, F) to a complex array of shape (F, T)
            s_out = s_out.squeeze(0)
            C, T, F = s_out.shape
            sout_c = torch.zeros(T, F, dtype=torch.cfloat)
            sout_c.real = s_out[0,:,:]
            sout_c.imag = s_out[1,:,:]
            sout_c = sout_c.T.detach().cpu().numpy()

        # Recover time domain signal
        t, recovered_speech = dsp.recover_from_stft_spectrogram(sout_c, dsp.audio_fs)
        PESQ = pesq(dsp.audio_fs, clean_speech, recovered_speech, 'wb')
        STOI = stoi(clean_speech, recovered_speech, dsp.audio_fs, extended=False)
        SDR, sir, sar, perm = bss_eval_sources(clean_speech, recovered_speech)
        metrics[i,0] = SDR[0]
        metrics[i,1] = PESQ
        metrics[i,2] = STOI
        i += 1
        if i < len(dataset)-1:
            print('[Sample {}% Complete]'.format(100*i/len(dataset)), end='\r')
        else:
            print('[Sample {}% Complete]'.format(100*i/len(dataset)), end='\n')
    np.save(net_type + '_metrics.npy', metrics)
    print("Finished testing of net '{}', metrics saved in {}_metrics.npy".format(net_type, net_type))
예제 #15
0
def CompositeEval(ref_wav, deg_wav, log_all=False):
    # returns [sig, bak, ovl]
    alpha = 0.95
    len_ = min(ref_wav.shape[0], deg_wav.shape[0])
    ref_wav = ref_wav[:len_]
    ref_len = ref_wav.shape[0]
    deg_wav = deg_wav[:len_]

    # Compute WSS measure
    wss_dist_vec = wss(ref_wav, deg_wav, 16000)
    wss_dist_vec = sorted(wss_dist_vec, reverse=False)
    wss_dist = np.mean(wss_dist_vec[:int(round(len(wss_dist_vec) * alpha))])

    # Compute LLR measure
    LLR_dist = llr(ref_wav, deg_wav, 16000)
    LLR_dist = sorted(LLR_dist, reverse=False)
    LLRs = LLR_dist
    LLR_len = round(len(LLR_dist) * alpha)
    llr_mean = np.mean(LLRs[:LLR_len])

    # Compute the SSNR
    snr_mean, segsnr_mean = SSNR(ref_wav, deg_wav, 16000)
    segSNR = np.mean(segsnr_mean)
    # print(' 1')
    # Compute the PESQ
    # pesq_raw = PESQ(ref_wav, deg_wav)
    pesq_raw = pesq(ref=ref_wav, deg=deg_wav, fs=16000, mode='wb')
    # print(' 2')
    # pesq_raw = pesq(ref_wav, deg_wav, 16000)
    stoi_val = stoi(ref_wav, deg_wav, 16000, extended=False)

    # stoi_val = 0.0
    # print(' 3')

    # print('in utils.py L 435: ', type(pesq_raw))   #-> error here: <class 'NoneType'>
    # print('in utils.py L 436: ', pesq_raw)  # -> error here: <class 'NoneType'>
    # if 'error!' not in pesq_raw:
    #     pesq_raw = float(pesq_raw)
    # else:
    #     pesq_raw = -1.

    def trim_mos(val):
        return min(max(val, 1), 5)

    Csig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_raw - 0.009 * wss_dist
    Csig = trim_mos(Csig)
    Cbak = 1.634 + 0.478 * pesq_raw - 0.007 * wss_dist + 0.063 * segSNR
    Cbak = trim_mos(Cbak)
    Covl = 1.594 + 0.805 * pesq_raw - 0.512 * llr_mean - 0.007 * wss_dist
    Covl = trim_mos(Covl)
    if log_all:
        return Csig, Cbak, Covl, pesq_raw, segSNR, stoi_val * 100
    else:
        return Csig, Cbak, Covl
예제 #16
0
def sp_enhance_evals(est_source, clean_source, noisy_source, fs):
    est_source = est_source.cpu().clone().numpy()
    clean_source = clean_source.cpu().clone().numpy()
    noisy_source = noisy_source.cpu().clone().numpy()

    pesq_val = pesq(fs, clean_source, est_source, 'wb')
    stoi_val = stoi(clean_source, est_source, fs, extended=False)
    si_sdr_val = si_sdr(est_source, clean_source)
    noisy_si_sdr_val = si_sdr(noisy_source, clean_source)
    si_sdr_improvement = si_sdr_val - noisy_si_sdr_val
    return pesq_val, stoi_val, si_sdr_val, si_sdr_improvement
예제 #17
0
def Q(mode, clean, denoised, sr):
    if mode == "pesq":
        #Figure out wb and nb what is it
        return (pesq(sr, clean, denoised, 'wb') + 0.5) / 5.0
    elif mode == "stoi":
        #criterion = NegSTOILoss(sample_rate=sr)
        #print(criterion(torch.from_numpy(denoised).unsqueeze(0), torch.from_numpy(clean).unsqueeze(0)))
        #return stoi(clean, denoised, sr, extended=False)
        #return -torch.sum(criterion(torch.from_numpy(denoised).unsqueeze(0), torch.from_numpy(clean).unsqueeze(0)))
        #clean = librosa.resample(clean, sr, 10000)
        #denoised = librosa.resample(denoised, sr, 10000)
        return stoi(clean, denoised, 16000, extended=False)
예제 #18
0
def get_stoi(ref_sig, out_sig, sr):
    """Calculate STOI.
    Args:
        ref_sig: numpy.ndarray, [B, T]
        out_sig: numpy.ndarray, [B, T]
    Returns:
        STOI
    """
    stoi_val = 0
    for i in range(len(ref_sig)):
        stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False)
    return stoi_val
예제 #19
0
def _eval(batch,
          metrics,
          including='output',
          sample_rate=8000,
          use_pypesq=False):
    if use_pypesq:
        metrics = [m for m in metrics if m != 'pesq']

    has_estoi = False
    if 'estoi' in metrics:
        metrics = [m for m in metrics if m != 'estoi']
        has_estoi = True

    has_wer = False
    if 'wer' in metrics:
        metrics = [m for m in metrics if m != 'wer']
        has_wer = True

    mix = batch['mix']
    clean = batch['clean']
    estimate = batch['enh']
    snr = batch['snr']

    res = get_metrics(mix.numpy(),
                      clean.numpy(),
                      estimate.numpy(),
                      sample_rate=sample_rate,
                      metrics_list=metrics,
                      including=including)

    if use_pypesq:
        res['pesq'] = pesq(clean.flatten(), estimate.flatten(), sample_rate)

    if has_estoi:
        res['estoi'] = stoi(clean.flatten(),
                            estimate.flatten(),
                            sample_rate,
                            extended=True)

    if has_wer:
        res['wer'] = jiwer.wer(batch['clean_text'],
                               batch['transcription'],
                               truth_transform=_wer_trans,
                               hypothesis_transform=_wer_trans)

    if including == 'input':
        for m in metrics:
            res[m] = res['input_' + m]
            del res['input_' + m]

    res['snr'] = snr[0].item()
    return res
예제 #20
0
def task1_metric(clean_speech, denoised_speech, sr=16000):
    '''
    Compute evaluation metric for task 1 as (stoi+(1-word error rate)/2)
    This function computes such measure for 1 single datapoint
    '''
    WER = wer(clean_speech, denoised_speech)
    if WER is not None:  #if there is no speech in the segment
        STOI = stoi(clean_speech, denoised_speech, sr, extended=False)
        WER = np.clip(WER, 0., 1.)
        STOI = np.clip(STOI, 0., 1.)
        metric = (STOI + (1. - WER)) / 2.
    else:
        metric = None
        STOI = None
    return metric, WER, STOI
예제 #21
0
def stoi_score(y_true: np.array,
               y_pred: np.array,
               samplerate=16000,
               extended=False) -> float:
    """Computes the Short Term Objective Intelligibility metric between `y_true` and `y_pred`.

    Args:
        y_true (np.array): The original audio signal with shape (samplerate*length).
        y_pred (np.array): The predicted audio signal with shape (samplerate*length).
        samplerate (int, optional): Either 8000 or 16000. Defaults to 16000.
        extended (boolean, optional): Whenever to use the extended stoi metric instead.

    Returns:
        float: The stoi score between `y_true` and `y_pred`.
    """
    return stoi(y_true, y_pred, fs=samplerate, extended=extended)
예제 #22
0
def main():
    parser = argparse.ArgumentParser(description='Calculate performance index')
    parser.add_argument('--test_mix_folder',
                        default='../test-mix-2-babble',
                        type=str,
                        help='test-set-mix')
    parser.add_argument('--test_clean_folder',
                        default='../test-clean-2-babble',
                        type=str,
                        help='test-set-clean')
    parser.add_argument('--enhanced_folder',
                        default='../test-result',
                        type=str,
                        help='test-set-enhanced')

    opt = parser.parse_args()
    MIX_FOLDER = opt.test_mix_folder
    CLEAN_FOLDER = opt.test_clean_folder
    ENHANCED_FOLDER = opt.enhanced_folder

    pesqs = []
    stois = []

    for cleanfile in os.listdir(CLEAN_FOLDER):
        mixfile = cleanfile.replace('clean', 'mix')
        enhancedfile = 'enhanced_' + mixfile

        cleanfile = os.path.join(CLEAN_FOLDER, cleanfile)
        mixfile = os.path.join(MIX_FOLDER, mixfile)
        enhancedfile = os.path.join(ENHANCED_FOLDER, enhancedfile)

        ref, sr1 = librosa.load(cleanfile, 16000)
        #deg_mix, sr2 = librosa.load(mixfile, 16000)
        deg_enh, sr3 = librosa.load(enhancedfile, 16000)

        #pesq1 = pesq.pesq(ref, deg_mix)
        pesq2 = pesq.pesq(ref, deg_enh[:len(ref)])
        #print("pesq:", pesq1, " --> ", pesq2)

        pesqs.append(pesq2)

        #stoi1 = stoi(ref, deg_mix, fs_sig=16000)
        stoi2 = stoi(ref, deg_enh[:len(ref)], fs_sig=16000)
        #print("stoi:", stoi1, " --> ", stoi2)
        stois.append(stoi2)

    print('Epesq:', np.mean(pesqs), "Estoi:", np.mean(stois))
예제 #23
0
def cal_STOI(ref_sig, out_sig, sr):
    """Calculate STOI.
    Args:
        ref_sig: numpy.ndarray, [B, C, T]
        out_sig: numpy.ndarray, [B, C, T]
    Returns:
        STOI
    """
    B, C, T = ref_sig.shape
    ref_sig = ref_sig.reshape(B * C, T)
    out_sig = out_sig.reshape(B * C, T)
    try:
        stoi_val = 0
        for i in range(len(ref_sig)):
            stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False)
        return stoi_val / (B * C)
    except:
        return 0
def stoi_from_fft(noisy,phase_noisy,clean,phase_clean):
     """
     Calculate STOI Metric on stft batch
     """

     phase_noisy=np.array(phase_noisy)
     noisy=librosa.db_to_amplitude(noisy)

     noisy=noisy*phase_noisy
     noisy = librosa.istft(noisy)

     clean=np.array(clean)
     phase_clean=np.array(phase_clean)
     clean=librosa.db_to_amplitude(clean)
     clean=clean*phase_clean
     clean = librosa.istft(clean)

     sr =16000

     stoivalue=stoi(noisy, clean,sr, 'wb')
     #print(pesqvalue)
     return stoivalue
예제 #25
0
    def train(self):
        m_print(
            f"The amount of parameters in the project is {print_networks([self.model]) / 1e6} million."
        )

        for epoch in range(self.config['epochs']):
            total_loss = 0
            self.model.train()
            '''training'''
            for mixs_wav, cleans_wav, lengths, _ in tqdm(
                    self.train_dataloader):

                self.optimizer.zero_grad()

                mixs = torch.stft(mixs_wav,
                                  n_fft=self.n_fft,
                                  hop_length=self.hop_len,
                                  win_length=self.win_len,
                                  window=torch.hamming_window(
                                      self.win_len)).permute(0, 2, 1,
                                                             3).cuda()

                # mixs = self.stft.transform(mixs_wav.cuda())
                mixs_real = mixs[:, :, :, 0]
                mixs_imag = mixs[:, :, :, 1]
                mixs_mag = torch.sqrt(mixs_real**2 + mixs_imag**2)

                cleans = torch.stft(cleans_wav,
                                    n_fft=self.n_fft,
                                    hop_length=self.hop_len,
                                    win_length=self.win_len,
                                    window=torch.hamming_window(
                                        self.win_len)).permute(0, 2, 1,
                                                               3).cuda()

                cleans_real = cleans[:, :, :, 0]
                cleans_imag = cleans[:, :, :, 1]
                cleans_mag = torch.sqrt(cleans_real**2 + cleans_imag**2)

                # z_score
                # mixs_mag, _, _ = z_score(mixs_mag)
                # cleans_mag, _, _ = z_score(cleans_mag)

                enhances_mag = self.model(mixs_mag)

                frames = []
                for length in lengths:
                    frame = (length - self.win_len) // self.hop_len + 3
                    frames.append(frame)
                loss = self.loss_function.calculate_loss(
                    enhances_mag, cleans_mag, frames)
                total_loss += loss.item()

                loss.backward()
                self.optimizer.step()
                # break
                # print(loss.item())
                gc.collect()

            # break
            tqdm.write(f"\nepoch: {epoch}, total loss: {total_loss}")
            end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
            m_print(f"epoch: {epoch} end logging time:\t{end_time}")

            self.model.eval()
            '''validating'''
            with torch.no_grad():
                stoi_sum = 0
                pesq_sum = 0
                si_sdr_sum = 0
                count_number = 0
                for mixs_wav, cleans_wav, lengths, _ in tqdm(
                        self.eval_dataloader):
                    mixs = torch.stft(mixs_wav,
                                      n_fft=self.n_fft,
                                      hop_length=self.hop_len,
                                      win_length=self.win_len,
                                      window=torch.hamming_window(
                                          self.win_len)).permute(0, 2, 1,
                                                                 3).cuda()
                    # mixs = self.stft.transform(mixs_wav.cuda())
                    mixs_real = mixs[:, :, :, 0]
                    mixs_imag = mixs[:, :, :, 1]
                    mixs_mag = torch.sqrt(mixs_real**2 + mixs_imag**2)

                    # z_score
                    # mixs_mag, mixture_mean, mixture_std = z_score(mixs_mag)

                    enhances_mag = self.model(mixs_mag)

                    # z_score
                    # enhances_mag = reverse_z_score(enhances_mag, mixture_mean, mixture_std)
                    '''eval'''
                    enhances_real = enhances_mag * mixs_real / mixs_mag
                    enhances_imag = enhances_mag * mixs_imag / mixs_mag
                    enhances = torch.stack([enhances_real, enhances_imag], 3)
                    enhances = enhances.permute(0, 2, 1, 3)

                    enhances_wav = torch.istft(enhances,
                                               n_fft=self.n_fft,
                                               hop_length=self.hop_len,
                                               win_length=self.win_len,
                                               window=torch.hamming_window(
                                                   self.win_len).cuda(),
                                               length=max(lengths))
                    # enhances_wav = self.stft.inverse(enhances)

                    frames = []
                    # len_list = []
                    for length in lengths:
                        frame = (length - self.win_len) // self.hop_len + 3
                        frames.append(frame)
                        # len_list.append((frame - 1) * 160 + 320)

                    cleans_wav = cleans_wav.cpu().numpy()
                    enhances_wav = enhances_wav.cpu().numpy()
                    for clean, enhance, length in zip(cleans_wav, enhances_wav,
                                                      lengths):
                        clean = clean[:length]
                        enhance = enhance[:length]
                        stoi_transform = pystoi.stoi(clean, enhance, 16000)
                        pesq_transform = pypesq.pesq(clean, enhance, 16000)
                        si_sdr_transform = SI_SDR(clean, enhance)
                        if np.isnan(stoi_transform) or np.isnan(
                                pesq_transform) or np.isnan(si_sdr_transform):
                            continue
                        stoi_sum += stoi_transform
                        pesq_sum += pesq_transform
                        si_sdr_sum += si_sdr_transform
                        count_number += 1

                score = self._calculate_score(stoi=stoi_sum, pesq=pesq_sum)
                if self._is_best_score(score):
                    is_best = True
                else:
                    is_best = False
                self._save_checkpoints(epoch, is_best)

                print(count_number)
                m_print(f"stoi score: "
                        f"{stoi_sum / count_number},"
                        f"pesq score: "
                        f"{pesq_sum / count_number},"
                        f"si_sdr score: "
                        f"{si_sdr_sum / count_number},"
                        f"is best: {is_best}")
예제 #26
0
def scoring(
    output_dir: str,
    dtype: str,
    log_level: Union[int, str],
    key_file: str,
    ref_scp: List[str],
    inf_scp: List[str],
    ref_channel: int,
):
    assert check_argument_types()

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    assert len(ref_scp) == len(inf_scp), ref_scp
    num_spk = len(ref_scp)

    keys = [
        line.rstrip().split(maxsplit=1)[0]
        for line in open(key_file, encoding="utf-8")
    ]

    ref_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp
    ]
    inf_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp
    ]

    # get sample rate
    sample_rate, _ = ref_readers[0][keys[0]]

    # check keys
    for inf_reader, ref_reader in zip(inf_readers, ref_readers):
        assert inf_reader.keys() == ref_reader.keys()

    with DatadirWriter(output_dir) as writer:
        for key in keys:
            ref_audios = [ref_reader[key][1] for ref_reader in ref_readers]
            inf_audios = [inf_reader[key][1] for inf_reader in inf_readers]
            ref = np.array(ref_audios)
            inf = np.array(inf_audios)
            if ref.ndim > inf.ndim:
                # multi-channel reference and single-channel output
                ref = ref[..., ref_channel]
                assert ref.shape == inf.shape, (ref.shape, inf.shape)
            elif ref.ndim < inf.ndim:
                # single-channel reference and multi-channel output
                raise ValueError("Reference must be multi-channel when the \
                    network output is multi-channel.")
            elif ref.ndim == inf.ndim == 3:
                # multi-channel reference and output
                ref = ref[..., ref_channel]
                inf = inf[..., ref_channel]

            sdr, sir, sar, perm = bss_eval_sources(ref,
                                                   inf,
                                                   compute_permutation=True)

            for i in range(num_spk):
                stoi_score = stoi(ref[i],
                                  inf[int(perm[i])],
                                  fs_sig=sample_rate)
                si_snr_score = -float(
                    si_snr_loss(
                        torch.from_numpy(ref[i][None, ...]),
                        torch.from_numpy(inf[int(perm[i])][None, ...]),
                    ))
                writer[f"STOI_spk{i + 1}"][key] = str(stoi_score)
                writer[f"SI_SNR_spk{i + 1}"][key] = str(si_snr_score)
                writer[f"SDR_spk{i + 1}"][key] = str(sdr[i])
                writer[f"SAR_spk{i + 1}"][key] = str(sar[i])
                writer[f"SIR_spk{i + 1}"][key] = str(sir[i])
                # save permutation assigned script file
                writer[f"wav_spk{i + 1}"][key] = inf_readers[perm[i]].data[key]
예제 #27
0
def scoring(
    output_dir: str,
    dtype: str,
    log_level: Union[int, str],
    key_file: str,
    ref_scp: List[str],
    inf_scp: List[str],
    ref_channel: int,
    metrics: List[str],
    frame_size: int = 512,
    frame_hop: int = 256,
):
    assert check_argument_types()
    for metric in metrics:
        assert metric in (
            "STOI",
            "ESTOI",
            "SNR",
            "SI_SNR",
            "SDR",
            "SAR",
            "SIR",
            "framewise-SNR",
        ), metric

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    assert len(ref_scp) == len(inf_scp), ref_scp
    num_spk = len(ref_scp)

    keys = [
        line.rstrip().split(maxsplit=1)[0]
        for line in open(key_file, encoding="utf-8")
    ]

    ref_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp
    ]
    inf_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp
    ]

    # get sample rate
    fs, _ = ref_readers[0][keys[0]]

    # check keys
    for inf_reader, ref_reader in zip(inf_readers, ref_readers):
        assert inf_reader.keys() == ref_reader.keys()

    stft = STFTEncoder(n_fft=frame_size, hop_length=frame_hop)

    do_bss_eval = "SDR" in metrics or "SAR" in metrics or "SIR" in metrics
    with DatadirWriter(output_dir) as writer:
        for key in keys:
            ref_audios = [ref_reader[key][1] for ref_reader in ref_readers]
            inf_audios = [inf_reader[key][1] for inf_reader in inf_readers]
            ref = np.array(ref_audios)
            inf = np.array(inf_audios)
            if ref.ndim > inf.ndim:
                # multi-channel reference and single-channel output
                ref = ref[..., ref_channel]
                assert ref.shape == inf.shape, (ref.shape, inf.shape)
            elif ref.ndim < inf.ndim:
                # single-channel reference and multi-channel output
                raise ValueError("Reference must be multi-channel when the "
                                 "network output is multi-channel.")
            elif ref.ndim == inf.ndim == 3:
                # multi-channel reference and output
                ref = ref[..., ref_channel]
                inf = inf[..., ref_channel]

            if do_bss_eval or num_spk > 1:
                sdr, sir, sar, perm = bss_eval_sources(
                    ref, inf, compute_permutation=True)
            else:
                perm = [0]

            ilens = torch.LongTensor([ref.shape[1]])
            # (num_spk, T, F)
            ref_spec, flens = stft(torch.from_numpy(ref), ilens)
            inf_spec, _ = stft(torch.from_numpy(inf), ilens)

            for i in range(num_spk):
                p = int(perm[i])
                for metric in metrics:
                    name = f"{metric}_spk{i + 1}"
                    if metric == "STOI":
                        writer[name][key] = str(
                            stoi(ref[i], inf[p], fs_sig=fs, extended=False))
                    elif metric == "ESTOI":
                        writer[name][key] = str(
                            stoi(ref[i], inf[p], fs_sig=fs, extended=True))
                    elif metric == "SNR":
                        si_snr_score = -float(
                            ESPnetEnhancementModel.snr_loss(
                                torch.from_numpy(ref[i][None, ...]),
                                torch.from_numpy(inf[p][None, ...]),
                            ))
                        writer[name][key] = str(si_snr_score)
                    elif metric == "SI_SNR":
                        si_snr_score = -float(
                            ESPnetEnhancementModel.si_snr_loss(
                                torch.from_numpy(ref[i][None, ...]),
                                torch.from_numpy(inf[p][None, ...]),
                            ))
                        writer[name][key] = str(si_snr_score)
                    elif metric == "SDR":
                        writer[name][key] = str(sdr[i])
                    elif metric == "SAR":
                        writer[name][key] = str(sar[i])
                    elif metric == "SIR":
                        writer[name][key] = str(sir[i])
                    elif metric == "framewise-SNR":
                        framewise_snr = -ESPnetEnhancementModel.snr_loss(
                            ref_spec[i].abs(), inf_spec[i].abs())
                        writer[name][key] = " ".join(
                            map(str, framewise_snr.tolist()))
                    else:
                        raise ValueError("Unsupported metric: %s" % metric)
                    # save permutation assigned script file
                    writer[f"wav_spk{i + 1}"][key] = inf_readers[
                        perm[i]].data[key]
예제 #28
0
def estoi_eval(pred_wav, target_wav):
    return stoi(x=target_wav.numpy(),
                y=pred_wav.numpy(),
                fs_sig=16000,
                extended=True)
예제 #29
0
def compute_metrics_utt(args):
    # Separate args
    file_path, snr_db = args[0], args[1]

    # print(file_path)
    # Read files
    s_t, fs_s = sf.read(processed_data_dir + os.path.splitext(file_path)[0] +
                        '_s.wav')  # clean speech
    n_t, fs_n = sf.read(processed_data_dir + os.path.splitext(file_path)[0] +
                        '_n.wav')  # noise
    x_t, fs_x = sf.read(processed_data_dir + os.path.splitext(file_path)[0] +
                        '_x.wav')  # mixture
    s_hat_t, fs_s_hat = sf.read(model_data_dir +
                                os.path.splitext(file_path)[0] +
                                '_s_est.wav')  # est. speech

    # compute metrics
    ## SI-SDR, SI-SAR, SI-SNR
    si_sdr, si_sir, si_sar = energy_ratios(s_hat=s_hat_t, s=s_t, n=n_t)

    ## STOI (or ESTOI?)
    stoi_s_hat = stoi(s_t, s_hat_t, fs, extended=True)
    # all_stoi.append(stoi_s_hat)

    ## PESQ
    pesq_s_hat = pesq(fs, s_t, s_hat_t, 'wb')  # wb = wideband
    # all_pesq.append(pesq_s_hat)

    ## POLQA
    # polqa_s_hat = polqa(s, s_t, fs)
    # all_polqa.append(polqa_s_hat)

    # TF representation
    s_tf = stft(s_t,
                fs=fs,
                wlen_sec=wlen_sec,
                win=win,
                hop_percent=hop_percent,
                dtype=dtype)  # shape = (freq_bins, frames)

    # plots of target / estimation
    # TF representation
    x_tf = stft(x_t,
                fs=fs,
                wlen_sec=wlen_sec,
                win=win,
                hop_percent=hop_percent,
                dtype=dtype)  # shape = (freq_bins, frames)

    s_hat_tf = stft(s_hat_t,
                    fs=fs,
                    wlen_sec=wlen_sec,
                    win=win,
                    hop_percent=hop_percent,
                    dtype=dtype)  # shape = (freq_bins, frames)

    ## mixture signal (wav + spectro)
    ## target signal (wav + spectro + mask)
    ## estimated signal (wav + spectro + mask)
    signal_list = [
        [x_t, x_tf, None],  # mixture: (waveform, tf_signal, no mask)
        [s_t, s_tf, None],  # clean speech
        [s_hat_t, s_hat_tf, None]
    ]
    fig = display_multiple_signals(signal_list,
                                   fs=fs,
                                   vmin=vmin,
                                   vmax=vmax,
                                   wlen_sec=wlen_sec,
                                   hop_percent=hop_percent,
                                   xticks_sec=xticks_sec,
                                   fontsize=fontsize)

    # put all metrics in the title of the figure
    title = "Input SNR = {:.1f} dB \n" \
        "SI-SDR = {:.1f} dB, " \
        "SI-SIR = {:.1f} dB, " \
        "SI-SAR = {:.1f} dB \n" \
        "STOI = {:.2f}, " \
        "PESQ = {:.2f} \n" \
        "".format(snr_db, si_sdr, si_sir, si_sar, stoi_s_hat, pesq_s_hat)

    fig.suptitle(title, fontsize=40)

    # Save figure
    fig.savefig(model_data_dir + os.path.splitext(file_path)[0] + '_fig.png')

    # Clear figure
    plt.close()

    metrics = [si_sdr, si_sir, si_sar, stoi_s_hat, pesq_s_hat]
    return metrics
예제 #30
0
def main():
    # Load input SNR
    all_snr_db = read_dataset(processed_data_dir, dataset_type, 'snr_db')
    all_snr_db = np.array(all_snr_db)

    # Create file list
    file_paths = speech_list(input_speech_dir=input_speech_dir,
                             dataset_type=dataset_type)

    # 1 list per metric
    all_stoi = []
    all_pesq = []
    all_polqa = []
    all_f1score = []

    for i, file_path in tqdm(enumerate(file_paths)):

        # Read files
        s_t, fs_s = sf.read(processed_data_dir +
                            os.path.splitext(file_path)[0] +
                            '_s.wav')  # clean speech
        n_t, fs_n = sf.read(processed_data_dir +
                            os.path.splitext(file_path)[0] + '_n.wav')  # noise
        x_t, fs_x = sf.read(processed_data_dir +
                            os.path.splitext(file_path)[0] +
                            '_x.wav')  # mixture

        # compute metrics

        ## STOI (or ESTOI?)
        stoi_s_hat = stoi(s_t, x_t, fs, extended=True)
        all_stoi.append(stoi_s_hat)

        ## PESQ
        pesq_s_hat = pesq(fs, s_t, x_t, 'wb')  # wb = wideband
        all_pesq.append(pesq_s_hat)

        ## POLQA
        # polqa_s_hat = polqa(s, s_t, fs)
        # all_polqa.append(polqa_s_hat)

        # TF representation
        n_tf = stft(n_t,
                    fs=fs,
                    wlen_sec=wlen_sec,
                    win=win,
                    hop_percent=hop_percent,
                    dtype=dtype)  # shape = (freq_bins, frames)

        s_tf = stft(s_t,
                    fs=fs,
                    wlen_sec=wlen_sec,
                    win=win,
                    hop_percent=hop_percent,
                    dtype=dtype)  # shape = (freq_bins, frames)

        # plots of target / estimation
        # TF representation
        x_tf = stft(x_t,
                    fs=fs,
                    wlen_sec=wlen_sec,
                    win=win,
                    hop_percent=hop_percent,
                    dtype=dtype)  # shape = (freq_bins, frames)

        # ## mixture signal (wav + spectro)
        # ## target signal (wav + spectro + mask)
        # ## estimated signal (wav + spectro + mask)
        # signal_list = [
        #     [x_t, x_tf, None], # mixture: (waveform, tf_signal, no mask)
        #     [s_t, s_tf, None], # clean speech
        #     [n_t, n_tf, None]
        # ]
        # fig = display_multiple_signals(signal_list,
        #                     fs=fs, vmin=vmin, vmax=vmax,
        #                     wlen_sec=wlen_sec, hop_percent=hop_percent,
        #                     xticks_sec=xticks_sec, fontsize=fontsize)

        # # put all metrics in the title of the figure
        # title = "Input SNR = {:.1f} dB \n" \
        #     "STOI = {:.2f}, " \
        #     "PESQ = {:.2f} \n" \
        #     "".format(all_snr_db[i], stoi_s_hat, pesq_s_hat)

        # fig.suptitle(title, fontsize=40)

        # # Save figure
        # fig.savefig(processed_data_dir + os.path.splitext(file_path)[0] + '_fig.png')

        # # Clear figure
        # plt.close()

    # Confidence interval
    metrics = {'SNR': all_snr_db, 'STOI': all_stoi, 'PESQ': all_pesq}

    stats = {}

    # Print the names of the columns.
    print("{:<10} {:<10} {:<10}".format('METRIC', 'AVERAGE', 'CONF. INT.'))
    for key, metric in metrics.items():
        m, h = mean_confidence_interval(metric, confidence=confidence)
        stats[key] = {'avg': m, '+/-': h}
        print("{:<10} {:<10} {:<10}".format(key, m, h))
    print('\n')

    # Save stats (si_sdr, si_sar, etc. )
    with open(
            processed_data_dir + os.path.dirname(os.path.dirname(file_path)) +
            'stats.json', 'w') as f:
        json.dump(stats, f)

    # Metrics by input SNR
    for snr_db in np.unique(all_snr_db):
        stats = {}

        print('Input SNR = {:.2f}'.format(snr_db))
        # Print the names of the columns.
        print("{:<10} {:<10} {:<10}".format('METRIC', 'AVERAGE', 'CONF. INT.'))
        for key, metric in metrics.items():
            subset_metric = np.array(metric)[np.where(all_snr_db == snr_db)]
            m, h = mean_confidence_interval(subset_metric,
                                            confidence=confidence)
            stats[key] = {'avg': m, '+/-': h}
            print("{:<10} {:<10} {:<10}".format(key, m, h))
        print('\n')

        # Save stats (si_sdr, si_sar, etc. )
        with open(
                processed_data_dir +
                os.path.dirname(os.path.dirname(file_path)) +
                'stats_{:g}.json'.format(snr_db), 'w') as f:
            json.dump(stats, f)