Ejemplo n.º 1
0
def cal(path, aim_mix_number):
    mix_number = len(
        set([l.split('_')[0] for l in os.listdir(path) if l[-3:] == 'wav']))
    print 'num of mixed :', mix_number

    SDR_sum = np.array([])
    for idx in range(mix_number):
        pre_speech_channel = []
        aim_speech_channel = []
        mix_speech = []
        # aim_list=[l for l in os.listdir(path) if l[-3:]=='wav' and l.split('_')[0]==str(idx)]
        for l in sorted(os.listdir(path)):
            if l[-3:] != 'wav':
                continue
            if l.split('_')[0] == str(idx):
                # print l
                if 'True_mix' in l:
                    mix_speech.append(sf.read(path + l)[0])
                # if 'genTrue' in l:
                #     aim_speech_channel.append(sf.read(path+l)[0])
                if 'realTrue' in l:
                    aim_speech_channel.append(sf.read(path + l)[0])
                if 'pre' in l:
                    pre_speech_channel.append(sf.read(path + l)[0])

        # assert len(aim_speech_channel)==len(pre_speech_channel)
        # if len(aim_speech_channel)!=aim_mix_number:
        #     continue
        aim_speech_channel = np.array(aim_speech_channel)
        pre_speech_channel = np.array(pre_speech_channel)
        # print aim_speech_channel.shape
        # print pre_speech_channel.shape
        # result=bss_eval_sources(aim_speech_channel,pre_speech_channel,False)
        # permu=bss_eval_sources(aim_speech_channel,pre_speech_channel)[-1]
        # print permu
        # result=bss_eval_sources(aim_speech_channel[[1,0]],pre_speech_channel)
        if add_slience_channel:
            num_add_cha = pre_speech_channel.shape[
                0] - aim_speech_channel.shape[0]
            aim_speech_channel_add = np.concatenate(
                (aim_speech_channel,
                 np.zeros([num_add_cha, pre_speech_channel.shape[1]]) + 1e-5))
            permu = bss_eval_sources(aim_speech_channel_add,
                                     pre_speech_channel)[-1][:aim_mix_number]
            result = bss_eval_sources(aim_speech_channel,
                                      pre_speech_channel[[permu]])
        else:
            if pre_speech_channel.shape[0] == 1 and aim_speech_channel.shape[
                    0] == 2:
                pre_speech_channel = pre_speech_channel.repeat(2, 0)
            result = bss_eval_sources(aim_speech_channel, pre_speech_channel)
            # result=bss_eval_sources(aim_speech_channel,aim_speech_channel)

        print result
        SDR_sum = np.append(SDR_sum, result[0])
    print 'SDR here:', SDR_sum.mean()
    return SDR_sum
Ejemplo n.º 2
0
def cal(path):
    mix_number = len(
        set([l.split('_')[0] for l in os.listdir(path) if l[-3:] == 'wav']))
    print 'num of mixed :', mix_number
    SDR_sum = np.array([])
    for idx in range(mix_number):
        pre_speech_channel = []
        aim_speech_channel = []
        mix_speech = []
        for l in sorted(os.listdir(path)):
            if l[-3:] != 'wav':
                continue
            if l.split('_')[0] == str(idx):
                if 'True_mix' in l:
                    mix_speech.append(sf.read(path + l)[0])
                if 'real' in l and 'noise' not in l:
                    aim_speech_channel.append(sf.read(path + l)[0])
                if 'pre' in l:
                    pre_speech_channel.append(sf.read(path + l)[0])

        assert len(aim_speech_channel) == len(pre_speech_channel)
        aim_speech_channel = np.array(aim_speech_channel)
        pre_speech_channel = np.array(pre_speech_channel)
        # print aim_speech_channel.shape
        # print pre_speech_channel.shape

        result = bss_eval_sources(aim_speech_channel, pre_speech_channel)
        # result=bss_eval_sources(aim_speech_channel,aim_speech_channel)
        print result

        SDR_sum = np.append(SDR_sum, result[0])
    print 'SDR_Aver for this batch:', SDR_sum.mean()
    return SDR_sum.mean()
Ejemplo n.º 3
0
def cal_SDRi(src_ref, src_est, mix):
    """Calculate Source-to-Distortion Ratio improvement (SDRi).
    NOTE: bss_eval_sources is very very slow.
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SDRi
    """
    src_anchor = np.stack([mix, mix], axis=0)
    sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
    sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
    avg_SDRi = ((sdr[0] - sdr0[0]) + (sdr[1] - sdr0[1])) / 2
    # print("SDRi1: {0:.2f}, SDRi2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[1]))
    return avg_SDRi
