def evaluation_all(path, index_file):
    with open(os.path.join(path, index_file), 'r') as f:
        filelist = f.readlines()
    filelist = [filename.replace('\n', '') for filename in filelist]
    pesq_estm = []
    pesq_noisy = []
    stoi_estm = []
    stoi_noisy = []
    fwsnr_estm = []
    fwsnr_noisy = []
    sr = 16000
    n_audio = len(filelist)
    for item in filelist:
        wav_clean = librosa.core.load(os.path.join(path, item[1]), sr=16000)[0]
        wav_noisy = librosa.core.load(os.path.join(path, item[2]), sr=16000)[0]
        wav_estm = librosa.core.load(os.path.join(path, item[0], sr=16000))[0]

        pesq_estm.append("{}: {}\n".format(
            item[0],
            pypesq(sr, wav_clean, wav_estm, 'nb')[0]))
        pesq_noisy.append("{}: {}\n".format(
            item[2],
            pypesq(sr, wav_clean, wav_noisy, 'nb')[0]))
        stoi_estm.append("{}: {}\n".format(item[0],
                                           stoi.stoi(wav_clean, wav_estm, sr)))
        stoi_noisy.append("{}: {}\n".format(
            item[2], stoi.stoi(wav_clean, wav_noisy, sr)))
        fwsnr_estm.append("{}: {}\n".format(
            item[0], fwseg_snr.fwSNRseg(wav_clean, wav_estm, sr)))
        fwsnr_noisy.append("{}: {}\n".format(
            item[2], fwseg_snr.fwSNRseg(wav_clean, wav_noisy, sr)))

    with open(os.path.join(path, "pesq_score.txt"), 'w') as f:
        f.writelines(pesq_estm)
        f.writelines(pesq_noisy)
    with open(os.path.join(path, "stoi_score.txt"), 'w') as f:
        f.writelines(stoi_estm)
        f.writelines(stoi_noisy)
    with open(os.path.join(path, "fw_seg_snr.txt"), 'w') as f:
        f.writelines(fwsnr_estm)
        f.writelines(fwsnr_noisy)

    pesq, pesq_avg = read_score_1(os.path.join(path, "pesq_score.txt"),
                                  n_audio=n_audio)
    stoi, stoi_avg = read_score_1(os.path.join(path, "stoi_score.txt"),
                                  n_audio=n_audio)
    fwsnr, fwsnr_avg = read_score_1(os.path.join(path, "fw_seg_snr.txt"),
                                    n_audio=n_audio)

    write_final_result_1(os.path.join(path, "final_results.txt"), pesq_avg,
                         stoi_avg, fwsnr_avg)
Example #2
0
def test():
    data_dir = Path(__file__).parent.parent / 'audio'
    ref_path = data_dir / 'speech.wav'
    deg_path = data_dir / 'speech_bab_0dB.wav'

    sample_rate, ref = scipy.io.wavfile.read(ref_path)
    sample_rate, deg = scipy.io.wavfile.read(deg_path)

    score = pypesq(ref=ref, deg=deg, fs=sample_rate, mode='wb')

    assert score == 1.0832337141036987, score

    score = pypesq(ref=ref, deg=deg, fs=sample_rate, mode='nb')

    assert score == 1.6072081327438354, score
Example #3
0
def pesq(ref, deg):
    try:
        ref = np.reshape(ref, (16000, ))
        deg = np.reshape(deg, (16000, ))
        return pypesq.pypesq(16000, ref, deg, 'nb')
    except:
        return 0
def compute_scores(audiofile_ref, audiofile_deg, sr=16000):
    wav_ref = librosa.core.load(audiofile_ref, sr)[0]
    wav_deg = librosa.core.load(audiofile_deg, sr)[0]

    pesq_score = pypesq(sr, wav_ref, wav_deg, 'nb')[0]
    stoi_score = stoi.stoi(wav_ref, wav_deg, sr)
    fwsnr_score = fwseg_snr.fwSNRseg(wav_ref, wav_deg, sr)

    return pesq_score, stoi_score, fwsnr_score
