Exemple #1
0
def decode_one_wav(sess, model, wavedata):
    x_spec_t = spectrum_tool.magnitude_spectrum_librosa_stft(
        wavedata, PARAM.NFFT, PARAM.OVERLAP)
    x_phase_t = spectrum_tool.phase_spectrum_librosa_stft(
        wavedata, PARAM.NFFT, PARAM.OVERLAP)
    length = np.shape(x_spec_t)[0]
    x_spec = np.array([x_spec_t], dtype=np.float32)
    x_theta = np.array([x_phase_t], dtype=np.float32)
    lengths = np.array([length], dtype=np.int32)
    y_theta_est, y_mag_estimation, mask, x_mag, = sess.run([
        model.y_theta_estimation,
        model.y_mag_estimation,
        model.mask,
        model._x_mag_spec,
    ],
                                                           feed_dict={
                                                               model.x_mag:
                                                               x_spec,
                                                               model.lengths:
                                                               lengths,
                                                               model.x_theta:
                                                               x_theta,
                                                           })

    y_mag_estimation = np.array(y_mag_estimation[0])
    mask = np.array(mask[0])
    # print(np.shape(y_mag_estimation), np.shape(mask))
    if PARAM.RESTORE_PHASE == 'MIXED':
        y_mag_estimation = y_mag_estimation * np.exp(
            1j * spectrum_tool.phase_spectrum_librosa_stft(
                wavedata, PARAM.NFFT, PARAM.OVERLAP))
        reY = spectrum_tool.librosa_istft(y_mag_estimation, PARAM.NFFT,
                                          PARAM.OVERLAP)
    elif PARAM.RESTORE_PHASE == 'GRIFFIN_LIM':
        reY = spectrum_tool.griffin_lim(y_mag_estimation, PARAM.NFFT,
                                        PARAM.OVERLAP, PARAM.GRIFFIN_ITERNUM,
                                        wavedata)
    elif PARAM.RESTORE_PHASE == 'ESTIMATE':
        if y_theta_est is None:
            print('Model cannot estimate y_theta.')
            exit(-1)
        y_mag_estimation = y_mag_estimation * np.exp(1j * y_theta_est)
        reY = spectrum_tool.librosa_istft(y_mag_estimation, PARAM.NFFT,
                                          PARAM.OVERLAP)
    else:
        print('RESTORE_PHASE error.')
        exit(-1)

    # print(np.shape(mask), np.max(mask), np.min(mask))
    # print(np.shape(x_mag), np.max(x_mag), np.min(x_mag))
    # print(np.shape(norm_x_mag), np.max(norm_x_mag), np.min(norm_x_mag))
    # print(np.shape(norm_logmag), np.max(norm_logmag), np.min(norm_logmag))
    # print(np.shape(y_mag_estimation), np.max(y_mag_estimation), np.min(y_mag_estimation))
    # spectrum_tool.picture_spec(mask[0],"233")
    return reY, mask
def decode_one_wav(sess, model, wavedata):
    x_spec_t = spectrum_tool.magnitude_spectrum_librosa_stft(
        wavedata, PARAM.NFFT, PARAM.OVERLAP)
    length = np.shape(x_spec_t)[0]
    x_spec = np.array([x_spec_t], dtype=np.float32)
    lengths = np.array([length], dtype=np.int32)
    cleaned, mask, x_mag, norm_x_mag, norm_logmag = sess.run([
        model.cleaned, model.mask, model._x_mag_spec, model._norm_x_mag_spec,
        model._norm_x_logmag_spec
    ],
                                                             feed_dict={
                                                                 model.inputs:
                                                                 x_spec,
                                                                 model.lengths:
                                                                 lengths,
                                                             })

    cleaned = np.array(cleaned)
    if PARAM.RESTORE_PHASE == 'MIXED':
        cleaned = cleaned * np.exp(
            1j * spectrum_tool.phase_spectrum_librosa_stft(
                wavedata, PARAM.NFFT, PARAM.OVERLAP))
        reY = spectrum_tool.librosa_istft(cleaned, PARAM.NFFT, PARAM.OVERLAP)
    elif PARAM.RESTORE_PHASE == 'GRIFFIN_LIM':
        reY = spectrum_tool.griffin_lim(cleaned, PARAM.NFFT, PARAM.OVERLAP,
                                        PARAM.GRIFFIN_ITERNUM, wavedata)

    # print(np.shape(mask), np.max(mask), np.min(mask))
    # print(np.shape(x_mag), np.max(x_mag), np.min(x_mag))
    # print(np.shape(norm_x_mag), np.max(norm_x_mag), np.min(norm_x_mag))
    # print(np.shape(norm_logmag), np.max(norm_logmag), np.min(norm_logmag))
    # spectrum_tool.picture_spec(mask[0],"233")

    return reY
