Example #1
0
def evaluation(ref, est, mix):
	"""
	Wrapper function for evaluating the output of a NN. Metrics are PESQ and STOI

	:param ref: Path to the original (reference point) file.
	:param est: Path to the estimated file.
	:return: Prints in stdout PESQ and STOI metric values.
	"""
	file_ref = ref
	file_est = est
	file_mix = mix
	reference_sources, sr_r = librosa.load(file_ref, sr = None)
	estimated_sources, sr_e = librosa.load(file_est, sr = None)
	mix_sources, sr_m = librosa.load(file_mix, sr = None)

	if sr_r != 16000 or sr_e != 16000 or sr_m != 16000:
		print("\nResampling at 16k...")
		ref_16k = librosa.resample(reference_sources, sr_r, 16000)
		est_16k = librosa.resample(estimated_sources, sr_e, 16000)
		mix_16k = librosa.resample(mix_sources, sr_e, 16000)
	else:
		ref_16k = reference_sources
		est_16k = estimated_sources
		mix_16k = mix_sources

	pesq_score = round(pesq(ref_16k, est_16k, 16000), 3)
	stoi_score = round(stoi(ref_16k, est_16k, sr_r, extended=False), 2)
	estoi_score = round(stoi(ref_16k, est_16k, sr_r, extended=True), 2)
	ssr_score = round(SSR(est_16k, mix_16k), 3)

	print("PESQ\t STOI\t eSTOI\t   SSR")
	print(pesq_score,"\t",stoi_score,"\t",estoi_score,"\t",ssr_score)
def eval(ref_name, enh_name, nsy_name, results):
    try:
        utt_id = ref_name.split('/')[-1]
        ref, sr = audioread(ref_name)
        enh, sr = audioread(enh_name)
        nsy, sr = audioread(nsy_name)
        enh_len = enh.shape[0]
        ref_len = ref.shape[0]
        if enh_len > ref_len:
            enh = enh[:ref_len]
        else:
            ref = ref[:enh_len]
            nsy = nsy[:enh_len]
        ref_score = pesq(sr, ref, nsy, 'wb')
        enh_score = pesq(sr, ref, enh, 'wb')
        ref_stoi = stoi(ref, nsy, sr, extended=False)
        enh_stoi = stoi(ref, enh, sr, extended=False)
        ref_sisdr = si_snr(nsy, ref)
        enh_sisdr = si_snr(enh, ref)

    except Exception as e:
        print(e)

    results.append([
        utt_id, {
            'pesq': [ref_score, enh_score],
            'stoi': [ref_stoi, enh_stoi],
            'si_sdr': [ref_sisdr, enh_sisdr],
        }
    ])
Example #3
0
def eval(ref_format, enh_name, nsy_format, results):
    try:
        utt_id = enh_name.split('/')[-1]
        enh, sr = audioread(enh_name)

        mix_score = -10000
        mix_stoi = -10000
        mix_estoi = -10000
        mix_sisdr = -10000
        final_score = -10000
        final_stoi = -10000
        final_estoi = -10000
        final_sisdr = -10000
        for kind in ['circular', 'linear_nonuniform', 'linear_uniform']:
            ref_name = ref_format.format(kind)
            nsy_name = nsy_format.format(kind)
            ref, sr = audioread(ref_name)
            nsy, sr = audioread(nsy_name)
            enh_len = enh.shape[0]
            ref_len = ref.shape[0]
            if enh_len > ref_len:
                enh = enh[:ref_len]
            else:
                ref = ref[:enh_len]
                nsy = nsy[:enh_len]
            ref_score = pesq(sr, ref, nsy, 'wb')
            enh_score = pesq(sr, ref, enh, 'wb')
            ref_stoi = stoi(ref, nsy, sr, extended=False)
            enh_stoi = stoi(ref, enh, sr, extended=False)
            ref_estoi = stoi(ref, nsy, sr, extended=True)
            enh_estoi = stoi(ref, enh, sr, extended=True)
            ref_sisdr = si_snr(nsy, ref)
            enh_sisdr = si_snr(enh, ref)
            if enh_score > final_score:
                mix_score = ref_score
                final_score = enh_score
            if enh_stoi > final_stoi:
                mix_stoi = ref_stoi
                final_stoi = enh_stoi
            if enh_estoi > final_estoi:
                mix_estoi = ref_estoi
                final_estoi = enh_estoi
            if enh_sisdr > final_sisdr:
                mix_sisdr = ref_sisdr
                final_sisdr = enh_sisdr

    except Exception as e:
        print(e)

    results.append([
        utt_id, {
            'pesq': [mix_score, final_score],
            'stoi': [mix_stoi, final_stoi],
            'estoi': [mix_estoi, final_estoi],
            'si_sdr': [mix_sisdr, final_sisdr],
        }
    ])