Example #5
0
def pesq(clean_speech, processed_speech, fs):
    if fs == 8000:
        pesq_mos = pypesq(fs, clean_speech, processed_speech, 'nb')
        pesq_mos = 46607 / 14945 - (2000 * np.log(
            1 /
            (pesq_mos / 4 - 999 / 4000) - 1)) / 2989  #remap to raw pesq score

    elif fs == 16000:
        pesq_mos = pypesq(fs, clean_speech, processed_speech, 'wb')
    elif fs >= 16000:
        numSamples = round(len(clean_speech) / fs * 16000)
        pesq_mos = pypesq(fs, resample(clean_speech, numSamples),
                          resample(processed_speech, numSamples), 'wb')
    else:
        numSamples = round(len(clean_speech) / fs * 8000)
        pesq_mos = pypesq(fs, resample(clean_speech, numSamples),
                          resample(processed_speech, numSamples), 'nb')
        pesq_mos = 46607 / 14945 - (2000 * np.log(
            1 /
            (pesq_mos / 4 - 999 / 4000) - 1)) / 2989  #remap to raw pesq score

    return pesq_mos
Example #6
0
def calculate_pesq(args):
    """Calculate PESQ of all enhaced speech. 
    
    Args:
      workspace: str, path of workspace. 
      speech_dir: str, path of clean speech. 
      te_snr: float, testing SNR. 
    """
    # Remove already existed file.
    data_type = args.data_type
    speech_dir = "mini_data/test_speech"
    f = "{0:<16} {1:<16} {2:<16}"
    print(f.format("0", "Noise", "PESQ"))
    f1 = open(data_type + '_pesq_results.csv', 'w')
    f1.write("%s\t%s\n" % ("audio_id", "PESQ"))
    # Calculate PESQ of all enhaced speech.
    if data_type == "DM":
        enh_speech_dir = os.path.join("workspace", "enh_wavs", "test", "mixdb")
    elif data_type == "IRM":
        enh_speech_dir = os.path.join("workspace", "enh_wavs", "test",
                                      "mask_mixdb")
    elif data_type == "CRN":
        enh_speech_dir = os.path.join("workspace", "enh_wavs", "test",
                                      "crn_mixdb")
    elif data_type == "PHASE":
        enh_speech_dir = os.path.join("workspace", "enh_wavs", "test",
                                      "phase_spec_clean_mixdb")
    elif data_type == "VOLUME":
        enh_speech_dir = os.path.join("workspace", "enh_wavs", "test",
                                      "volume_mixdb")
    elif data_type == "NOISE":
        enh_speech_dir = os.path.join("workspace", 'mixed_audios',
                                      'spectrogram', 'test', 'mixdb')
    names = os.listdir(enh_speech_dir)
    for (cnt, na) in enumerate(names):
        enh_path = os.path.join(enh_speech_dir, na)
        enh_audio, fs = soundfile.read(enh_path)
        speech_na = na.split('.')[0]
        speech_path = os.path.join(speech_dir, "%s.WAV" % speech_na)
        speech_audio, fs = soundfile.read(speech_path)
        #alpha = 1. / np.max(np.abs(speech_audio))
        #speech_audio *=alpha
        pesq_ = pypesq(16000, speech_audio, enh_audio, 'wb')
        print(f.format(cnt, na, pesq_))
        f1.write("%s\t%f\n" % (na, pesq_))
        # Call executable PESQ tool.
        #cmd = ' '.join(["./pesq", speech_path, enh_path, "+16000"])
        #os.system(cmd)
    os.system("mv %s_pesq_results.csv ./pesq_result/%s_pesq_results.csv" %
              (data_type, data_type))
def pesq_score(clean_wavs, reconst_wavs, band='wb'):

    scores = []

    print('PESQ Calculation...')
    for i, (clean_, reconst_) in enumerate(zip(clean_wavs, reconst_wavs)):
        rate, ref = wavfile.read(clean_)
        rate, deg = wavfile.read(reconst_)
        score = pypesq(rate, ref, deg, band)
        scores.append(score)
        print('Score : {0} ... {1}/{2}'.format(score, i, len(clean_wavs)))

    score = np.average(np.array(scores))

    print('  ---------------------------------------------------')
    print('  Average PESQ score = {0}'.format(score))
    print('  ---------------------------------------------------')

    return 0
