Пример #1
0
def calc_sdr(in_file, out_speech_dir):

    out_names = [
        os.path.join(out_speech_dir, na)
        for na in sorted(os.listdir(out_speech_dir)) if na.endswith(".wav")
    ]
    sdr_list = []
    print("---------------------------------")
    print("\t", "SDR", "\n")

    (x, fs1) = pp.read_audio(in_file)
    top = 1 / x.shape[0] * (sum(x**2))

    for f in out_names:
        (y, fs2) = pp.read_audio(f)

        if fs1 != fs2:
            print("Error: output and input files have different sampling rate")

        bottom = 1 / y.shape[0] * (sum(y**2))
        sdr_list.append(10 * np.log10(top / bottom))

    avg_sdr = np.mean(sdr_list)
    std_sdr = np.std(sdr_list)
    print("AVG SDR\t", avg_sdr)
    print("ST DEV SDR\t", std_sdr)
    print("---------------------------------")
    return avg_sdr, std_sdr
Пример #2
0
def create_room(source_file, noise_file, dist):

    (clean, fs) = pp.read_audio(source_file)
    (noise, _) = pp.read_audio(noise_file)

    for file in os.listdir(os.path.join("data_eval", "dnn1_in")):
        file_path = os.path.join("data_eval", "dnn1_in", file)
        os.remove(file_path)

    for n in range(len(dist)):

        mixed, noise_new, clean_new, s2nr = set_microphone_at_distance(
            clean, noise, fs, dist[n])

        # s2nr = 1 / (1 + (1 / float(snr)))

        mixed_name = "mix_%s_%s" % (str(
            dist[n]), os.path.basename(source_file))
        clean_name = "clean_%s_%s" % (str(
            dist[n]), os.path.basename(source_file))

        mixed_path = os.path.join('data_eval/dnn1_in', mixed_name)
        clean_path = os.path.join('data_eval/dnn1_in', clean_name)

        pp.write_audio(mixed_path, mixed, fs)
Пример #3
0
def calc_stoi(in_file, out_speech_dir):

    out_names = [
        os.path.join(out_speech_dir, na)
        for na in sorted(os.listdir(out_speech_dir)) if na.endswith(".wav")
    ]
    stoi_list = []
    print("---------------------------------")
    print("\t", "STOI", "\n")
    (x, fs1) = pp.read_audio(in_file)

    for f in out_names:
        print(f)
        (y, fs2) = pp.read_audio(f)
        if fs1 != fs2:
            print("Error: output and input files have different sampling rate")

        m = min(len(x), len(y))
        res = stoi(x[0:m], y[0:m], fs1)
        stoi_list.append(res)
        # print(g, "\t",  res)

    avg_stoi = np.mean(stoi_list)
    std_stoi = np.std(stoi_list)
    print("AVG STOI\t", avg_stoi)
    print("ST DEV STOI\t", std_stoi)
    print("---------------------------------")
    return avg_stoi, std_stoi
Пример #4
0
def demo(args):
    """Inference all test data, write out recovered wavs to disk.

    Args:
      workspace: str, path of workspace.
      tr_snr: float, training SNR.
      te_snr: float, testing SNR.
      n_concat: int, number of frames to concatenta, should equal to n_concat
          in the training stage.
      iter: int, iteration of model to load.
      visualize: bool, plot enhanced spectrogram for debug.
    """
    print(args)
    workspace = args.workspace
    tr_snr = args.tr_snr
    te_snr = args.te_snr
    n_concat = args.n_concat
    iter = args.iteration

    n_window = cfg.n_window
    n_overlap = cfg.n_overlap
    fs = cfg.sample_rate
    scale = True

    # Load model.
    model_path = os.path.join(workspace, "models", "%ddb" % int(tr_snr), "FullyCNN.h5")
    model = load_model(model_path)

    # Load test data.
    if args.online:
        print('recording....')
        recordfile = 'record.wav'
        my_record(recordfile, 16000, 2)
        print('recording end')
        (data, _) = pp_data.read_audio(recordfile, 16000)
    else:
        testfile = 'data_cache/test_speech/1568253725.587787.wav'
        (data, _) = pp_data.read_audio(testfile, 16000)
    mixed_complx_x = pp_data.calc_sp(data, mode='complex')
    mixed_x, mixed_phase = divide_magphase(mixed_complx_x, power=1)

    # Predict.
    pred = model.predict(mixed_x)
    # Recover enhanced wav.
    pred_sp = pred  # np.exp(pred)
    n_window = cfg.n_window
    n_overlap = cfg.n_overlap
    hop_size = n_window - n_overlap
    ham_win = np.sqrt(np.hanning(n_window))
    stft_reconstructed_clean = merge_magphase(pred_sp, mixed_phase)
    stft_reconstructed_clean = stft_reconstructed_clean.T
    signal_reconstructed_clean = librosa.istft(stft_reconstructed_clean, hop_length=hop_size, window=ham_win)
    signal_reconstructed_clean = signal_reconstructed_clean*32768
    s = signal_reconstructed_clean.astype('int16')

    # Write out enhanced wav.
    # out_path = os.path.join(workspace, "enh_wavs", "test", "%ddb" % int(te_snr), "%s.enh.wav" % na)
    # pp_data.create_folder(os.path.dirname(out_path))
    pp_data.write_audio('1568253725.587787ehs.wav', s, fs)
def plot_fig5(data_type, audio_idx):
    workspace = cfg.workspace
    fs = cfg.sample_rate

    # Read audio.
    audio_path = os.path.join(
        workspace, "mixed_audio/n_events=3/%s.mixed_20db.wav" % audio_idx)
    (audio, _) = pp_data.read_stereo_audio(audio_path, fs)
    event_audio = audio[:, 0]
    noise_audio = audio[:, 1]
    mixed_audio = (event_audio + noise_audio) / 2

    event_audio_1 = np.zeros_like(event_audio)
    event_audio_1[0:int(fs * 2.5)] = event_audio[0:int(fs * 2.5)]
    event_audio_2 = np.zeros_like(event_audio)
    event_audio_2[int(fs * 2.5):int(fs * 5.)] = event_audio[int(fs *
                                                                2.5):int(fs *
                                                                         5.)]
    event_audio_3 = np.zeros_like(event_audio)
    event_audio_3[int(fs * 5.):int(fs * 7.5)] = event_audio[int(fs *
                                                                5.):int(fs *
                                                                        7.5)]

    sep_dir = "/vol/vssp/msos/qk/workspaces/weak_source_separation/dcase2013_task2/sep_audio/tmp01/n_events=3/fold=0/snr=20"
    sep_paths = glob.glob(os.path.join(sep_dir, "%s*" % audio_idx))

    print([os.path.basename(e) for e in sep_paths])
    (sep_event_audio_1, _) = pp_data.read_audio(sep_paths[3])
    (sep_event_audio_2, _) = pp_data.read_audio(sep_paths[0])
    (sep_event_audio_3, _) = pp_data.read_audio(sep_paths[1])
    (sep_noise_audio, _) = pp_data.read_audio(sep_paths[2])

    fig, axs = plt.subplots(5, 2, sharex=True)
    axs[0, 0].plot(mixed_audio)
    axs[1, 0].plot(event_audio_1)
    axs[2, 0].plot(event_audio_2)
    axs[3, 0].plot(event_audio_3)
    axs[4, 0].plot(noise_audio)

    axs[1, 1].plot(sep_event_audio_1)
    axs[2, 1].plot(sep_event_audio_2)
    axs[3, 1].plot(sep_event_audio_3)
    axs[4, 1].plot(sep_noise_audio)

    T = len(noise_audio)
    for i1 in xrange(5):
        for i2 in xrange(2):
            axs[i1, i2].axis([0, T, -1, 1])
            axs[i1, i2].xaxis.set_ticks([])
            axs[i1, i2].yaxis.set_ticks([])
            axs[i1, i2].set_ylabel("Amplitude")
    plt.show()