Ejemplo n.º 4
0
def cal_SDRi(src_ref, src_est, mix):
    """Calculate Source-to-Distortion Ratio improvement (SDRi).
    NOTE: bss_eval_sources is very very slow.
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SDRi
    """
    src_anchor = np.stack([mix, mix], axis=0)
    if src_ref.shape[0] == 1:
        src_anchor = src_anchor[0]
    sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
    sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
    avg_SDR = ((sdr[0]) + (sdr[1])) / 2
    avg_SDRi = ((sdr[0] - sdr0[0]) + (sdr[1] - sdr0[1])) / 2
    return avg_SDR, avg_SDRi
Ejemplo n.º 5
0
def cal(path, tmp=None):
    mix_number = len(
        set([l.split('_')[0] for l in os.listdir(path) if l[-3:] == 'wav']))
    print(('num of mixed :', mix_number))
    SDR_sum = np.array([])
    SDRi_sum = np.array([])
    for idx in range(mix_number):
        pre_speech_channel = []
        aim_speech_channel = []
        mix_speech = []
        for l in sorted(os.listdir(path)):
            if l[-3:] != 'wav':
                continue
            if l.split('_')[0] == str(idx):
                if 'True_mix' in l:
                    mix_speech.append(sf.read(path + l)[0])
                if 'real' in l and 'noise' not in l:
                    aim_speech_channel.append(sf.read(path + l)[0])
                if 'pre' in l:
                    pre_speech_channel.append(sf.read(path + l)[0])

        assert len(aim_speech_channel) == len(pre_speech_channel)
        aim_speech_channel = np.array(aim_speech_channel)
        pre_speech_channel = np.array(pre_speech_channel)
        mix_speech = np.array(mix_speech)
        assert mix_speech.shape[0] == 1
        mix_speech = mix_speech[0]

        # print aim_speech_channel.shape
        # print pre_speech_channel.shape

        # print('aim SDR:',aim_speech_channel[:,16000:16005])
        # print('pre SDR:',pre_speech_channel[:,16000:16005])
        result = bss_eval_sources(aim_speech_channel, pre_speech_channel)
        print(('SDR', result))
        SDR_sum = np.append(SDR_sum, result[0])

        # result=bss_eval_sources(aim_speech_channel,aim_speech_channel)
        # result_sdri=cal_SDRi(aim_speech_channel,pre_speech_channel,mix_speech)
        # print 'SDRi:',result_sdri
        result_sdri = cal_SISNRi(aim_speech_channel,
                                 pre_speech_channel[result[-1]], mix_speech)
        print(('SI-SNR', result))
        # for ii in range(aim_speech_channel.shape[0]):
        #     result=cal_SISNRi(aim_speech_channel[ii],pre_speech_channel[ii],mix_speech[ii])
        #     print('SI-SNR',result)
        SDRi_sum = np.append(SDRi_sum, result_sdri)

    print(('SDR_Aver for this batch:', SDR_sum.mean()))
    # print 'SDRi_Aver for this batch:',SDRi_sum.mean()
    return SDR_sum.mean(), SDRi_sum.mean()