Example #4
0
def calc_stoi(in_file, out_speech_dir):

    out_names = [
        os.path.join(out_speech_dir, na)
        for na in sorted(os.listdir(out_speech_dir)) if na.endswith(".wav")
    ]
    stoi_list = []
    print("---------------------------------")
    print("\t", "STOI", "\n")
    (x, fs1) = pp.read_audio(in_file)

    for f in out_names:
        print(f)
        (y, fs2) = pp.read_audio(f)
        if fs1 != fs2:
            print("Error: output and input files have different sampling rate")

        m = min(len(x), len(y))
        res = stoi(x[0:m], y[0:m], fs1)
        stoi_list.append(res)
        # print(g, "\t",  res)

    avg_stoi = np.mean(stoi_list)
    std_stoi = np.std(stoi_list)
    print("AVG STOI\t", avg_stoi)
    print("ST DEV STOI\t", std_stoi)
    print("---------------------------------")
    return avg_stoi, std_stoi
Example #5
0
def test_estoi_good_fs():
    """ Test extended STOI at sampling frequency of 10kHz. """
    x = np.random.randn(2 * FS)
    y = np.random.randn(2 * FS)
    estoi_out = stoi(x, y, FS, extended=True)
    estoi_out_m = octave.feval('octave/estoi.m', x, y, float(FS))
    assert_allclose(estoi_out, estoi_out_m, atol=ATOL, rtol=RTOL)
Example #6
0
def test_stoi_good_fs():
    """ Test STOI at sampling frequency of 10kHz. """
    x = np.random.randn(2 * FS)
    y = np.random.randn(2 * FS)
    stoi_out = stoi(x, y, FS)
    stoi_out_m = octave.feval('octave/stoi.m', x, y, float(FS))
    assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL)
Example #7
0
def cal_stoi_from_file(enhance, clean, sr=SAMPLE_RATE):
    """
    A wrapped pystoi func to get stoi

    Example:
    >>> cal_stoi('test', 'test')
    Traceback (most recent call last):
    ...
    Exception: can not find files, please check path

    >>> cal_stoi('mix/fjcs0_sx319.wav_-1.wav', 'speech/test/fjcs0_sx319.wav')
    0.76689639728190695

    :param enhance:
    :param clean:
    :return:
    """
    try:
        clean_signal, _ = librosa.load(clean, sr=sr)
        # plot_wav(clean_signal, title=clean)
        enhance_signal, _ = librosa.load(enhance, sr=sr)
        clean_signal = clean_signal[:len(enhance_signal)]  # 有时信号不等长,所以这里对齐
    except FileNotFoundError:
        raise Exception('can not find files, please check path')
    except IsADirectoryError:
        raise Exception('can not find files, please check path')

    res = stoi(clean_signal, enhance_signal, SAMPLE_RATE, extended=False)
    return res