def plot_fig1():
    workspace = cfg.workspace
    fs = cfg.sample_rate

    # Read audio.
    audio_path = os.path.join(workspace,
                              "mixed_audio/n_events=3/00292.mixed_20db.wav")
    (audio, _) = pp_data.read_audio(audio_path, fs)
    audio = audio / np.max(np.abs(audio))

    # Calculate log Mel.
    x = _calc_feat(audio)
    print(x.shape)

    audio_path = os.path.join(workspace,
                              "mixed_audio/n_events=3/00292.mixed_100db.wav")
    (audio_clean, _) = pp_data.read_audio(audio_path, fs)
    x_clean = _calc_feat(audio_clean)

    # Plot.
    fig, axs = plt.subplots(3, 1, sharex=False)
    axs[0].plot(audio)
    axs[0].axis([0, len(audio), -1, 1])
    axs[0].xaxis.set_ticks([])
    axs[0].set_ylabel("Amplitude")
    axs[0].set_title("Waveform")

    axs[1].matshow(x.T, origin='lower', aspect='auto', cmap='jet')
    axs[1].xaxis.set_ticks([])
    axs[1].set_ylabel('Mel freq. bin')
    axs[1].set_title("Log Mel spectrogram")

    tmp = (np.sign(x_clean - (-7.)) + 1) / 2.
    axs[2].matshow(tmp.T, origin='lower', aspect='auto', cmap='jet')
    axs[2].xaxis.set_ticks([0, 60, 120, 180, 239])
    axs[2].xaxis.tick_bottom()
    axs[2].xaxis.set_ticklabels(np.arange(0, 10.1, 2.5))
    axs[2].set_xlabel("second")
    axs[2].xaxis.set_label_coords(1.1, -0.05)

    axs[2].yaxis.set_ticks([0, 16, 32, 48, 63])
    axs[2].yaxis.set_ticklabels([0, 16, 32, 48, 63])
    axs[2].set_ylabel('Mel freq. bin')

    axs[2].set_title("T-F segmentation mask")

    plt.tight_layout()
    plt.show()
Пример #7
0
def main(from_dir, to_dir, sr):
    from_paths = wav_paths(from_dir)
    for from_p in tqdm(from_paths, 'Resampling audio'):
        rel_p = PurePath(from_p).relative_to(from_dir)
        to_p = to_dir / rel_p
        os.makedirs(to_p.parent, exist_ok=True)

        wav, _ = read_audio(from_p, sr)
        write_audio(to_p, wav, sr)          
Пример #8
0
def predict_file(file_path, model, scaler):

    (a, _) = pp.read_audio(file_path)
    mixed_complex = pp.calc_sp(a, 'complex')

    mixed_x = np.abs(mixed_complex)

    # Process data.
    n_pad = (conf1.n_concat - 1) / 2
    mixed_x = pp.pad_with_border(mixed_x, n_pad)
    mixed_x = pp.log_sp(mixed_x)
    # speech_x = dnn1_train.log_sp(speech_x)


    # Scale data.
    # if scale:
    mixed_x = pp.scale_on_2d(mixed_x, scaler)
    # speech_x = pp.scale_on_2d(speech_x, scaler)

    # Cut input spectrogram to 3D segments with n_concat.
    mixed_x_3d = pp.mat_2d_to_3d(mixed_x, agg_num=conf1.n_concat, hop=1)

    # Predict.
    pred = model.predict(mixed_x_3d)

    if visualize_plot:
        visualize(mixed_x, pred)
    # Inverse scale.
    # if scale:
    mixed_x = pp.inverse_scale_on_2d(mixed_x, scaler)
    # speech_x = dnn1_train.inverse_scale_on_2d(speech_x, scaler)
    pred = pp.inverse_scale_on_2d(pred, scaler)


    # Debug plot.

    # Recover enhanced wav.
    pred_sp = np.exp(pred)
    s = recover_wav(pred_sp, mixed_complex, conf1.n_overlap, np.hamming)
    s *= np.sqrt((np.hamming(conf1.n_window) ** 2).sum())  # Scaler for compensate the amplitude
    # change after spectrogram and IFFT.

    # Write out enhanced wav.

    # audio_path = os.path.dirname(file_path)
    # pp.write_audio(audio_path, s, conf1.sample_rate)

    return mixed_complex, pred, s