Example #8
0
    def pypesq(self):
        # pypesq does not release the GIL. Either release our pesq code or
        # change pypesq to release the GIL and be thread safe
        try:
            import pypesq
        except ImportError:
            raise AssertionError(
                'To use this pesq implementation, install '
                'https://github.com/ludlows/python-pesq .'
            )
        mode = {8000: 'nb', 16000: 'wb'}[self.sample_rate]

        assert self.speech_source.shape == self.speech_prediction_selection.shape, (self.speech_source.shape, self.speech_prediction_selection.shape)  # NOQA
        assert self.speech_source.ndim == 2, (self.speech_source.shape, self.speech_prediction_selection.shape)  # NOQA
        assert self.speech_source.shape[0] < 5, (self.speech_source.shape, self.speech_prediction_selection.shape)  # NOQA

        return [
            pypesq.pypesq(ref=ref, deg=deg, fs=self.sample_rate, mode=mode)
            for ref, deg in zip(
                self.speech_source, self.speech_prediction_selection)
        ]
Example #9
0
def pesq_score(clean_wav, reconst_wav, split_num=100, band='nb'):

    rate, ref = wavfile.read(clean_wav)
    rate, deg = wavfile.read(reconst_wav)

    ref_s = np.array_split(ref, split_num)
    deg_s = np.array_split(deg, split_num)

    scores = []

    print('PESQ Calculation...')
    for i in range(split_num):
        if i % 10 == 0:
            print('  No. {0}/{1}...'.format(i, split_num))
        scores.append(pypesq(rate, ref_s[i], deg_s[i], band))

    score = np.average(np.array(scores))

    print('  ---------------------------------------------------')
    print('  PESQ score = {0}'.format(score))
    print('  ---------------------------------------------------')

    return 0
Example #10
0
noise_folder = 'asr_cli_sounds/final-mix-16k-wav'
rnnoise_folder = 'asr_cli_sounds/final-rnnoise-16k-wav'

for filepath in glob.iglob(org_folder + '/*.wav'):
    #print(filepath)
    filename = os.path.basename(filepath)

    #original
    rate, ref = wavfile.read(filepath)
    clean, fs = sf.read(filepath)

    #mix with noise degraded
    rate, deg = wavfile.read(noise_folder + '/car_' + filename)
    denoised, fs = sf.read(noise_folder + '/car_' + filename)
    estoi_deg = "%.3f" % stoi(clean, denoised, fs, extended=True)
    pesq_deg = "%.3f" % pypesq(rate, ref, deg, 'wb')

    #rnnoise processed
    rate, deg = wavfile.read(rnnoise_folder + '/rnnoise_16k_car_' + filename)
    denoised, fs = sf.read(rnnoise_folder + '/rnnoise_16k_car_' + filename)
    estoi_rnn = "%.3f" % stoi(clean, denoised, fs, extended=True)
    pesq_rnn = "%.3f" % pypesq(rate, ref, deg, 'wb')

    #print("=========================Car Noise==================================")
    print("car," + filename + ',' + str(pesq_deg) + ',' + str(pesq_rnn) + ',' +
          str(estoi_deg) + ',' + str(estoi_rnn))
    #print("pseq score",pypesq(rate, ref, deg, 'wb'))
    #print("stoi score",d)
    #print("estoi score",ed)

    #mix with noise degraded