def decode_one_wav(sess, model, wavedata):
    x_spec_t = wav_tool._extract_norm_log_mag_spec(wavedata)
    length = np.shape(x_spec_t)[0]
    x_spec = np.array([x_spec_t], dtype=np.float32)
    lengths = np.array([length], dtype=np.int32)
    cleaned, mask = sess.run([model.cleaned, model.mask],
                             feed_dict={
                                 model.inputs: x_spec,
                                 model.lengths: lengths,
                             })

    y_mag_estimation = np.array(cleaned[0])
    mask = np.array(mask[0])
    if MIXED_AISHELL_PARAM.FEATURE_TYPE == 'LOG_MAG' and MIXED_AISHELL_PARAM.MASK_ON_MAG_EVEN_LOGMAG:
        y_mag_estimation = np.array(y_mag_estimation)
    else:
        y_mag_estimation = np.array(rmNormalization(y_mag_estimation))

    cleaned_spec = spectrum_tool.griffin_lim(y_mag_estimation, wavedata,
                                             MIXED_AISHELL_PARAM.NFFT,
                                             MIXED_AISHELL_PARAM.OVERLAP,
                                             NNET_PARAM.GRIFFIN_ITERNUM)

    # write restore wave
    reY = spectrum_tool.librosa_istft(cleaned_spec, MIXED_AISHELL_PARAM.NFFT,
                                      MIXED_AISHELL_PARAM.OVERLAP)
    if NNET_PARAM.decode_output_speaker_volume_amp:  # norm resotred wave
        reY = reY / np.max(np.abs(reY)) * 32767

    return np.array(reY), mask
def _addnoise_and_decoder_one_batch(i_p, speaker_id, sub_process_speaker_num,
                                    waves_dir, noise_dir, sess, model):
    """
  x_wav, y_wav_est
  """
    s_time = time.time()
    noise_dir_list = [
        os.path.join(noise_dir, _dir) for _dir in os.listdir(noise_dir)
    ]
    n_noise = len(noise_dir_list)
    wave_dir_list = [
        os.path.join(waves_dir, _dir) for _dir in os.listdir(waves_dir)
    ]

    # print(len(wave_dir_list), os.path.dirname(wave_dir_list[0]))

    # mix && get input
    # x_batch = [] # [n_wav, time, 257]
    # x_theta_batch = [] # [n_wav, time, 257]
    # x_lengths = [] # [n_wav]
    batch_size = 0
    for wav_dir in wave_dir_list:
        batch_size += 1
        y_wave, sr_y = audio_tool.read_audio(wav_dir)
        if y_wave.ndim != 1:  # aishell has 2 channel wav
            y_wave = y_wave.T[0] + y_wave.T[1]
        noise_id = np.random.randint(n_noise)
        noise_wave, sr_n = audio_tool.read_audio(noise_dir_list[noise_id])
        noise_wave = audio_tool.repeat_to_len(noise_wave, len(y_wave))
        x_wave, alpha = audio_tool._mix_wav_by_randomSNR(y_wave, noise_wave)

        assert sr_y == sr_n and sr_y == 16000, 'sr error sr_y:%d, sr_n %d' % (
            sr_y, sr_n)
        x_wav_dir = wav_dir.replace('wav', addnoise_dir_name, 1)
        x_wav_dir = x_wav_dir.replace(root_dir, new_root_dir, 1)
        x_wav_father_dir = os.path.dirname(x_wav_dir)
        if not os.path.exists(x_wav_father_dir):
            os.makedirs(x_wav_father_dir)
        audio_tool.write_audio(x_wav_dir, x_wave, sr_y)

        x_spec_t = spectrum_tool.magnitude_spectrum_librosa_stft(
            x_wave,  # [time, 257]
            PARAM.NFFT,
            PARAM.OVERLAP)
        x_phase_t = spectrum_tool.phase_spectrum_librosa_stft(
            x_wave, PARAM.NFFT, PARAM.OVERLAP)
        # x_batch.append(x_spec_t)
        # x_theta_batch.append(x_phase_t)
        # x_lengths.append(np.shape(x_spec_t)[0])

        x_batch = np.array([x_spec_t], dtype=np.float32)
        x_theta_batch = np.array([x_phase_t], dtype=np.float32)
        x_lengths = np.array([np.shape(x_spec_t)[0]], dtype=np.int32)

        # enhance
        y_mag_est = sess.run(model.y_mag_estimation,
                             feed_dict={
                                 model.x_mag: x_batch,
                                 model.x_theta: x_theta_batch,
                                 model.lengths: x_lengths,
                             })

        # istf && save
        if PARAM.RESTORE_PHASE != 'MIXED':
            raise ValueError('Please set PARAM.RESTORE_PHASE=MIXED.')
        # istft
        y_mag_est = y_mag_est * np.exp(1j * x_phase_t)
        reY = spectrum_tool.librosa_istft(y_mag_est, PARAM.NFFT, PARAM.OVERLAP)
        y_wav_dir = wav_dir.replace('wav', enhanced_dir_name, 1)
        y_wav_dir = y_wav_dir.replace(root_dir, new_root_dir, 1)
        y_wav_father_dir = os.path.dirname(y_wav_dir)
        if not os.path.exists(y_wav_father_dir):
            os.makedirs(y_wav_father_dir)
        audio_tool.write_audio(y_wav_dir, reY, PARAM.FS)

    max_len = np.max(x_lengths)

    e_time = time.time()
    print("\n----------------\n"
          "%d workers\n"
          "%s\n"
          "Worker_id %03d, rate of progress: %d/%d\n"
          "time_step_max_len: %d\n"
          "batch_sie: %d\n"
          'batch_cost_time: %ds\n' %
          (num_process, time.ctime(), i_p + 1, speaker_id,
           sub_process_speaker_num, max_len, batch_size, e_time - s_time),
          flush=True)