def plot_fig4(data_type, audio_idx):
    workspace = cfg.workspace
    n_window = cfg.n_window
    n_overlap = cfg.n_overlap
    fs = cfg.sample_rate
    events = cfg.events
    te_fold = cfg.te_fold

    # Read audio.
    audio_path = os.path.join(
        workspace, "mixed_audio/n_events=3/%s.mixed_20db.wav" % audio_idx)
    (audio, _) = pp_data.read_audio(audio_path, fs)

    # Calculate log Mel.
    x = _calc_feat(audio)
    sp = _calc_spectrogram(audio)
    print(x.shape)

    # Plot.
    fig, axs = plt.subplots(4, 4, sharex=False)

    # Plot log Mel spectrogram.
    for i2 in xrange(16):
        axs[i2 / 4, i2 % 4].set_visible(False)

    axs[0, 0].matshow(x.T, origin='lower', aspect='auto', cmap='jet')
    axs[0, 0].xaxis.set_ticks([0, 60, 120, 180, 239])
    axs[0, 0].xaxis.tick_bottom()
    axs[0, 0].xaxis.set_ticklabels(np.arange(0, 10.1, 2.5))
    axs[0, 0].set_xlabel("time (s)")
    # axs[0,0].xaxis.set_label_coords(1.12, -0.05)

    axs[0, 0].yaxis.set_ticks([0, 16, 32, 48, 63])
    axs[0, 0].yaxis.set_ticklabels([0, 16, 32, 48, 63])
    axs[0, 0].set_ylabel('Mel freq. bin')

    axs[0, 0].set_title("Log Mel spectrogram")
    axs[0, 0].set_visible(True)

    # Plot spectrogram.
    axs[0, 2].matshow(np.log(sp.T + 1.),
                      origin='lower',
                      aspect='auto',
                      cmap='jet')
    axs[0, 2].xaxis.set_ticks([0, 60, 120, 180, 239])
    axs[0, 2].xaxis.tick_bottom()
    axs[0, 2].xaxis.set_ticklabels(np.arange(0, 10.1, 2.5))
    axs[0, 2].set_xlabel("time (s)")
    # axs[0,2].xaxis.set_label_coords(1.12, -0.05)

    axs[0, 2].yaxis.set_ticks([0, 128, 256, 384, 512])
    axs[0, 2].yaxis.set_ticklabels([0, 128, 256, 384, 512])
    axs[0, 2].set_ylabel('FFT freq. bin')

    axs[0, 2].set_title("Spectrogram")
    axs[0, 2].set_visible(True)

    # plt.tight_layout()
    plt.show()

    # Load data.
    snr = 20
    n_events = 3
    feature_dir = os.path.join(workspace, "features", "logmel",
                               "n_events=%d" % n_events)
    yaml_dir = os.path.join(workspace, "mixed_audio", "n_events=%d" % n_events)
    (tr_x, tr_at_y, tr_sed_y, tr_na_list, te_x, te_at_y, te_sed_y,
     te_na_list) = pp_data.load_data(feature_dir=feature_dir,
                                     yaml_dir=yaml_dir,
                                     te_fold=te_fold,
                                     snr=snr,
                                     is_scale=is_scale)

    if data_type == "train":
        x = tr_x
        at_y = tr_at_y
        sed_y = tr_sed_y
        na_list = tr_na_list
    elif data_type == "test":
        x = te_x
        at_y = te_at_y
        sed_y = te_sed_y
        na_list = te_na_list

    for (i1, na) in enumerate(na_list):
        if audio_idx in na:
            idx = i1
    print(idx)

    # GT mask
    (stereo_audio, _) = pp_data.read_stereo_audio(audio_path, target_fs=fs)
    event_audio = stereo_audio[:, 0]
    noise_audio = stereo_audio[:, 1]
    mixed_audio = event_audio + noise_audio

    ham_win = np.hamming(n_window)
    mixed_cmplx_sp = pp_data.calc_sp(mixed_audio, fs, ham_win, n_window,
                                     n_overlap)
    mixed_sp = np.abs(mixed_cmplx_sp)
    event_sp = np.abs(
        pp_data.calc_sp(event_audio, fs, ham_win, n_window, n_overlap))
    noise_sp = np.abs(
        pp_data.calc_sp(noise_audio, fs, ham_win, n_window, n_overlap))

    db = -5.
    gt_mask = (np.sign(20 * np.log10(event_sp / noise_sp) - db) +
               1.) / 2.  # (n_time, n_freq)
    fig, axs = plt.subplots(4, 4, sharex=True)
    for i2 in xrange(16):
        ind_gt_mask = gt_mask * sed_y[idx, :, i2][:, None]
        axs[i2 / 4, i2 % 4].matshow(ind_gt_mask.T,
                                    origin='lower',
                                    aspect='auto',
                                    cmap='jet')
        # axs[i2/4, i2%4].set_title(events[i2])
        axs[i2 / 4, i2 % 4].xaxis.set_ticks([])
        axs[i2 / 4, i2 % 4].yaxis.set_ticks([])
        axs[i2 / 4, i2 % 4].set_xlabel('time')
        axs[i2 / 4, i2 % 4].set_ylabel('FFT freq. bin')
    plt.show()

    for filename in ["tmp01", "tmp02", "tmp03"]:
        # Plot up sampled seg masks.
        preds_dir = os.path.join(workspace, "preds", filename,
                                 "n_events=%d" % n_events, "fold=%d" % te_fold,
                                 "snr=%d" % snr)

        at_probs_list, seg_masks_list = [], []
        bgn_iter, fin_iter, interval = 2000, 3001, 200
        for iter in xrange(bgn_iter, fin_iter, interval):
            seg_masks_path = os.path.join(preds_dir, "md%d_iters" % iter,
                                          "seg_masks.p")
            seg_masks = cPickle.load(open(seg_masks_path, 'rb'))
            seg_masks_list.append(seg_masks)
        seg_masks = np.mean(seg_masks_list,
                            axis=0)  # (n_clips, n_classes, n_time, n_freq)

        print(at_y[idx])

        melW = librosa.filters.mel(sr=fs,
                                   n_fft=cfg.n_window,
                                   n_mels=64,
                                   fmin=0.,
                                   fmax=fs / 2)
        inverse_melW = get_inverse_W(melW)

        spec_masks = np.dot(seg_masks[idx],
                            inverse_melW)  # (n_classes, n_time, 513)

        fig, axs = plt.subplots(4, 4, sharex=True)
        for i2 in xrange(16):
            axs[i2 / 4, i2 % 4].matshow(spec_masks[i2].T,
                                        origin='lower',
                                        aspect='auto',
                                        vmin=0,
                                        vmax=1,
                                        cmap='jet')
            # axs[i2/4, i2%4].set_title(events[i2])
            axs[i2 / 4, i2 % 4].xaxis.set_ticks([])
            axs[i2 / 4, i2 % 4].yaxis.set_ticks([])
            axs[i2 / 4, i2 % 4].set_xlabel('time')
            axs[i2 / 4, i2 % 4].set_ylabel('FFT freq. bin')
        fig.suptitle(filename)
        plt.show()

        # Plot SED probs.
        sed_probs = np.mean(seg_masks[idx], axis=-1)  # (n_classes, n_time)
        fig, axs = plt.subplots(4, 4, sharex=False)
        for i2 in xrange(16):
            axs[i2 / 4, i2 % 4].set_visible(False)
        axs[0, 0].matshow(sed_probs,
                          origin='lower',
                          aspect='auto',
                          vmin=0,
                          vmax=1,
                          cmap='jet')
        # axs[0, 0].xaxis.set_ticks([0, 60, 120, 180, 239])
        # axs[0, 0].xaxis.tick_bottom()
        # axs[0, 0].xaxis.set_ticklabels(np.arange(0, 10.1, 2.5))
        axs[0, 0].xaxis.set_ticks([])
        # axs[0, 0].set_xlabel('time (s)')
        axs[0, 0].yaxis.set_ticks(xrange(len(events)))
        axs[0, 0].yaxis.set_ticklabels(events)
        for tick in axs[0, 0].yaxis.get_major_ticks():
            tick.label.set_fontsize(8)
        axs[0, 0].set_visible(True)

        axs[1, 0].matshow(sed_y[idx].T,
                          origin='lower',
                          aspect='auto',
                          vmin=0,
                          vmax=1,
                          cmap='jet')
        # axs[1, 0].xaxis.set_ticks([])
        axs[1, 0].xaxis.set_ticks([0, 60, 120, 180, 239])
        axs[1, 0].xaxis.tick_bottom()
        axs[1, 0].xaxis.set_ticklabels(np.arange(0, 10.1, 2.5))
        axs[1, 0].set_xlabel('time (s)')
        axs[1, 0].yaxis.set_ticks(xrange(len(events)))
        axs[1, 0].yaxis.set_ticklabels(events)
        for tick in axs[1, 0].yaxis.get_major_ticks():
            tick.label.set_fontsize(8)
        axs[1, 0].set_visible(True)
        fig.suptitle(filename)
        plt.show()