Example #8
0
def calculate_stoi(args):
    workspace = "workspace"
    speech_dir = "mini_data/test_speech"
    # Calculate PESQ of all enhaced speech.
    enh_speech_dir = os.path.join(workspace, "enh_wavs", "test", "mixdb")
    #enh_speech_dir = "/data00/wangjinchao/sednn-master/mixture2clean_dnn/workspace/mixed_audios/spectrogram/test/mixdb"
    #    enh_speech_dir = os.path.join(workspace ,'mixed_audios','spectrogram','test','mixdb')
    names = os.listdir(enh_speech_dir)
    f = open("IRM_stoi.txt", "w")
    f.write("%s\t%s\n" % ("speech_id", "stoi"))
    f.flush()
    for (cnt, na) in enumerate(names):
        print(cnt, na)
        enh_path = os.path.join(enh_speech_dir, na)
        speech_na = na.split('.')[0]
        speech_path = os.path.join(speech_dir, "%s.WAV" % speech_na)
        speech_audio, fs = read_audio(speech_path, 16000)
        enhance_audio, fs = read_audio(enh_path, 16000)
        if len(speech_audio) > len(enhance_audio):
            speech_audio = speech_audio[:len(enhance_audio)]
        else:
            enhance_audio = enhance_audio[:len(speech_audio)]
        stoi_value = stoi(speech_audio, enhance_audio, fs, extended=False)
        f.write("%s\t%f\n" % (na, stoi_value))
        f.flush()
    f.close()
Example #9
0
def test_stoi_good_fs():
    x = np.random.randn(2 * FS, )
    y = np.random.randn(2 * FS, )
    stoi_out = stoi(x, y, FS)
    x_m = matlab.double(list(x))
    y_m = matlab.double(list(y))
    stoi_out_m = eng.stoi(x_m, y_m, float(FS))
    assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL)
Example #10
0
def cal_score(clean, enhanced):
    clean = clean / abs(clean).max()
    enhanced = enhanced / abs(enhanced).max()

    s_stoi = stoi(clean, enhanced, 16000)
    s_pesq = pesq(clean, enhanced, 16000)

    return round(s_pesq, 5), round(s_stoi, 5)
Example #11
0
def test_estoi_good_fs():
    x = np.random.randn(2 * FS, )
    y = np.random.randn(2 * FS, )
    estoi_out = stoi(x, y, FS, extended=True)
    x_m = matlab.double(list(x))
    y_m = matlab.double(list(y))
    estoi_out_m = eng.estoi(x_m, y_m, float(FS))
    assert_allclose(estoi_out, estoi_out_m, atol=ATOL, rtol=RTOL)
Example #12
0
def test_estoi_good_fs():
    """ Test extended STOI at sampling frequency of 10kHz. """
    x = np.random.randn(2*FS, )
    y = np.random.randn(2*FS, )
    estoi_out = stoi(x, y, FS, extended=True)
    x_m = matlab.double(list(x))
    y_m = matlab.double(list(y))
    estoi_out_m = eng.estoi(x_m, y_m, float(FS))
    assert_allclose(estoi_out, estoi_out_m, atol=ATOL, rtol=RTOL)
Example #13
0
def test_stoi_upsample():
    """ Test STOI at sampling frequency above 10 kHz. """
    for fs in [8000]:
        x = np.random.randn(2 * fs)
        y = np.random.randn(2 * fs)
        octave.eval('pkg load signal')
        stoi_out = stoi(x, y, fs)
        stoi_out_m = octave.feval('octave/stoi.m', x, y, float(fs))
        assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL)
Example #14
0
def test_stoi_downsample():
    """ Test STOI at sampling frequency below 10 kHz. """
    for fs in [11025, 16000, 22050, 32000, 44100, 48000]:
        x = np.random.randn(2 * fs)
        y = np.random.randn(2 * fs)
        octave.eval('pkg load signal')
        stoi_out = stoi(x, y, fs)
        stoi_out_m = octave.feval('octave/stoi.m', x, y, float(fs))
        assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL)
Example #15
0
def read_STOI(clean_root, enhanced_file):
    f = enhanced_file.split('/')[-1]
    wave_name = f.split('_')[-2] + '_' + f.split('_')[-1].split('@')[0]

    clean_wav, _ = wavread(clean_root + wave_name + '.wav')
    enhanced_wav, _ = wavread(enhanced_file)

    stoi_score = stoi(clean_wav, enhanced_wav, 16000, extended=False)
    return stoi_score