Example #11
0
def infer_a_sample(sample_i, round):
    target_path = target_list[sample_i]
    mix_flag = random.random()
    if mix_flag < 0.33:
        interf_id = random.randint(0, len(target_list) - 1)
        interf_path = target_list[interf_id]
        interf = target_dict[interf_path]
    elif 0.33 <= mix_flag < 0.66:
        interf_id = random.randint(0, len(interf_audio_list) - 1)
        interf_path = interf_audio_list[interf_id]
        interf = interf_dict[interf_path]
    else:
        interf_id = random.randint(0, len(interf_esc_list) - 1)
        interf_path = interf_esc_list[interf_id]
        interf = interf_dict[interf_path]
    target = target_dict[target_path]
    chooseIndexNormalised = None
    batchLen = opt.seqLen
    saveMixtureName = None
    interf_scale = SNR_db_to_scale(opt.db)
    print('Mixing scale is {}'.format(interf_scale))
    (batchInput1, batchTarget, batchMixphase, batchTargetphase,
     chooseIndex) = ExtractFeatureFromOneSignal_fromMemory(
         target, interf, interf_scale, chooseIndexNormalised, batchLen,
         add_hole, saveMixtureName, opt.dBscale, hparams.sample_rate,
         opt.normalize)

    phase_dim = halfFFT + 1
    L = chooseIndex[-1] + opt.seqLen

    train_model.eval()

    with torch.no_grad():
        batchInput1 = torch.from_numpy(batchInput1)

        if opt.cuda:
            batchInput1 = batchInput1.cuda()

        pred = train_model(batchInput1)

    pred_spec = np.zeros((L, halfFFT))
    gt_spec = np.zeros((L, halfFFT))
    mix_spec = np.zeros((L, halfFFT))
    mix_phase = np.zeros((L, phase_dim))
    target_phase = np.zeros((L, phase_dim))

    try:  # in case the wav is shorter than seqLen frames
        for n, i in enumerate(chooseIndex):
            # the mixture
            mix_spec[i:i + opt.seqLen] = batchInput1[n].data.cpu().numpy()
            # the gt
            gt_spec[i:i + opt.seqLen] = batchTarget[n]
            # the prediction
            pred_spec[i:i + opt.seqLen] = pred[n].data.cpu().numpy()

            mix_phase[i:i + opt.seqLen] = batchMixphase[n]
            target_phase[i:i + opt.seqLen] = batchTargetphase[n]

        # Forgot to add options!!
        mix_wav = postpro_and_gen(mix_spec,
                                  mix_phase,
                                  dBscale=opt.dBscale,
                                  denormalize=opt.normalize)
        target_wav = postpro_and_gen(gt_spec,
                                     target_phase,
                                     dBscale=opt.dBscale,
                                     denormalize=opt.normalize)
        pred_wav = postpro_and_gen(pred_spec,
                                   mix_phase,
                                   dBscale=opt.dBscale,
                                   denormalize=opt.normalize)

        mixtureSaveName = 'r_{}_samp_{}_input.wav'.format(round, sample_i)
        targetSaveName = 'r_{}_samp_{}_target.wav'.format(round, sample_i)
        predSaveName = 'r_{}_samp_{}_pred.wav'.format(round, sample_i)
        figName = 'r_{}_samp_{}_fig.jpg'.format(round, sample_i)

        # save the audios locally
        sf.write(os.path.join(wav_dir, mixtureSaveName), mix_wav,
                 hparams.sample_rate)
        sf.write(os.path.join(wav_dir, targetSaveName), target_wav,
                 hparams.sample_rate)
        sf.write(os.path.join(wav_dir, predSaveName), pred_wav,
                 hparams.sample_rate)
        print('Saving to r_{}_samp_{}'.format(round, sample_i))

        # compute scores
        sdr_this_sample = bss_eval_sources(target_wav, pred_wav)[0]
        pesq_this_sample = pypesq(hparams.sample_rate, target_wav, pred_wav,
                                  'nb')
        stoi_this_sample = stoi(target_wav,
                                pred_wav,
                                hparams.sample_rate,
                                extended=False)

        # plot fig
        plot_and_compare(mix_wav, target_wav, pred_wav,
                         os.path.join(wav_dir, figName), hparams.sample_rate)

        print('SDR, PESQ, STOI this sample is {}, {}, {}.'.format(
            sdr_this_sample, pesq_this_sample, stoi_this_sample))

        if opt.compute_orig:
            sdr_orig_this_sample = bss_eval_sources(target_wav, mix_wav)[0]
            pesq_orig_this_sample = pypesq(hparams.sample_rate, target_wav,
                                           mix_wav, 'nb')
            stoi_orig_this_sample = stoi(target_wav,
                                         mix_wav,
                                         hparams.sample_rate,
                                         extended=False)
            print('SDR, PESQ, STOI this sample is {}, {}, {} orig.'.format(
                sdr_orig_this_sample, pesq_orig_this_sample,
                stoi_orig_this_sample))

    except Exception as e:
        sdr_this_sample = None
        pesq_this_sample = None
        stoi_this_sample = None
        sdr_orig_this_sample = None
        pesq_orig_this_sample = None
        stoi_orig_this_sample = None
        print(e)

    if opt.compute_orig:
        return sdr_this_sample, pesq_this_sample, stoi_this_sample, \
               sdr_orig_this_sample, pesq_orig_this_sample, stoi_orig_this_sample
    else:
        return sdr_this_sample, pesq_this_sample, stoi_this_sample