Пример #10
0
def prepare_database():

    (noise, _) = pp.read_audio(conf1.noise_path)

    with open('dnn1/dnn1_files_list.txt') as f:
        dnn1_data = f.readlines()

    # generate train spectrograms
    mixed_all = []
    clean_all = []

    snr1_list = []
    mixed_avg = []

    for n in range(conf1.training_number):
        current_file = (random.choice(dnn1_data)).rstrip()
        dist = random.uniform(1, 20)
        (clean, _) = pp.read_audio(current_file)

        mixed, noise_new, clean_new, snr = set_microphone_at_distance(
            clean, noise, conf1.fs, dist)

        snr1_list.append(snr)
        mixed_avg.append(np.mean(mixed))

        if n % 10 == 0:
            print(n)

        if conf1.save_single_files and n < conf1.n_files_to_save:

            sr = ''.join(
                random.choice(string.ascii_uppercase + string.digits)
                for _ in range(5))

            path_list = current_file.split(os.sep)
            mixed_name = "mix_%s_%s_%s" % (path_list[2], sr,
                                           os.path.basename(current_file))
            clean_name = "clean_%s_%s_%s" % (path_list[2], sr,
                                             os.path.basename(current_file))
            path_list = current_file.split(os.sep)

            mixed_path = os.path.join(conf1.train_folder, mixed_name)
            clean_path = os.path.join(conf1.train_folder, clean_name)

            pp.write_audio(mixed_path, mixed, conf1.fs)
            pp.write_audio(clean_path, clean_new, conf1.fs)

        clean_spec = pp.calc_sp(clean_new, mode='magnitude')
        mixed_spec = pp.calc_sp(mixed, mode='complex')

        clean_all.append(clean_spec)
        mixed_all.append(mixed_spec)

    print(len(clean_all), ',', len(mixed_all))
    num_tr = pp.pack_features(mixed_all, clean_all, 'train')

    compute_scaler('train')

    # generate test spectrograms
    mixed_all = []
    clean_all = []

    snr1_list = []
    mixed_avg = []

    for n in range(conf1.test_number):
        current_file = (random.choice(dnn1_data)).rstrip()
        dist = random.uniform(1, 20)
        (clean, _) = pp.read_audio(current_file)

        mixed, noise_new, clean_new, snr = set_microphone_at_distance(
            clean, noise, conf1.fs, dist)

        snr1_list.append(snr)
        mixed_avg.append(np.mean(mixed))

        if n % 10 == 0:
            print(n)

        if conf1.save_single_files and n < conf1.n_files_to_save:

            sr = ''.join(
                random.choice(string.ascii_uppercase + string.digits)
                for _ in range(5))

            path_list = current_file.split(os.sep)
            mixed_name = "mix_%s_%s_%s" % (path_list[2], sr,
                                           os.path.basename(current_file))
            clean_name = "clean_%s_%s_%s" % (path_list[2], sr,
                                             os.path.basename(current_file))

            mixed_path = os.path.join(conf1.test_folder, mixed_name)
            clean_path = os.path.join(conf1.test_folder, clean_name)

            pp.write_audio(mixed_path, mixed, conf1.fs)
            pp.write_audio(clean_path, clean_new, conf1.fs)

        clean_spec = pp.calc_sp(clean_new, mode='magnitude')
        mixed_spec = pp.calc_sp(mixed, mode='complex')

        clean_all.append(clean_spec)
        mixed_all.append(mixed_spec)

    print(len(clean_all), ',', len(mixed_all))

    num_te = pp.pack_features(mixed_all, clean_all, 'test')

    compute_scaler('test')

    return num_tr, num_te,
def evaluate_separation(args):
    workspace = cfg.workspace
    events = cfg.events
    te_fold = cfg.te_fold
    n_window = cfg.n_window
    n_overlap = cfg.n_overlap
    fs = cfg.sample_rate
    clip_duration = cfg.clip_duration
    n_events = args.n_events
    snr = args.snr

    # Load ground truth data.
    feature_dir = os.path.join(workspace, "features", "logmel",
                               "n_events=%d" % n_events)
    yaml_dir = os.path.join(workspace, "mixed_audio", "n_events=%d" % n_events)
    (tr_x, tr_at_y, tr_sed_y, tr_na_list, te_x, te_at_y, te_sed_y,
     te_na_list) = pp_data.load_data(feature_dir=feature_dir,
                                     yaml_dir=yaml_dir,
                                     te_fold=te_fold,
                                     snr=snr,
                                     is_scale=is_scale)

    at_y = te_at_y
    sed_y = te_sed_y
    na_list = te_na_list

    audio_dir = os.path.join(workspace, "mixed_audio",
                             "n_events=%d" % n_events)

    sep_dir = os.path.join(workspace, "sep_audio",
                           pp_data.get_filename(__file__),
                           "n_events=%d" % n_events, "fold=%d" % te_fold,
                           "snr=%d" % snr)

    sep_stats = {}
    for e in events:
        sep_stats[e] = {'sdr': [], 'sir': [], 'sar': []}

    cnt = 0
    for (i1, na) in enumerate(na_list):
        bare_na = os.path.splitext(na)[0]
        gt_audio_path = os.path.join(audio_dir, "%s.wav" % bare_na)
        (stereo_audio, _) = pp_data.read_stereo_audio(gt_audio_path,
                                                      target_fs=fs)
        gt_event_audio = stereo_audio[:, 0]
        gt_noise_audio = stereo_audio[:, 1]

        print(na)
        for j1 in xrange(len(events)):
            if at_y[i1][j1] == 1:
                sep_event_audio_path = os.path.join(
                    sep_dir, "%s.%s.wav" % (bare_na, events[j1]))
                (sep_event_audio, _) = pp_data.read_audio(sep_event_audio_path,
                                                          target_fs=fs)
                sep_noise_audio_path = os.path.join(sep_dir,
                                                    "%s.noise.wav" % bare_na)
                (sep_noise_audio, _) = pp_data.read_audio(sep_noise_audio_path,
                                                          target_fs=fs)
                ref_array = np.array((gt_event_audio, gt_noise_audio))
                est_array = np.array((sep_event_audio, sep_noise_audio))
                (sdr, sir, sar) = sdr_sir_sar(ref_array,
                                              est_array,
                                              sed_y[i1, :, j1],
                                              inside_only=True)
                print(sdr, sir, sar)
                sep_stats[events[j1]]['sdr'].append(sdr)
                sep_stats[events[j1]]['sir'].append(sir)
                sep_stats[events[j1]]['sar'].append(sar)

        cnt += 1
        # if cnt == 5: break

    print(sep_stats)
    sep_stat_path = os.path.join(workspace, "sep_stats",
                                 pp_data.get_filename(__file__),
                                 "n_events=%d" % n_events, "fold=%d" % te_fold,
                                 "snr=%d" % snr, "sep_stat.p")
    pp_data.create_folder(os.path.dirname(sep_stat_path))
    cPickle.dump(sep_stats, open(sep_stat_path, 'wb'))