Example #16
0
def read_STOI(clean_root, enhanced_file):
    f=enhanced_file.split('/')[-1]
    wave_name=f.split('_')[-1].split('@')[0]
    
    clean_wav    = librosa.load(clean_root+'Train_'+wave_name+'.wav', sr=16000)     
    enhanced_wav = librosa.load(enhanced_file, sr=16000)
    
    stoi_score = stoi(clean_wav[0], enhanced_wav[0], 16000, extended=False)    
    return stoi_score
Example #17
0
def compute_STOI(clean_signal, noisy_signal, sr=16000):
    """计算 STOI

    Args:
        clean_signal:纯净语音信号
        noisy_signal:带噪语音信号
        sr:采样率
    """
    return stoi(clean_signal, noisy_signal, sr, extended=False)
Example #18
0
def eval_metrics(ori_sig, dec_sig, _rand_model_id):
    _min_len, _snr, _ori_sig, _dec_sig = snr(ori_sig, dec_sig)
    _si_snr = si_snr(_dec_sig, _ori_sig)
    the_stoi = stoi(_ori_sig, _dec_sig, 16000, extended=False)
    sf.write('ori_sig_' + _rand_model_id + '.wav', ori_sig, 16000, 'PCM_16')
    sf.write('dec_sig_' + _rand_model_id + '.wav', dec_sig, 16000, 'PCM_16')
    the_pesq = pesq('ori_sig_' + _rand_model_id + '.wav',
                    'dec_sig_' + _rand_model_id + '.wav', 16000)
    return _min_len, _si_snr, _snr, float(the_stoi), float(
        the_pesq), np.corrcoef(_ori_sig, _dec_sig)[0][1]
Example #19
0
def test_stoi_upsample():
    """ FAILING BECAUSE OF RESAMPLING """
    for fs in [8000]:
        x = np.random.randn(2 * fs, )
        y = np.random.randn(2 * fs, )
        stoi_out = stoi(x, y, fs)
        x_m = matlab.double(list(x))
        y_m = matlab.double(list(y))
        stoi_out_m = eng.stoi(x_m, y_m, float(fs))
        assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL)
Example #20
0
def read_STOI(clean_root, enhanced_file):
    f = enhanced_file.split('/')[-1]
    wave_name = f.split('@')[0]

    clean_wav = wavfile.read(clean_root + wave_name +
                             '.wav')[-1].astype(float) / maxv
    enhanced_wav = wavfile.read(enhanced_file)[-1].astype(float) / maxv

    stoi_score = stoi(clean_wav, enhanced_wav, 16000, extended=False)
    return stoi_score
Example #21
0
def test_stoi_downsample():
    """ Test STOI at sampling frequency below 10 kHz.
        FAILING BECAUSE OF RESAMPLING """
    for fs in [11025, 16000, 22050, 32000, 44100, 48000]:
        x = np.random.randn(2*fs, )
        y = np.random.randn(2*fs, )
        stoi_out = stoi(x, y, fs)
        x_m = matlab.double(list(x))
        y_m = matlab.double(list(y))
        stoi_out_m = eng.stoi(x_m, y_m, float(fs))
        assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL)
Example #22
0
def test_stoi_upsample():
    """ Test STOI at sampling frequency above 10 kHz.
        PASSES FOR :  RTOL = 1e-3  /  ATOL = 1e-3. """
    for fs in [8000]:
        x = np.random.randn(2 * fs, )
        y = np.random.randn(2 * fs, )
        stoi_out = stoi(x, y, fs)
        x_m = matlab.double(list(x))
        y_m = matlab.double(list(y))
        stoi_out_m = eng.stoi(x_m, y_m, float(fs))
        assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL)