Exemple #5
0
def addnoise_and_decoder_one_batch(waves_dir, noise_dir, sess, model):
  """
  x_wav, y_wav_est
  """
  s_time = time.time()
  global speaker_n
  speaker_n += 1
  print("\n----------------\n","%d/%d"%(speaker_n,all_speaker))
  sys.stdout.flush()
  noise_dir_list = [os.path.join(noise_dir, _dir) for _dir in os.listdir(noise_dir)]
  n_noise = len(noise_dir_list)
  wave_dir_list = [os.path.join(waves_dir, _dir) for _dir in os.listdir(waves_dir)]

  # print(len(wave_dir_list), os.path.dirname(wave_dir_list[0]))

  # mix && get input
  x_batch = [] # [n_wav, time, 257]
  x_theta_batch = [] # [n_wav, time, 257]
  x_lengths = [] # [n_wav]
  for wav_dir in wave_dir_list:
    y_wave, sr_y = audio_tool.read_audio(wav_dir)
    if y_wave.ndim != 1: # aishell has 2 channel wav
      y_wave = y_wave.T[0]+y_wave.T[1]
    noise_id = np.random.randint(n_noise)
    noise_wave, sr_n = audio_tool.read_audio(noise_dir_list[noise_id])
    noise_wave = audio_tool.repeat_to_len(noise_wave, len(y_wave))
    x_wave, alpha = audio_tool._mix_wav_by_randomSNR(y_wave, noise_wave)

    assert sr_y == sr_n and sr_y == 16000, 'sr error sr_y:%d, sr_n %d' % (sr_y, sr_n)
    x_wav_dir = wav_dir.replace('wav', addnoise_dir_name, 1)
    x_wav_father_dir = os.path.dirname(x_wav_dir)
    if not os.path.exists(x_wav_father_dir):
      os.makedirs(x_wav_father_dir)
    audio_tool.write_audio(x_wav_dir, x_wave, sr_y)

    x_spec_t = spectrum_tool.magnitude_spectrum_librosa_stft(x_wave, # [time, 257]
                                                             PARAM.NFFT,
                                                             PARAM.OVERLAP)
    x_phase_t = spectrum_tool.phase_spectrum_librosa_stft(x_wave,
                                                          PARAM.NFFT,
                                                          PARAM.OVERLAP)
    x_batch.append(x_spec_t)
    x_theta_batch.append(x_phase_t)
    x_lengths.append(np.shape(x_spec_t)[0])

  max_len = np.max(x_lengths)
  print("time_step_max_len:",max_len)
  sys.stdout.flush()

  x_batch_mat = []
  x_theta_batch_mat = []
  for x_spec, x_theta, length in zip(x_batch, x_theta_batch, x_lengths):
    x_spec_mat = np.pad(x_spec, ((0,max_len-length),(0,0)), 'constant', constant_values=((0,0),(0,0)))
    x_theta_mat = np.pad(x_theta, ((0,max_len-length),(0,0)), 'constant', constant_values=((0,0),(0,0)))
    x_batch_mat.append(x_spec_mat)
    x_theta_batch_mat.append(x_theta_mat)

  x_batch = np.array(x_batch_mat, dtype=np.float32)
  x_theta_batch = np.array(x_theta_batch_mat, dtype=np.float32)
  x_lengths = np.array(x_lengths, dtype=np.int32)


  # enhance
  y_mag_est_batch = sess.run(
      model.y_mag_estimation,
      feed_dict={
          model.x_mag: x_batch,
          model.x_theta: x_theta_batch,
          model.lengths: x_lengths,
      })

  # istf && save
  print(np.shape(y_mag_est_batch), np.shape(x_theta_batch), np.shape(x_lengths))
  sys.stdout.flush()
  for y_mag_est, x_theta, length, wav_dir in zip(y_mag_est_batch, x_theta_batch, x_lengths, wave_dir_list):
    if PARAM.RESTORE_PHASE != 'MIXED':
      raise ValueError('Please set PARAM.RESTORE_PHASE=MIXED.')
    # cat padding
    y_mag_est = y_mag_est[:length,:]
    x_theta = x_theta[:length,:]

    # istft
    y_mag_est = y_mag_est*np.exp(1j*x_theta)
    reY = spectrum_tool.librosa_istft(y_mag_est, PARAM.NFFT, PARAM.OVERLAP)
    y_wav_dir = wav_dir.replace('wav', enhanced_dir_name, 1)
    y_wav_father_dir = os.path.dirname(y_wav_dir)
    if not os.path.exists(y_wav_father_dir):
      os.makedirs(y_wav_father_dir)
    audio_tool.write_audio(y_wav_dir, reY, PARAM.FS)

  e_time = time.time()
  print('batch_cost_time: %ds' % (e_time-s_time), flush=True)