Пример #12
0
def inference_wiener(args):
    workspace = args.workspace
    iter = args.iteration
    stack_num = args.stack_num
    filename = args.filename
    mini_num = args.mini_num
    visualize = args.visualize
    cuda = args.use_cuda and torch.cuda.is_available()
    print("cuda:", cuda)

    sample_rate = cfg.sample_rate
    fft_size = cfg.fft_size
    hop_size = cfg.hop_size
    window_type = cfg.window_type

    if window_type == 'hamming':
        window = np.hamming(fft_size)

    # Audio
    audio_dir = "/vol/vssp/msos/qk/workspaces/speech_enhancement/mixed_audios/spectrogram/test/0db"
    # audio_dir = "/user/HS229/qk00006/my_code2015.5-/python/pub_speech_enhancement/mixture2clean_dnn/workspace/mixed_audios/spectrogram/test/0db"
    names = os.listdir(audio_dir)

    # Load model.
    target_type = ['speech', 'noise']
    model_dict = {}
    for e in target_type:
        n_freq = 257
        model = DNN(stack_num, n_freq)
        model_path = os.path.join(workspace, "models", filename, e,
                                  "md_%d_iters.tar" % iter)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['state_dict'])

        # Move model to GPU.
        if cuda:
            model.cuda()
        model.eval()

        model_dict[e] = model

    # Load scalar
    scalar_path = os.path.join(workspace, "scalars", filename, "scalar.p")
    (mean_, std_) = cPickle.load(open(scalar_path, 'rb'))
    mean_ = move_data_to_gpu(mean_, cuda, volatile=True)
    std_ = move_data_to_gpu(std_, cuda, volatile=True)

    if mini_num > 0:
        n_every = len(names) / mini_num
    else:
        n_every = 1

    out_wav_dir = os.path.join(workspace, "enh_wavs", filename)
    pp_data.create_folder(out_wav_dir)

    for (cnt, name) in enumerate(names):
        if cnt % n_every == 0:
            audio_path = os.path.join(audio_dir, name)
            (audio, _) = pp_data.read_audio(audio_path, sample_rate)

            audio = pp_data.normalize(audio)
            cmplx_sp = pp_data.calc_sp(audio, fft_size, hop_size, window)
            x = np.abs(cmplx_sp)

            # Process data.
            n_pad = (stack_num - 1) / 2
            x = pp_data.pad_with_border(x, n_pad)
            x = pp_data.mat_2d_to_3d(x, stack_num, hop=1)

            # Predict.
            pred_dict = {}
            for e in target_type:
                pred = forward(model_dict[e], x, mean_, std_, cuda)
                pred = pred.data.cpu().numpy()
                pred_dict[e] = pred
            print(cnt, name)

            # Wiener filter.
            pred_mag_sp = pred_dict['speech'] / (
                pred_dict['speech'] + pred_dict['noise']) * np.abs(cmplx_sp)

            pred_cmplx_sp = stft.real_to_complex(pred_mag_sp, cmplx_sp)
            frames = stft.istft(pred_cmplx_sp)

            cola_constant = stft.get_cola_constant(hop_size, window)
            seq = stft.overlap_add(frames, hop_size, cola_constant)
            seq = seq[0:len(audio)]

            # Write out wav
            out_wav_path = os.path.join(out_wav_dir, name)
            pp_data.write_audio(out_wav_path, seq, sample_rate)
            print("Write out wav to: %s" % out_wav_path)

            if visualize:
                vmin = -5.
                vmax = 5.
                fig, axs = plt.subplots(3, 1, sharex=True)
                axs[0].matshow(np.log(np.abs(cmplx_sp)).T,
                               origin='lower',
                               aspect='auto',
                               cmap='jet')
                axs[1].matshow(np.log(np.abs(pred_dict['speech'])).T,
                               origin='lower',
                               aspect='auto',
                               cmap='jet')
                axs[2].matshow(np.log(np.abs(pred_dict['noise'])).T,
                               origin='lower',
                               aspect='auto',
                               cmap='jet')
                plt.show()