Example #12
0
def ratevoice():
    rate, ref = wavfile.read(voicewave)
    rate, deg = wavfile.read(voicerd)
    pesqvalnb = pypesq(rate, ref, deg, 'nb')
    print("The voice quality is %.2f (0=worst, 4.5=best)" % pesqvalnb)
def evaluate(folder_audio):
    results_file = os.path.join(FOLDER, 'results.csv')
    if os.path.exists(results_file):
        results_file = os.path.join(
            FOLDER, 'results' + os.path.split(folder_audio)[1] + '.csv')
    with open(results_file, mode='a', newline='') as csv_file:
        PR_STOIS = []
        OR_STOIS = []
        fieldnames = [
            'Sample', 'Speech', 'Noise', 'SNR', 'STOI orig.', 'STOI pred.',
            'eSTOI orig.', 'eSTOI pred.', 'PESQ orig.', 'PESQ pred.'
        ]

        class excel_semicolon(csv.excel):
            delimiter = ';'

        writer = csv.DictWriter(csv_file,
                                fieldnames=fieldnames,
                                dialect=excel_semicolon,
                                extrasaction='ignore')
        writer.writeheader()
        sleep(0.1)  # for tqdm
        pred_stois, orig_stois = [], []
        pred_estois, orig_estois = [], []
        pred_pesqs, orig_pesqs = [], []
        speech_names, noise_names = [], []
        snrs = []
        index = 0
        n = get_count_of_audiofiles(folder_audio) // 3
        for i in tqdm(range(n), total=n, desc='Calculating STOI & PESQ'):
            list_audio = [
                k for k in get_list_of_files(folder_audio) if '.wav' in k
            ]
            list_audio.sort()
            assert len(list_audio) % 3 == 0
            filename = list_audio[index][:-9]
            fsx, x = read_audio(filename + 'noisy.wav')
            fsy, y = read_audio(filename + 'clean.wav')
            fsyh, y_hat = read_audio(filename + 'predi.wav')
            x, y = x[:len(y_hat)], y[:len(y_hat)]
            assert fsx == fsy == fsyh == target_fs
            assert len(x) == len(y) == len(y_hat)

            index += 3
            # filenames
            _, f = os.path.split(filename)
            speech_noise_name = f[:-5] if f[-4] is '-' else f[:-4]
            sn = speech_noise_name.split('_')
            sn = [x.strip() for x in sn if x.strip()]
            speech_name = sn[0]
            noise_name = sn[1]
            speech_names.append(speech_name)
            noise_names.append(noise_name)
            # snr
            snr_string = f[-5:-3]
            snr = int(
                snr_string[1]) if snr_string[0] is '_' else int(snr_string)
            snrs.append(snr)
            # STOI
            pred_stoi = np.round(stoi(y, y_hat, target_fs), 3)
            orig_stoi = np.round(stoi(y, x, target_fs), 3)
            # eSTOI
            pred_estoi = np.round(stoi(y, y_hat, target_fs, extended=True), 3)
            orig_estoi = np.round(stoi(y, x, target_fs, extended=True), 3)
            # PESQ
            pred_pesq = np.round(
                pypesq(fs=target_fs, ref=y, deg=y_hat, mode='wb'), 3)
            orig_pesq = np.round(pypesq(fs=target_fs, ref=y, deg=x, mode='wb'),
                                 3)
            # Results
            pred_stois.append(pred_stoi)
            pred_estois.append(pred_estoi)
            pred_pesqs.append(pred_pesq)
            orig_stois.append(orig_stoi)
            orig_estois.append(orig_estoi)
            orig_pesqs.append(orig_pesq)
            writer.writerow({
                'Sample': i,
                'Speech': speech_name,
                'Noise': noise_name,
                'SNR': snr,
                'STOI orig.': orig_stoi,
                'STOI pred.': pred_stoi,
                'eSTOI orig.': orig_estoi,
                'eSTOI pred.': pred_estoi,
                'PESQ orig.': orig_pesq,
                'PESQ pred.': pred_pesq
            })
        sleep(0.15)  # for tqdm

        # Results analysis with pandas
        csv_file.close()
        total_metrics = 'Orig. STOI: %s - eSTOI: %s - PESQ: %s \nPred. STOI: %s - eSTOI: %s - PESQ: %s' % \
                        (mean_std(np.array(orig_stois)), mean_std(np.array(orig_estois)), mean_std(np.array(orig_pesqs)),
                         mean_std(np.array(pred_stois)), mean_std(np.array(pred_estois)), mean_std(np.array(pred_pesqs)))
        with open(os.path.join(FOLDER, 'results_total.txt'), 'a') as file:
            file.write(total_metrics)
            file.close()
        df = pd.read_csv(results_file, sep=';')
        fig, ax = plt.subplots()
        df.groupby('Noise').mean()['STOI orig.'].plot(kind='bar',
                                                      ax=ax,
                                                      position=1,
                                                      width=0.3,
                                                      color='C0')
        df.groupby('Noise').mean()['STOI pred.'].plot(kind='bar',
                                                      ax=ax,
                                                      position=0,
                                                      width=0.3,
                                                      color='C1')
        plt.legend()
        plt.savefig(FOLDER + '/metrics_1stoi.png',
                    dpi=600)  # , bbox_inches='tight')
        plt.clf()
        plt.cla()
        plt.close()

        fig, ax = plt.subplots()
        df.groupby('Noise').mean()['eSTOI orig.'].plot(kind='bar',
                                                       ax=ax,
                                                       position=1,
                                                       width=0.3,
                                                       color='C0')
        df.groupby('Noise').mean()['eSTOI pred.'].plot(kind='bar',
                                                       ax=ax,
                                                       position=0,
                                                       width=0.3,
                                                       color='C1')
        plt.legend()
        plt.savefig(FOLDER + '/metrics_2estoi.png',
                    dpi=600)  # , bbox_inches='tight')
        # plt.show()
        plt.clf()
        plt.cla()
        plt.close()

        fig, ax = plt.subplots()
        df.groupby('Noise').mean()['PESQ orig.'].plot(kind='bar',
                                                      ax=ax,
                                                      position=1,
                                                      width=0.3,
                                                      color='C0')
        df.groupby('Noise').mean()['PESQ pred.'].plot(kind='bar',
                                                      ax=ax,
                                                      position=0,
                                                      width=0.3,
                                                      color='C1')
        plt.legend()
        plt.savefig(FOLDER + '/metrics_3pesq.png',
                    dpi=600)  # , bbox_inches='tight')
        # plt.show()
        plt.clf()
        plt.cla()
        plt.close()

        fig, ax = plt.subplots()
        df.groupby('SNR').mean()['STOI orig.'].plot(kind='bar',
                                                    ax=ax,
                                                    position=1,
                                                    width=0.3,
                                                    color='C0')
        df.groupby('SNR').mean()['STOI pred.'].plot(kind='bar',
                                                    ax=ax,
                                                    position=0,
                                                    width=0.3,
                                                    color='C1')
        plt.legend()
        plt.savefig(FOLDER + '/metrics_snr_1stoi.png',
                    dpi=600)  # , bbox_inches='tight')
        # plt.show()
        plt.clf()
        plt.cla()
        plt.close()

        fig, ax = plt.subplots()
        df.groupby('SNR').mean()['eSTOI orig.'].plot(kind='bar',
                                                     ax=ax,
                                                     position=1,
                                                     width=0.3,
                                                     color='C0')
        df.groupby('SNR').mean()['eSTOI pred.'].plot(kind='bar',
                                                     ax=ax,
                                                     position=0,
                                                     width=0.3,
                                                     color='C1')
        plt.legend()
        plt.savefig(FOLDER + '/metrics_snr_2estoi.png',
                    dpi=600)  # , bbox_inches='tight')
        # plt.show()
        plt.clf()
        plt.cla()
        plt.close()

        fig, ax = plt.subplots()
        df.groupby('SNR').mean()['PESQ orig.'].plot(kind='bar',
                                                    ax=ax,
                                                    position=1,
                                                    width=0.3,
                                                    color='C0')
        df.groupby('SNR').mean()['PESQ pred.'].plot(kind='bar',
                                                    ax=ax,
                                                    position=0,
                                                    width=0.3,
                                                    color='C1')
        plt.legend()
        plt.savefig(FOLDER + '/metrics_snr_3pesq.png',
                    dpi=600)  # , bbox_inches='tight')
        # plt.show()
        plt.clf()
        plt.cla()
        plt.close()

        PR_STOIS.extend(pred_stois)
        OR_STOIS.extend(orig_stois)

        print(
            '__________________________________________________________________________________________________'
        )
        print('Evaluation Results: (%d files)\n' % (n))
        print(total_metrics)
        print(
            '__________________________________________________________________________________________________'
        )

    return total_metrics