Exemple #6
0
def get_PESQ_STOI_SDR(test_set_tfrecords_dir, ckpt_dir, set_name):
    '''
  x_mag : magnitude spectrum of mixed audio.
  x_theta : angle of mixed audio's complex spectrum.
  y_xxx : clean(label) audio's xxx.
  y_xxx_est " estimate audio's xxx.
  '''
    sess, model, iter_test = _build_model_use_tfdata(test_set_tfrecords_dir,
                                                     ckpt_dir)
    sess.run(iter_test.initializer)
    i = 0
    all_batch = math.ceil(PARAM.DATASET_SIZES[-1] / PARAM.batch_size)
    pesq_mat = None
    stoi_mat = None
    sdr_mat = None
    while True:
        try:
            i += 1
            if i <= all_batch:
                print("-Testing batch %03d/%03d: " % (i, all_batch))
                print('  |-Decoding...')
            time_save = time.time()
            sys.stdout.flush()
            mask, x_mag, x_theta, y_mag, y_theta, y_mag_est, y_theta_est, batch_size = sess.run(
                [
                    model.mask, model.x_mag, model.x_theta, model.y_mag,
                    model.y_theta, model.y_mag_estimation,
                    model.y_theta_estimation, model.batch_size
                ])
            x_wav = [
                spectrum_tool.librosa_istft(x_mag_t * np.exp(1j * x_theta_t),
                                            PARAM.NFFT, PARAM.OVERLAP)
                for x_mag_t, x_theta_t in zip(x_mag, x_theta)
            ]
            y_wav = [
                spectrum_tool.librosa_istft(y_mag_t * np.exp(1j * y_theta_t),
                                            PARAM.NFFT, PARAM.OVERLAP)
                for y_mag_t, y_theta_t in zip(y_mag, y_theta)
            ]
            if PARAM.RESTORE_PHASE == 'MIXED':
                y_spec_est = [
                    y_mag_est_t * np.exp(1j * x_theta_t)
                    for y_mag_est_t, x_theta_t in zip(y_mag_est, x_theta)
                ]
                y_wav_est = [
                    spectrum_tool.librosa_istft(y_spec_est_t, PARAM.NFFT,
                                                PARAM.OVERLAP)
                    for y_spec_est_t in y_spec_est
                ]
            elif PARAM.RESTORE_PHASE == 'GRIFFIN_LIM':
                y_wav_est = [
                    spectrum_tool.griffin_lim(y_mag_est_t, PARAM.NFFT,
                                              PARAM.OVERLAP,
                                              PARAM.GRIFFIN_ITERNUM, x_wav_t)
                    for y_mag_est_t, x_wav_t in zip(y_mag_est, x_wav)
                ]
            elif PARAM.RESTORE_PHASE == 'ESTIMATE':
                if y_theta_est is None:
                    print('Model cannot estimate y_theta.')
                    exit(-1)
                y_spec_est = [
                    y_mag_est_t * np.exp(1j * y_theta_est_t)
                    for y_mag_est_t, y_theta_est_t in zip(
                        y_mag_est, y_theta_est)
                ]
                y_wav_est = [
                    spectrum_tool.librosa_istft(y_spec_est_t, PARAM.NFFT,
                                                PARAM.OVERLAP)
                    for y_spec_est_t in y_spec_est
                ]
            else:
                print('RESTORE_PHASE error.')
                exit(-1)

            # Prevent overflow (else PESQ crashed)
            abs_max = (2**(PARAM.AUDIO_BITS - 1) - 1)
            x_wav = np.array(x_wav)
            y_wav = np.array(y_wav)
            y_wav_est = np.array(y_wav_est)
            x_wav = np.where(x_wav > abs_max, abs_max, x_wav)
            x_wav = np.where(x_wav < -abs_max, -abs_max, x_wav)
            y_wav = np.where(y_wav > abs_max, abs_max, y_wav)
            y_wav = np.where(y_wav < -abs_max, -abs_max, y_wav)
            y_wav_est = np.where(y_wav_est > abs_max, abs_max, y_wav_est)
            y_wav_est = np.where(y_wav_est < -abs_max, -abs_max, y_wav_est)

            print('      |-Decode cost time:', (time.time() - time_save))
            time_save = time.time()
            print('  |-Calculating PESQ...')
            sys.stdout.flush()
            pesq_mat_t = audio_tool.get_batch_pesq_improvement(
                x_wav, y_wav, y_wav_est, i, set_name)
            pesq_ans_t = np.mean(pesq_mat_t, axis=-1)
            print('      |-Batch average mix-ref     PESQ :', pesq_ans_t[0])
            print('      |-Batch average enhance-ref PESQ :', pesq_ans_t[1])
            print('      |-Batch average improved    PESQ :', pesq_ans_t[2])
            print('      |-Calculate PESQ cost time:',
                  (time.time() - time_save))

            time_save = time.time()
            print('  |-Calculating STOI...')
            sys.stdout.flush()
            stoi_mat_t = audio_tool.get_batch_stoi_improvement(
                x_wav, y_wav, y_wav_est)
            stoi_ans_t = np.mean(stoi_mat_t, axis=-1)
            print('      |-Batch average mix-ref     STOI :', stoi_ans_t[0])
            print('      |-Batch average enhance-ref STOI :', stoi_ans_t[1])
            print('      |-Batch average improved    STOI :', stoi_ans_t[2])
            print('      |-Calculate STOI cost time:',
                  (time.time() - time_save))

            time_save = time.time()
            print('  |-Calculating SDR...')
            sys.stdout.flush()
            sdr_mat_t = audio_tool.get_batch_sdr_improvement(
                x_wav, y_wav, y_wav_est)
            sdr_ans_t = np.mean(sdr_mat_t, axis=-1)
            # print(np.shape(sdr_mat_t),np.shape(sdr_ans_t))
            print('      |-Batch average mix-ref     SDR :', sdr_ans_t[0])
            print('      |-Batch average enhance-ref SDR :', sdr_ans_t[1])
            print('      |-Batch average improved    SDR :', sdr_ans_t[2])
            print('      |-Calculate SDR cost time:',
                  (time.time() - time_save))
            sys.stdout.flush()

            if pesq_mat is None:
                pesq_mat = pesq_mat_t
                stoi_mat = stoi_mat_t
                sdr_mat = sdr_mat_t
            else:
                pesq_mat = np.concatenate((pesq_mat, pesq_mat_t), axis=-1)
                stoi_mat = np.concatenate((stoi_mat, stoi_mat_t), axis=-1)
                sdr_mat = np.concatenate((sdr_mat, sdr_mat_t), axis=-1)
        except tf.errors.OutOfRangeError:
            break
    pesq_ans = np.mean(pesq_mat, axis=-1)
    stoi_ans = np.mean(stoi_mat, axis=-1)
    sdr_ans = np.mean(sdr_mat, axis=-1)
    print('avg_pesq      raw:', pesq_ans[0])
    print('avg_pesq enhanced:', pesq_ans[1])
    print('avg_pesq      imp:', pesq_ans[2])
    print('avg_stoi      raw:', stoi_ans[0])
    print('avg_stoi enhanced:', stoi_ans[1])
    print('avg_stoi      imp:', stoi_ans[2])
    print('avg_sdr      raw:', sdr_ans[0])
    print('avg_sdr enhanced:', sdr_ans[1])
    print('avg_sdr      imp:', sdr_ans[2])
    return {
        'pesq': list(pesq_ans),
        'stoi': list(stoi_ans),
        'sdr': list(sdr_ans)
    }