Пример #13
0
def inference(args):
    workspace = args.workspace
    model_name = args.model_name
    stack_num = args.stack_num
    filename = args.filename
    mini_num = args.mini_num
    visualize = args.visualize
    cuda = args.use_cuda and torch.cuda.is_available()
    print("cuda:", cuda)

    sample_rate = cfg.sample_rate
    fft_size = cfg.fft_size
    hop_size = cfg.hop_size
    window_type = cfg.window_type

    if window_type == 'hamming':
        window = np.hamming(fft_size)

    # Audio
    audio_dir = "/vol/vssp/msos/qk/workspaces/speech_enhancement/mixed_audios/spectrogram/test/0db"
    names = os.listdir(audio_dir)

    # Load model
    model_path = os.path.join(workspace, "models", filename, model_name)
    n_freq = 257
    model = DNN(stack_num, n_freq)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])

    if cuda:
        model.cuda()

    # Load scalar
    scalar_path = os.path.join(workspace, "scalars", filename, "scalar.p")
    (mean_, std_) = cPickle.load(open(scalar_path, 'rb'))
    mean_ = move_data_to_gpu(mean_, cuda, volatile=True)
    std_ = move_data_to_gpu(std_, cuda, volatile=True)

    if mini_num > 0:
        n_every = len(names) / mini_num
    else:
        n_every = 1

    for (cnt, name) in enumerate(names):
        if cnt % n_every == 0:
            audio_path = os.path.join(audio_dir, name)
            (audio, _) = pp_data.read_audio(audio_path, sample_rate)

            audio = pp_data.normalize(audio)
            sp = pp_data.calc_sp(audio, fft_size, hop_size, window)
            x = np.abs(sp)

            # Process data.
            n_pad = (stack_num - 1) / 2
            x = pp_data.pad_with_border(x, n_pad)
            x = pp_data.mat_2d_to_3d(x, stack_num, hop=1)

            output = forward(model, x, mean_, std_, cuda)
            output = output.data.cpu().numpy()

            print(output.shape)
            if visualize:
                fig, axs = plt.subplots(2, 1, sharex=True)
                axs[0].matshow(np.log(np.abs(sp)).T,
                               origin='lower',
                               aspect='auto',
                               cmap='jet')
                axs[1].matshow(np.log(np.abs(output)).T,
                               origin='lower',
                               aspect='auto',
                               cmap='jet')
                plt.show()

            import crash
            pause
Пример #14
0
def dab_run(snr_list, file_name="dab_out", mode='dab'):

    output_file_folder = os.path.join("data_eval", mode)

    # removing previous enhancements
    for file in os.listdir(os.path.join("data_eval", "dnn1_out")):
        file_path = os.path.join("data_eval", "dnn1_out", file)
        os.remove(file_path)

    dnn1_inputs, dnn1_outputs = dnn1.predict_folder(
        os.path.join("data_eval", "dnn1_in"),
        os.path.join("data_eval", "dnn1_out"))

    names = [
        f for f in sorted(os.listdir(os.path.join("data_eval", "dnn1_out")))
        if f.startswith("enh")
    ]
    dnn1_outputs = []
    for (cnt, na) in enumerate(names):
        # Load feature.
        file_path = os.path.join("data_eval", "dnn1_out", na)
        (a, _) = pp.read_audio(file_path)
        enh_complex = pp.calc_sp(a, 'complex')
        dnn1_outputs.append(enh_complex)

    # s2nrs = dnn2.predict("data_eval/dnn1_in", "data_eval/dnn1_out")

    # snr = np.array([5.62, 1.405, 0.703, 0.281])
    # snr = np.array([5.62, 2.81, 1.875, 1.406])
    s2nrs = snr_list * 1
    for i in range(len(snr_list)):
        s2nrs[i] = 1 / (1 + 1 / snr_list[i])

    ch_rw_outputs = []
    # calculate channel weights
    if mode == 'dab':
        new_weights = channel_weights(s2nrs)
        print(new_weights)
        # multiply enhanced audio for the corresponding weight
        for i, p in zip(dnn1_outputs, new_weights):
            ch_rw_outputs.append(p * i)

    # cancel reweighting if db mode
    if mode == 'db':
        new_weights = s2nrs
        print(new_weights)
        ch_rw_outputs = dnn1_outputs

    # execute mvdr
    final = mvdr(dnn1_inputs, ch_rw_outputs)

    (init,
     _) = pp.read_audio(os.path.join('data_eval', 'test_speech', file_name))
    init_sp = pp.calc_sp(init, mode='complex')

    visualize(dnn1_colors(np.abs(init_sp)), dnn1_colors(np.abs(final)),
              "source amplitude", "final amplitude")

    # Recover and save enhanced wav
    pp.create_folder(output_file_folder)
    s = recover_wav_complex(final, conf1.n_overlap, np.hamming)
    s *= np.sqrt((np.hamming(
        conf1.n_window)**2).sum())  # Scaler for compensate the amplitude
    audio_path = os.path.join(output_file_folder, file_name)
    pp.write_audio(audio_path, s, conf1.sample_rate)

    print('%s done' % mode)
Пример #15
0
def prepare_database():

    (noise, _) = pp.read_audio(conf2.noise_path)

    with open(os.path.join('dnn2', 'dnn2_files_list.txt')) as f:
        dnn2_data = f.readlines()

    (model1, scaler1) = dnn1.load_dnn()

    # generate train mean values

    snr2_list = []
    mixed_avg = []
    clean_avg = []
    enh_avg = []

    for n in range(conf2.training_number):
        current_file = (random.choice(dnn2_data)).rstrip()
        dist = random.uniform(1, 20)
        (clean, _) = pp.read_audio(current_file)

        mixed, noise_new, clean_new, s2nr = set_microphone_at_distance(
            clean, noise, conf2.fs, dist)

        (_, enh, _) = dnn1.predict_file(current_file, model1, scaler1)

        # s2nr = 1 / (1 + (1 / float(snr)))
        snr2_list.append(s2nr)

        mixed_avg.append(np.mean(mixed))
        clean_avg.append(np.mean(clean_new))
        enh_avg.append(np.mean(enh))

        sr = ''.join(
            random.choice(string.ascii_uppercase + string.digits)
            for _ in range(5))
        path_list = current_file.split(os.sep)
        mixed_name = "mix_%s_%s_%s" % (path_list[2], sr,
                                       os.path.basename(current_file))
        clean_name = "clean_%s_%s_%s" % (path_list[2], sr,
                                         os.path.basename(current_file))
        enh_name = "enh_%s_%s_%s" % (path_list[2], sr,
                                     os.path.basename(current_file))

        if n % 10 == 0:
            print(n)

        if conf2.save_single_files and n < conf1.n_files_to_save:

            mixed_path = os.path.join(conf2.train_folder, mixed_name)
            clean_path = os.path.join(conf2.train_folder, clean_name)
            enh_path = os.path.join(conf2.train_folder, enh_name)
            pp.write_audio(mixed_path, mixed, conf2.fs)
            pp.write_audio(clean_path, clean_new, conf2.fs)
            pp.write_audio(enh_path, enh, conf2.fs)

    if len(mixed_avg) != len(enh_avg):
        raise Exception('Number of mixed and enhanced audio must be the same')

    num_tr = len(mixed_avg)

    if os.path.exists(os.path.join(conf2.train_folder, 'train_data.txt')):
        os.remove(os.path.join(conf2.train_folder, 'train_data.txt'))
    f1 = open(os.path.join(conf2.train_folder, 'train_data.txt'), 'w')
    for line1, line2, line3 in zip(mixed_avg, clean_avg, snr2_list):
        f1.write("%s, %s, %s\n" % (line1, line2, line3))

    print(len(mixed_avg), ',', len(enh_avg))

    # generate test spectrograms]

    snr2_list = []
    mixed_avg = []
    clean_avg = []
    enh_avg = []

    for n in range(conf2.test_number):
        current_file = (random.choice(dnn2_data)).rstrip()
        dist = random.uniform(1, 20)
        (clean, _) = pp.read_audio(current_file)

        mixed, noise_new, clean_new, s2nr = set_microphone_at_distance(
            clean, noise, conf2.fs, dist)

        (_, enh, _) = dnn1.predict_file(current_file, model1, scaler1)

        # s2nr = 1 / (1 + (1 / float(snr)))
        snr2_list.append(s2nr)

        mixed_avg.append(np.mean(mixed))
        clean_avg.append(np.mean(clean_new))
        enh_avg.append(np.mean(enh))

        sr = ''.join(
            random.choice(string.ascii_uppercase + string.digits)
            for _ in range(5))
        path_list = current_file.split(os.sep)
        mixed_name = "mix_%s_%s_%s" % (path_list[2], sr,
                                       os.path.basename(current_file))
        clean_name = "clean_%s_%s_%s" % (path_list[2], sr,
                                         os.path.basename(current_file))
        enh_name = "enh_%s_%s_%s" % (path_list[2], sr,
                                     os.path.basename(current_file))

        if n % 10 == 0:
            print(n)

        if conf2.save_single_files and n < conf1.n_files_to_save:

            mixed_path = os.path.join(conf2.train_folder, mixed_name)
            clean_path = os.path.join(conf2.train_folder, clean_name)
            enh_path = os.path.join(conf2.train_folder, enh_name)
            pp.write_audio(mixed_path, mixed, conf2.fs)
            pp.write_audio(clean_path, clean_new, conf2.fs)
            pp.write_audio(enh_path, enh, conf2.fs)

    print(len(mixed_avg), ',', len(enh_avg))

    if len(mixed_avg) != len(enh_avg):
        raise Exception('Number of mixed and enhanced audio must be the same')

    num_te = len(mixed_avg)

    if os.path.exists(os.path.join(conf2.test_folder, 'test_data.txt')):
        os.remove(os.path.join(conf2.test_folder, 'test_data.txt'))
    f1 = open(os.path.join(conf2.test_folder, 'test_data.txt'), 'w')
    for line1, line2, line3 in zip(mixed_avg, clean_avg, snr2_list):
        f1.write("%s, %s, %s\n" % (line1, line2, line3))

    return num_tr, num_te