Example #14
0
def pypesq_eval(src, tar, sr=16000):
    assert src.ndim == 1 and tar.ndim == 1
    assert not np.allclose(src.sum(), 0.0, atol=1e-6) and not np.allclose(
        tar.sum(), 0.0, atol=1e-6)
    raw_pesq = pypesq(tar, src, sr)
    return raw_pesq
def validate(meta_dir: str,
             model_name: str,
             pretrained_path: str,
             out_dir: str = '',
             batch_size: int = 64,
             num_workers: int = 16,
             sr: int = 22050):
    """
    Evaluation on validation dataset. It calculates PESQ. If you wanna get validation audio files, put out_dir.
    :param meta_dir: voice bank meta directory
    :param model_name: model name
    :param pretrained_path: pretrained checkpoint file path
    :param out_dir: output directory
    :param batch_size: batch size for evaluating datasets
    :param num_workers: workers of data loader
    :param sr: training sample rate
    """

    preemp = PreEmphasis().cuda()

    # load model
    model = __load_model(model_name, pretrained_path)

    # load validation data loader
    _, valid_loader = voice_bank.get_datasets(meta_dir,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              fix_len=0,
                                              audio_mask=True)

    # loop all
    print('Process Validation Dataset (with PESQ) ...')
    pesq_score = 0.
    count = 0

    if out_dir:
        noise_all = []
        clean_all = []
        results = []

    for noise, clean, *others in tqdm(valid_loader, desc='validate'):
        noise = noise.cuda()
        noise = preemp(noise.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            clean_hat = model(noise)

        clean = clean.cpu().numpy()
        clean_hat = clean_hat.cpu().numpy()

        # calculate
        for clean_sample, clean_hat_sample in zip(clean, clean_hat):
            # resample
            clean_sample = librosa.core.resample(clean_sample, sr, 16000)
            clean_hat_sample = librosa.core.resample(clean_hat_sample, sr,
                                                     16000)

            item_score = pypesq(
                16000, clean_sample,
                inv_preemphasis(clean_hat_sample).clip(-1., 1.), 'wb')
            pesq_score += item_score
            count += 1

        if out_dir:
            noise_all.append(noise.cpu().numpy())
            clean_all.append(clean)
            results.append(clean_hat)

    print(f'PESQ Score : {pesq_score / count}')

    if out_dir:
        # mkdir
        os.makedirs(out_dir, exist_ok=True)
        # write all
        print('Write all result into {} ...'.format(out_dir))
        for idx, (batch_clean_hat, batch_noise, batch_clean) in tqdm(
                enumerate(zip(results, noise_all, clean_all))):
            for in_idx, (clean_hat, noise, clean) in enumerate(
                    zip(batch_clean_hat, batch_noise, batch_clean)):
                noise_out_path = os.path.join(
                    out_dir, '{}_noise.wav'.format(idx * batch_size + in_idx))
                pred_out_path = os.path.join(
                    out_dir, '{}_pred.wav'.format(idx * batch_size + in_idx))
                clean_out_path = os.path.join(
                    out_dir, '{}_clean.wav'.format(idx * batch_size + in_idx))

                librosa.output.write_wav(clean_out_path, clean,
                                         settings.SAMPLE_RATE)
                librosa.output.write_wav(noise_out_path,
                                         inv_preemphasis(noise),
                                         settings.SAMPLE_RATE)
                librosa.output.write_wav(
                    pred_out_path,
                    inv_preemphasis(clean_hat).clip(-1., 1.),
                    settings.SAMPLE_RATE)

        print('Finish writing files.')