Example #23
0
def eval(ref_name, enh_name, nsy_name, results):
    try:
        utt_id = ref_name.split('/')[-1]
        ref, sr = audioread(ref_name)
        enh, sr = audioread(enh_name)
        nsy, sr = audioread(nsy_name)
        ref_score = pesq(ref, nsy, sr)
        enh_score = pesq(ref, enh, sr)
        ref_stoi = stoi(ref, nsy, sr, extended=False)
        enh_stoi = stoi(ref, enh, sr, extended=False)
        ref_sdr = si_sdr(nsy, ref)
        enh_sdr = si_sdr(enh, ref)
    except Exception as e:
        print(e)
    
    results.append([utt_id, 
                    {'pesq':[ref_score, enh_score],
                     'stoi':[ref_stoi,enh_stoi],
                     'si_sdr':[ref_sdr, enh_sdr]
                    }])
def get_batch_stoi_improvement(x_wav, y_wav, y_wav_est):
    # return np.array([[1,1],[2,2],[3,3]])
    '''
  inputs:
    x_wav, y_wav, y_wav_est: [batch,wave]
  return:
     mixture stoi, enhanced stoi, stoi improvement: [batch]
  '''
    # calculate STOI improvement
    stoi_ref_cleaned_list = [
        stoi(ref / AMP_MAX, cleaned / AMP_MAX, FLAGS.PARAM.FS)
        for ref, cleaned in zip(y_wav, y_wav_est)
    ]
    stoi_ref_mixed_list = [
        stoi(ref / AMP_MAX, mixed / AMP_MAX, FLAGS.PARAM.FS)
        for ref, mixed in zip(y_wav, x_wav)
    ]
    stoi_ref_cleaned_vec = np.array(stoi_ref_cleaned_list)
    stoi_ref_mixed_vec = np.array(stoi_ref_mixed_list)
    stoi_imp_vec = stoi_ref_cleaned_vec - stoi_ref_mixed_vec
    return np.array([stoi_ref_mixed_vec, stoi_ref_cleaned_vec, stoi_imp_vec])
Example #25
0
 def end2end_final_eval(self, ori_sig, dec_sig):
     """
     Calculate evaluation metrics.         
     """
     _min_len, _snr, _ori_sig, _dec_sig = snr(ori_sig, dec_sig)
     the_stoi = stoi(_ori_sig, _dec_sig, sample_rate, extended=False)
     sf.write('ori_sig_that_one' + self._rand_model_id + '.wav', ori_sig,
              sample_rate, 'PCM_16')
     sf.write('dec_sig_that_one' + self._rand_model_id + '.wav', dec_sig,
              sample_rate, 'PCM_16')
     the_pesq = 0.0  # disabled for audio coding.
     return _min_len, _snr, float(the_stoi), float(the_pesq), np.corrcoef(
         _ori_sig, _dec_sig)[0][1]
Example #26
0
def run_test(config):
    """ Define our model and test it """

    generator = AECNN(
        channel_counts=config.gchan,
        kernel_size=config.gkernel,
        block_size=config.gblocksize,
        dropout=config.gdrop,
    ).cuda()

    generator.load_state_dict(torch.load(config.gcheckpoints))

    # Initialize datasets
    ev_dataset = wav_dataset(config, 'et', 4)

    count = 0
    score = {'stoi': 0, 'estoi': 0, 'sdr': 0}
    for example in ev_dataset:
        data = np.squeeze(
            generator(example['noisy'].cuda()).cpu().detach().numpy())
        clean = np.squeeze(example['clean'].numpy())
        noisy = np.squeeze(example['noisy'].numpy())
        score['stoi'] += stoi(clean, data, 16000, extended=False)
        score['estoi'] += stoi(clean, data, 16000, extended=True)
        score['sdr'] += si_sdr(data, clean)
        count += 1
        #if count == 1:
        #    with sf.SoundFile('clean.wav', 'w', 16000, 1) as w:
        #        w.write(clean)
        #    with sf.SoundFile('noisy.wav', 'w', 16000, 1) as w:
        #        w.write(noisy)
        #    with sf.SoundFile('test.wav', 'w', 16000, 1) as w:
        #        w.write(data)
        #    break

    print('stoi: %f' % (score['stoi'] / count))
    print('estoi: %f' % (score['estoi'] / count))
    print('sdr: %f' % (score['sdr'] / count))