Пример #16
0
def predict_folder(input_file_folder: object, output_file_folder: object) -> object:
    # Load model.
    data_type = "test"
    model_path = os.path.join(conf1.model_dir, "md_%diters.h5" % conf1.iterations)
    model = load_model(model_path)

    # Load scaler.
    # if scale:
    scaler_path = os.path.join(conf1.packed_feature_dir, data_type, "scaler.p")
    scaler = pickle.load(open(scaler_path, 'rb'))

    # Load test data.
    # names = os.listdir(input_file_folder)

    names = [f for f in sorted(os.listdir(input_file_folder)) if f.startswith("mix")]

    mixed_all = []
    pred_all = []
    for (cnt, na) in enumerate(names):
        # Load feature.
        file_path = os.path.join(input_file_folder, na)
        (a, _) = pp.read_audio(file_path)
        mixed_complex = pp.calc_sp(a, 'complex')


        mixed_x = np.abs(mixed_complex)

        # Process data.
        n_pad = (conf1.n_concat - 1) / 2
        mixed_x = pp.pad_with_border(mixed_x, n_pad)
        mixed_x = pp.log_sp(mixed_x)
        # speech_x = dnn1_train.log_sp(speech_x)

        # Scale data.
        # if scale:
        mixed_x = pp.scale_on_2d(mixed_x, scaler)

        # Cut input spectrogram to 3D segments with n_concat.
        mixed_x_3d = pp.mat_2d_to_3d(mixed_x, agg_num=conf1.n_concat, hop=1)


        # Predict.
        pred = model.predict(mixed_x_3d)
        print(cnt, na)

        # Inverse scale.
        #if scale:
        mixed_x = pp.inverse_scale_on_2d(mixed_x, scaler)
        # speech_x = dnn1_train.inverse_scale_on_2d(speech_x, scaler)
        pred = pp.inverse_scale_on_2d(pred, scaler)

        # Debug plot.
        if visualize_plot:
            visualize(mixed_x, pred)

        mixed_all.append(mixed_complex)
        pred_all.append(real_to_complex(pred, mixed_complex))


        # Recover enhanced wav.
        pred_sp = np.exp(pred)
        s = recover_wav(pred_sp, mixed_complex, conf1.n_overlap, np.hamming)
        s *= np.sqrt((np.hamming(conf1.n_window) ** 2).sum())  # Scaler for compensate the amplitude
        # change after spectrogram and IFFT.

        # Write out enhanced wav.

        pp.create_folder(output_file_folder)
        audio_path = os.path.join(output_file_folder, "enh_%s" % na)
        pp.write_audio(audio_path, s, conf1.sample_rate)

    return mixed_all, pred_all