Ejemplo n.º 6
0
def infer_a_sample(sample_i, round):
    target_path = target_list[sample_i]
    mix_flag = random.random()
    if mix_flag < 0.33:
        interf_id = random.randint(0, len(target_list) - 1)
        interf_path = target_list[interf_id]
        interf = target_dict[interf_path]
    elif 0.33 <= mix_flag < 0.66:
        interf_id = random.randint(0, len(interf_audio_list) - 1)
        interf_path = interf_audio_list[interf_id]
        interf = interf_dict[interf_path]
    else:
        interf_id = random.randint(0, len(interf_esc_list) - 1)
        interf_path = interf_esc_list[interf_id]
        interf = interf_dict[interf_path]
    target = target_dict[target_path]
    chooseIndexNormalised = None
    batchLen = opt.seqLen
    saveMixtureName = None
    interf_scale = SNR_db_to_scale(opt.db)
    print('Mixing scale is {}'.format(interf_scale))
    (batchInput1, batchTarget, batchMixphase, batchTargetphase,
     chooseIndex) = ExtractFeatureFromOneSignal_fromMemory(
         target, interf, interf_scale, chooseIndexNormalised, batchLen,
         add_hole, saveMixtureName, opt.dBscale, hparams.sample_rate,
         opt.normalize)

    phase_dim = halfFFT + 1
    L = chooseIndex[-1] + opt.seqLen

    train_model.eval()

    with torch.no_grad():
        batchInput1 = torch.from_numpy(batchInput1)

        if opt.cuda:
            batchInput1 = batchInput1.cuda()

        pred = train_model(batchInput1)

    pred_spec = np.zeros((L, halfFFT))
    gt_spec = np.zeros((L, halfFFT))
    mix_spec = np.zeros((L, halfFFT))
    mix_phase = np.zeros((L, phase_dim))
    target_phase = np.zeros((L, phase_dim))

    try:  # in case the wav is shorter than seqLen frames
        for n, i in enumerate(chooseIndex):
            # the mixture
            mix_spec[i:i + opt.seqLen] = batchInput1[n].data.cpu().numpy()
            # the gt
            gt_spec[i:i + opt.seqLen] = batchTarget[n]
            # the prediction
            pred_spec[i:i + opt.seqLen] = pred[n].data.cpu().numpy()

            mix_phase[i:i + opt.seqLen] = batchMixphase[n]
            target_phase[i:i + opt.seqLen] = batchTargetphase[n]

        # Forgot to add options!!
        mix_wav = postpro_and_gen(mix_spec,
                                  mix_phase,
                                  dBscale=opt.dBscale,
                                  denormalize=opt.normalize)
        target_wav = postpro_and_gen(gt_spec,
                                     target_phase,
                                     dBscale=opt.dBscale,
                                     denormalize=opt.normalize)
        pred_wav = postpro_and_gen(pred_spec,
                                   mix_phase,
                                   dBscale=opt.dBscale,
                                   denormalize=opt.normalize)

        mixtureSaveName = 'r_{}_samp_{}_input.wav'.format(round, sample_i)
        targetSaveName = 'r_{}_samp_{}_target.wav'.format(round, sample_i)
        predSaveName = 'r_{}_samp_{}_pred.wav'.format(round, sample_i)
        figName = 'r_{}_samp_{}_fig.jpg'.format(round, sample_i)

        # save the audios locally
        sf.write(os.path.join(wav_dir, mixtureSaveName), mix_wav,
                 hparams.sample_rate)
        sf.write(os.path.join(wav_dir, targetSaveName), target_wav,
                 hparams.sample_rate)
        sf.write(os.path.join(wav_dir, predSaveName), pred_wav,
                 hparams.sample_rate)
        print('Saving to r_{}_samp_{}'.format(round, sample_i))

        # compute scores
        sdr_this_sample = bss_eval_sources(target_wav, pred_wav)[0]
        pesq_this_sample = pypesq(hparams.sample_rate, target_wav, pred_wav,
                                  'nb')
        stoi_this_sample = stoi(target_wav,
                                pred_wav,
                                hparams.sample_rate,
                                extended=False)

        # plot fig
        plot_and_compare(mix_wav, target_wav, pred_wav,
                         os.path.join(wav_dir, figName), hparams.sample_rate)

        print('SDR, PESQ, STOI this sample is {}, {}, {}.'.format(
            sdr_this_sample, pesq_this_sample, stoi_this_sample))

        if opt.compute_orig:
            sdr_orig_this_sample = bss_eval_sources(target_wav, mix_wav)[0]
            pesq_orig_this_sample = pypesq(hparams.sample_rate, target_wav,
                                           mix_wav, 'nb')
            stoi_orig_this_sample = stoi(target_wav,
                                         mix_wav,
                                         hparams.sample_rate,
                                         extended=False)
            print('SDR, PESQ, STOI this sample is {}, {}, {} orig.'.format(
                sdr_orig_this_sample, pesq_orig_this_sample,
                stoi_orig_this_sample))

    except Exception as e:
        sdr_this_sample = None
        pesq_this_sample = None
        stoi_this_sample = None
        sdr_orig_this_sample = None
        pesq_orig_this_sample = None
        stoi_orig_this_sample = None
        print(e)

    if opt.compute_orig:
        return sdr_this_sample, pesq_this_sample, stoi_this_sample, \
               sdr_orig_this_sample, pesq_orig_this_sample, stoi_orig_this_sample
    else:
        return sdr_this_sample, pesq_this_sample, stoi_this_sample