Example #27
0
def cal_STOIi(src_ref, src_est, mix):
    """Calculate Source-to-Distortion Ratio improvement (SDRi).
    NOTE: bss_eval_sources is very very slow.
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SDRi
    """
    num = 0
    new_stoi = 0
    orig_stoi = 0
    avg_STOIi = 0
    for ref, est in zip(src_ref, src_est):
        num = num + 1
        new_stoi_out = stoi(ref, est, 8000)
        new_stoi = new_stoi + new_stoi_out
        orig_stoi_out = stoi(ref, mix, 8000)
        orig_stoi = orig_stoi + orig_stoi_out
        avg_STOIi = avg_STOIi + (new_stoi_out - orig_stoi_out)

    return new_stoi / num, orig_stoi / num, avg_STOIi / num
def read_STOI_DRC(clean_root, noise_root, enhanced_file):
    f = enhanced_file.split('/')[-1]
    wave_name = f
    clean_wav, _ = librosa.load(clean_root + wave_name, sr=fs)
    enhanced_wav, _ = librosa.load(enhanced_file, sr=fs)
    noise_wav, _ = librosa.load(noise_root + wave_name, sr=fs)
    minL = min(len(clean_wav), len(enhanced_wav))
    clean_wav = clean_wav[:minL]
    noise_wav = noise_wav[:minL]
    enhanced_wav = enhanced_wav[:minL]

    stoi_score = stoi(clean_wav, enhanced_wav + noise_wav, fs,
                      extended=True) * 2
    return stoi_score
Example #29
0
def eval_STOI(ref, y, fs, extended=False, compute_permutation=True):
    """Calculate STOI

    Reference:
        A short-time objective intelligibility measure
            for time-frequency weighted noisy speech
        https://ieeexplore.ieee.org/document/5495701

    Note(kamo):
        STOI is defined on the signal at 10kHz
        and the input at the other sampling rate will be resampled.
        Thus, the result differs depending on the implementation of resampling.
        Especially, pystoi cannot reproduce matlab's resampling now.

    :param ref (np.ndarray): Reference (Nsrc, Nframe, Nmic)
    :param y (np.ndarray): Enhanced (Nsrc, Nframe, Nmic)
    :param fs (int): Sample frequency
    :param extended (bool): stoi or estoi
    :param compute_permutation (bool):
    :return: value, perm
    :rtype: Tuple[Tuple[float, ...], Tuple[int, ...]]
    """
    if ref.shape != y.shape:
        raise ValueError(
            "ref and y should have the same shape: {} != {}".format(
                ref.shape, y.shape))
    if ref.ndim != 3:
        raise ValueError("Input must have 3 dims: {}".format_map(ref.ndim))
    n_src = ref.shape[0]
    n_mic = ref.shape[2]

    if compute_permutation:
        index_list = list(itertools.permutations(range(n_src)))
    else:
        index_list = [list(range(n_src))]

    values = [[
        sum(
            stoi(ref[i, :, ch], y[j, :, ch], fs, extended)
            for ch in range(n_mic)) / n_mic for i, j in enumerate(indices)
    ] for indices in index_list]

    best_pairs = sorted([(v, i) for v, i in zip(values, index_list)],
                        key=lambda x: sum(x[0]))[-1]
    value, perm = best_pairs
    return tuple(value), tuple(perm)
Example #30
0
def cal_score(clean, enhanced):
    clean = clean / abs(clean).max()
    enhanced = enhanced / abs(enhanced).max()

    s_stoi = stoi(clean, enhanced, 16000)
    s_pesq = pesq(clean, enhanced, 16000)

    return round(s_pesq, 5), round(s_stoi, 5)

    #def get_filepaths(directory,folders='BabyCry.wav,cafeteria_babble.wav',ftype='.wav'):
    #    file_paths = []
    #    folders = folders.split(',')
    #    with open(directory, 'r') as f:
    #        for line in f:
    #            if str(line.split('/')[-3]) in folders:
    #                file_paths.append(line[:-1])
    #    return file_paths
    '''