Пример #17
0
def calculate_pesq(workspace,
                   speech_dir,
                   model_name,
                   te_snr,
                   library='pypesq',
                   mode='nb',
                   calc_mixed=False,
                   force=False):
    """Calculate PESQ of all enhaced speech.

    Args:
      workspace: str, path of workspace.
      speech_dir: str, path of clean speech.
      te_snr: float, testing SNR.
    """
    assert library in ('pypesq', 'pesq', 'stoi', 'sisdr')
    assert mode in ('wb', 'nb')

    if library == 'pypesq':
        results_file = os.path.join(workspace, 'evaluation',
                                    f'pesq_results_{model_name}.csv')
    elif library == 'pesq':
        results_file = os.path.join(workspace, 'evaluation',
                                    f'pesq2_results_{mode}_{model_name}.csv')
    else:
        results_file = os.path.join(workspace, 'evaluation',
                                    f'{library}_results_{model_name}.csv')

    if os.path.isfile(results_file) and not force:
        df = pd.read_csv(results_file)
        done_snrs = df['snr'].unique()
        left_snrs = [snr for snr in te_snr if snr not in done_snrs]
        if len(left_snrs) == 0:
            print('Score is already calculated')
            return df[df['snr'].isin(te_snr)]
        else:
            te_snr = left_snrs

    else:
        df = pd.DataFrame(columns=['filepath', 'snr', 'pesq'])

    speech_audio_cache = {}

    os.makedirs(os.path.dirname(results_file), exist_ok=True)

    for snr in te_snr:
        print(f'SNR: {snr}')

        # Calculate PESQ of all enhaced speech.
        if calc_mixed:
            enh_speech_dir = os.path.join(workspace, "mixed_audios",
                                          "spectrogram", "test",
                                          "%ddb" % int(snr))
        else:
            enh_speech_dir = os.path.join(workspace, "enh_wavs", "test",
                                          model_name, "%ddb" % int(snr))

        enh_paths = all_file_paths(enh_speech_dir)

        pendings = []
        with ProcessPoolExecutor(10) as pool:
            for (cnt, enh_path) in tqdm(enumerate(enh_paths),
                                        'Calculating PESQ score (submitting)'):
                # enh_path = os.path.join(enh_speech_dir, na)
                na = str(PurePath(enh_path).relative_to(enh_speech_dir))
                #print(cnt, na)

                if calc_mixed:
                    speech_na = '.'.join(na.split('.')[:-2])
                else:
                    speech_na = '.'.join(na.split('.')[:-3])

                speech_path = os.path.join(speech_dir, f"{speech_na}.wav")

                deg, sr = read_audio(enh_path)

                try:
                    ref = speech_audio_cache[speech_path]
                except KeyError:
                    ref, _ = read_audio(speech_path, target_fs=sr)
                    speech_audio_cache[speech_path] = ref

                if len(ref) < len(deg):
                    ref = np.pad(ref, (0, len(deg) - len(ref)))
                elif len(deg) < len(ref):
                    deg = np.pad(deg, (0, len(ref) - len(deg)))

                if library == 'pypesq':
                    pendings.append(
                        pool.submit(_calc_pesq, ref, deg, sr, na, snr))
                elif library == 'pesq':
                    pendings.append(
                        pool.submit(_calc_pesq2, ref, deg, sr, na, snr, mode))
                elif library == 'stoi':
                    pendings.append(
                        pool.submit(_calc_stoi, ref, deg, sr, na, snr))
                elif library == 'sisdr':
                    pendings.append(pool.submit(_calc_sisdr, ref, deg, na,
                                                snr))
                else:
                    raise ValueError(f'Invalid library: {library}')

            for pending in tqdm(pendings, 'Collecting pending jobs'):
                score, na, snr = pending.result()
                df.loc[len(df)] = [na, snr, score]

        df.to_csv(results_file, index=False)

    return df
Пример #18
0
def inference(args):
    workspace = args.workspace
    iter = args.iteration
    stack_num = args.stack_num
    filename = args.filename
    mini_num = args.mini_num
    visualize = args.visualize
    cuda = args.use_cuda and torch.cuda.is_available()
    print("cuda:", cuda)
    audio_type = 'speech'
    
    sample_rate = cfg.sample_rate
    fft_size = cfg.fft_size
    hop_size = cfg.hop_size
    window_type = cfg.window_type

    if window_type == 'hamming':
        window = np.hamming(fft_size)

    # Audio
    audio_dir = "/vol/vssp/msos/qk/workspaces/speech_enhancement/mixed_audios/spectrogram/test/0db"
    # audio_dir = "/user/HS229/qk00006/my_code2015.5-/python/pub_speech_enhancement/mixture2clean_dnn/workspace/mixed_audios/spectrogram/test/0db"
    names = os.listdir(audio_dir)
    
    speech_dir = "/vol/vssp/msos/qk/workspaces/speech_enhancement/timit_wavs/subtest"
    
    # Load model
    model_path = os.path.join(workspace, "models", filename, audio_type, "md_%d_iters.tar" % iter)
    n_freq = 257
    model = DNN(stack_num, n_freq)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    
    if cuda:
        model.cuda()
        
    # Load scalar
    scalar_path = os.path.join(workspace, "scalars", filename, "scalar.p")
    (mean_, std_) = cPickle.load(open(scalar_path, 'rb'))
    mean_ = move_data_to_gpu(mean_, cuda, volatile=True)
    std_ = move_data_to_gpu(std_, cuda, volatile=True)
    
    if mini_num > 0:
        n_every = len(names) / mini_num
    else:
        n_every = 1
        
    out_wav_dir = os.path.join(workspace, "enh_wavs", filename)
    pp_data.create_folder(out_wav_dir)
    
    dft = pp_data.DFT(fft_size, cuda)
        
    for (cnt, name) in enumerate(names):
        if cnt % n_every == 0:
            audio_path = os.path.join(audio_dir, name)
            (audio0, _) = pp_data.read_audio(audio_path, sample_rate)
            
            audio = pp_data.normalize(audio0)
            
            # Enframe
            frames = stft.enframe(audio, fft_size, hop_size)
            
            # Process data. 
            n_pad = (stack_num - 1) / 2
            x = pp_data.pad_with_border(frames, n_pad)
            x = pp_data.mat_2d_to_3d(x, stack_num, hop=1)
            
            pred_frames = forward(model, x, mean_, std_, cuda)
            
            pred_frames = pred_frames.data.cpu().numpy()
            
            # cola_constant = 0.5
            # seq = stft.overlap_add(pred_frames, hop_size, cola_constant)
            
            pred_frames *= window
            
            cola_constant = stft.get_cola_constant(hop_size, window)
            seq = stft.overlap_add(pred_frames, hop_size, cola_constant)
            seq = seq[0 : len(audio)]
            
            
            # Write out wav
            out_wav_path = os.path.join(out_wav_dir, name)
            pp_data.write_audio(out_wav_path, seq, sample_rate)
            print("Write out wav to: %s" % out_wav_path)
            
            if visualize:
                
                clean_audio_path = os.path.join(speech_dir, name.split('.')[0] + ".WAV")
                (clean_audio, _) = pp_data.read_audio(clean_audio_path, sample_rate)
                clean_audio = pp_data.normalize(clean_audio)
                clean_frames = stft.enframe(clean_audio, fft_size, hop_size)
                
                mix_sp = np.abs(np.fft.rfft(frames * window, norm='ortho'))
                enh_sp = np.abs(np.fft.rfft(pred_frames * window, norm='ortho'))
                clean_sp = np.abs(np.fft.rfft(clean_frames * window, norm='ortho'))
                
                K = 10
                fig, axs = plt.subplots(K/2,2, sharex=True)
                for k in range(K):
                    axs[k / 2, k % 2].plot(frames[k+100], color='y')
                    axs[k / 2, k % 2].plot(clean_frames[k+100], color='r')
                    axs[k / 2, k % 2].plot(pred_frames[k+100], color='b')
                plt.show()
                
                # import crash
                # asdf
                
                vmin = -5.
                vmax = 5.
                fig, axs = plt.subplots(3,1, sharex=True)
                axs[0].matshow(np.log(np.abs(mix_sp)).T, origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
                axs[1].matshow(np.log(np.abs(clean_sp)).T, origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
                axs[2].matshow(np.log(np.abs(enh_sp)).T, origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
                plt.show()