예제 #1
0
파일: test.py 프로젝트: ZhihaoDU/du2020dan
def calc_func(noisy_dir_path):
    with torch.no_grad():
        debug_model = args.debug_model
        _method = method
        model_opts = json.load(
            open(os.path.join("configs/%s.json" % args.model_config), 'r'))
        gen_model = model_opts['gen_model_name']
        calc_target = get_target(args.target_type)

        device = torch.device("cuda")
        print_with_time("Loading model...")
        Generator, _ = get_model(gen_model, None)
        model = Generator(model_opts['gen_model_opts']).to(device)

        checkpoint = torch.load("Checkpoints/%s/checkpoint_%09d.pth" %
                                (_method, args.global_step))
        model.load_state_dict(checkpoint["generator"])
        # model.load_state_dict(checkpoint["enhancer"])
        model.eval()
        melbank = get_fft_mel_mat(512, 16000, 40)

        _method = "_".join([_method, str(args.global_step)])
        if debug_model:
            os.system('mkdir -p debug/%s' % _method)
        print_with_time("Start to enhance wav file in %s with method %s\n" %
                        (noisy_dir_path, _method))
        udir_path = "%s_%s" % (noisy_dir_path, _method)
        if not os.path.exists(udir_path):
            os.mkdir(udir_path)
        wav_scp = read_path_list(os.path.join(noisy_dir_path, "wav.scp"))
        if not debug_model:
            ark_file = open(os.path.join(udir_path, "feats.ark"), 'wb')
            scp_file = open(os.path.join(udir_path, "feats.scp"), 'w')
            key_len = wav_scp[0].find(' ')
            kaldi_holder = KaldiFeatHolder(key_len, 3000, 40)
            offset = key_len + 1
        enhanced_number = 0
        for it, (one_wav) in enumerate(wav_scp):
            wav_id, wav_path = one_wav.split(' ')
            sr, noisy_speech = wavfile.read(wav_path)
            if len(noisy_speech.shape) > 1:
                noisy_speech = np.mean(noisy_speech, 1)

            early50_path = wav_path.replace('.wav', '_early50.wav')
            sr, early50 = wavfile.read(early50_path)
            if len(early50.shape) > 1:
                early50 = np.mean(early50, 1)
            # as the training dataset, use "power_norm" to normalize the waveform to match the input of model.
            # c = np.sqrt(np.mean(np.square(noisy_speech)))
            c = calc_rescale_c(noisy_speech, args.rescale_method)
            noisy_speech = noisy_speech / c
            early50 = early50 / c

            noisy_fbank, noisy_mag = log_fbank(noisy_speech, False, True, True,
                                               None)
            early50_fbank, early50_mag = log_fbank(early50, False, True, True,
                                                   None)
            noise_fbank, noise_mag = log_fbank(noisy_speech - early50, False,
                                               True, True, None)
            if args.feature_domain == "mel":
                feat = torch.Tensor(noisy_fbank.T).unsqueeze(0).to(device)
                label = torch.Tensor(early50_fbank.T).unsqueeze(0).to(device)
                noise = torch.Tensor(noise_fbank.T).unsqueeze(0).to(device)
            else:
                feat = torch.Tensor(
                    np.square(noisy_mag).T).unsqueeze(0).to(device)
                label = torch.Tensor(
                    np.square(early50_mag).T).unsqueeze(0).to(device)
                noise = torch.Tensor(
                    np.square(noise_mag).T).unsqueeze(0).to(device)

            if args.target_type.lower() == "mapping_mag":
                predict = model.forward(feat.sqrt())
            else:
                predict = model.forward(torch.log(feat + opts['eps']))

            results = calc_target(feat, label, noise, predict, opts)
            enhanced = results["enhanced"]
            predict = results["predict"]
            target = results["target"]

            if args.feature_domain == "mel":
                enhanced_pow = 0
                enhanced_fbank = enhanced[0, :, :].cpu().numpy()
            else:
                enhanced_pow = enhanced[0, :, :].cpu().numpy()
                enhanced_fbank = np.matmul(enhanced_pow, melbank.T)

            log_enhanced_fbank = np.log(enhanced_fbank * (c**2.) + opts['eps'])

            if debug_model:
                sio.savemat(
                    "debug/%s/%s_%s" %
                    (_method, wav_id, wav_path.split('/')[-5]), {
                        'noisy_mag':
                        noisy_mag,
                        'noisy_fbank':
                        noisy_fbank,
                        'enhanced_mag':
                        np.sqrt(enhanced_pow).T,
                        'enhanced_fbank':
                        enhanced_fbank.T,
                        'early50_mag':
                        early50_mag,
                        'early50_fbank':
                        early50_fbank,
                        'predict':
                        predict[0, :, :].cpu().numpy().T,
                        'target':
                        target[0, :, :].cpu().numpy().T,
                        'log_enhanced_fbank':
                        log_enhanced_fbank.T,
                        'log_early50_fbank':
                        np.log(early50_fbank * (c**2.) + opts['eps']),
                        'c':
                        c
                    })
                if it >= 0:
                    return
            else:
                kaldi_holder.set_key(wav_id)
                kaldi_holder.set_value(log_enhanced_fbank)
                kaldi_holder.write_to(ark_file)
                scp_file.write("%s %s/feats.ark:%d\n" %
                               (wav_id, udir_path, offset))
                offset += kaldi_holder.get_real_len()

            enhanced_number += 1
            if enhanced_number % 40 == 0:
                print_with_time(
                    "Enhanced %5d(%6.2f%%) utterance" %
                    (enhanced_number, 100. * enhanced_number / len(wav_scp)))
        print_with_time("Enhanced %d utterance" % enhanced_number)
        ark_file.close()
        scp_file.close()
        post_process(noisy_dir_path, udir_path)
        print_with_time("Done %s." % _method)
예제 #2
0
def calc_func(noisy_dir_path):
    debug_model = True
    # nn.Module.dump_patches = True
    melbank = get_fft_mel_mat(512, 16000, 40)
    method = "Tan2018CRN_mag_early50"
    if debug_model:
        os.system('mkdir -p debug/%s' % method)
    device = torch.device("cuda")
    print_with_time("Loading model...")
    model = Generator(64, 256).to(device)
    checkpoint = torch.load(
        "Checkpoints/Tan2018CRN_mag_early50/checkpoint_000096336.pth")
    model.load_state_dict(checkpoint["generator"])
    model.eval()
    print_with_time("Start to enhance wav file in %s with method %s\n" %
                    (noisy_dir_path, method))
    udir_path = "%s_%s" % (noisy_dir_path, method)
    if not os.path.exists(udir_path):
        os.mkdir(udir_path)
    wav_scp = read_path_list(os.path.join(noisy_dir_path, "wav.scp"))
    if not debug_model:
        ark_file = open(os.path.join(udir_path, "feats.ark"), 'wb')
        scp_file = open(os.path.join(udir_path, "feats.scp"), 'w')
        key_len = wav_scp[0].find(' ')
        kaldi_holder = KaldiFeatHolder(key_len, 2000, 40)
        offset = key_len + 1
    enhanced_number = 0
    left_frame = 0
    right_frame = 0
    for it, (one_wav) in enumerate(wav_scp):
        wav_id, wav_path = one_wav.split(' ')
        sr, noisy_speech = wavfile.read(wav_path)
        # process binaural waves.
        if len(noisy_speech.shape) > 1:
            noisy_speech = np.mean(noisy_speech, 1)

        c = np.sqrt(np.mean(np.square(noisy_speech)))
        noisy_speech = noisy_speech / c

        n_noisy_feat, n_noisy_mag = log_fbank(noisy_speech, False, True, True,
                                              None)
        # n_log_noisy_power = np.log(n_noisy_mag ** 2 + 1e-12)
        # feat = torch.Tensor(n_log_noisy_power).to(device)
        feat = torch.Tensor(n_noisy_mag.T).unsqueeze(0).to(device)
        n, t, d = feat.size()
        if left_frame > 0 or right_frame > 0:
            pad_feats = F.pad(feat.unsqueeze(1),
                              (0, 0, left_frame, right_frame)).squeeze(1)
            ex_list = [
                pad_feats[:, i:i + t, :]
                for i in range(left_frame + 1 + right_frame)
            ]
            feat = torch.cat(ex_list, 2)

        with torch.no_grad():
            # mask = torch.sigmoid(model.forward(feat))
            # enhanced = mask * feat.pow(2.)
            enhanced_mag = model.forward(feat)
            enhanced = enhanced_mag.pow(2.)
            enhanced = (enhanced[0, :, :] * c**2.).cpu().numpy()
        enhanced_fbank = np.matmul(melbank, enhanced.T)
        log_enhanced_fbank = np.log(enhanced_fbank + 1e-12)

        if debug_model:
            early50_path = wav_path.replace('.wav', '_early50.wav')
            sr, early50 = wavfile.read(early50_path)
            if len(early50.shape) > 1:
                early50 = np.mean(early50, 1)
            early50 = early50 / c
            early50_feat, early50_mag = log_fbank(early50, False, True, True,
                                                  None)
            sio.savemat(
                "debug/%s/%s_%s" % (method, wav_id, wav_path.split('/')[-5]), {
                    'noisy_mag': n_noisy_mag,
                    'noisy_feat': n_noisy_feat,
                    'enhanced_mag': enhanced_mag[0, :, :].cpu().numpy().T,
                    'enhanced_feat': enhanced_fbank,
                    'early50_mag': early50_mag,
                    'early50_feat': early50_feat,
                })
            if it >= 5:
                return
        else:
            kaldi_holder.set_key(wav_id)
            kaldi_holder.set_value(log_enhanced_fbank.T)
            kaldi_holder.write_to(ark_file)
            scp_file.write("%s %s/feats.ark:%d\n" %
                           (wav_id, udir_path, offset))
            offset += kaldi_holder.get_real_len()

        enhanced_number += 1
        if enhanced_number % 40 == 0:
            print_with_time(
                "Enhanced %5d(%6.2f%%) utterance" %
                (enhanced_number, 100. * enhanced_number / len(wav_scp)))
    print_with_time("Enhanced %d utterance" % enhanced_number)
    ark_file.close()
    scp_file.close()
    post_process(noisy_dir_path, udir_path)
예제 #3
0
def calc_func(noisy_dir_path):
    # nn.Module.dump_patches = True
    melbank = get_fft_mel_mat(512, 16000, 40)
    method = "Tan2018CRN_mag_early50"
    device = torch.device("cuda")
    print_with_time("Loading model...")
    model = Generator(64, 256).to(device)
    checkpoint = torch.load("Checkpoints/Tan2018CRN_mag_early50/checkpoint_000096336.pth")
    model.load_state_dict(checkpoint["generator"])
    model.eval()
    print_with_time("Start to enhance wav file in %s with method %s\n" % (noisy_dir_path, method))
    udir_path = "%s_%s" % (noisy_dir_path, method)
    if not os.path.exists(udir_path):
        os.mkdir(udir_path)
    wav_scp = read_path_list(os.path.join(noisy_dir_path, "wav.scp"))
    ark_file = open(os.path.join(udir_path, "feats.ark"), 'wb')
    scp_file = open(os.path.join(udir_path, "feats.scp"), 'w')
    key_len = wav_scp[0].find(' ')
    kaldi_holder = KaldiFeatHolder(key_len, 2000, 40)
    offset = key_len + 1
    enhanced_number = 0
    left_frame = 0
    right_frame = 0
    for one_wav in wav_scp:
        wav_id, wav_path = one_wav.split(' ')
        # wav_path = one_wav
        # if "dB" in wav_path:
        #     wav_path = wav_path.replace(".wav", "_early50.wav")
        sr, noisy_speech = wavfile.read(wav_path)
        # sr, early50_speech = wavfile.read(wav_path.replace(".wav", "_early50.wav"))
        # sr, direct_speech = wavfile.read(wav_path.replace(".wav", "_direct_sound.wav"))
        # process binaural waves.
        if len(noisy_speech.shape) > 1:
            noisy_speech = np.mean(noisy_speech, 1)
        # if len(early50_speech.shape) > 1:
        #     early50_speech = np.mean(early50_speech, 1)
        # if len(direct_speech.shape) > 1:
        #     direct_speech = np.mean(direct_speech, 1)
        # noisy_speech = noisy_speech.astype(np.int16)
        # librosa.output.write_wav("/home/duzhihao/440c0201_mono.wav", noisy_speech.astype(np.float32), 16000, True)

        c = np.sqrt(np.mean(np.square(noisy_speech)))
        noisy_speech = noisy_speech / c
        # early50_speech = early50_speech / c

        n_noisy_feat, n_noisy_mag = log_fbank(noisy_speech, False, True, True, None)
        # n_noise_feat, n_noise_mag = log_fbank(noisy_speech - early50_speech, False, True, True, None)
        # n_early50_feat, n_early50_mag = log_fbank(early50_speech, False, True, True, None)
        # n_irm = n_early50_feat[0] / (n_early50_feat[0] + n_noise_feat[0])
        log_noisy_power = np.log(n_noisy_mag ** 2 + 1e-12)
        # log_early50_power = np.log(n_early50_mag ** 2 + 1e-12)
        # log_mask = np.clip((log_early50_power+10) / (log_noisy_power+10), 0, 1)
        # log_enhanced_power = (log_noisy_power+10) * log_mask - 10
        # enhanced_power = np.exp(log_enhanced_power)
        # n_direct_feat = log_fbank(direct_speech, False, True, False, None)
        # n_noisy_feat = log_fbank_for_wu(noisy_speech, False, True, False, None)
        # log_noisy_feat = np.log(n_noisy_feat[0].T)
        # log_noisy_feat[np.isnan(log_noisy_feat)] = 0.
        # log_noisy_feat[np.isinf(log_noisy_feat)] = 0.
        # log_noisy_feat = log_noisy_feat[np.newaxis, :, :]
        # log_noisy_feat = torch.Tensor(log_noisy_feat).to(device)
        # n, t, d = log_noisy_feat.size()
        # if left_frame > 0 or right_frame > 0:
        #     pad_feats = F.pad(log_noisy_feat.unsqueeze(1), (0, 0, left_frame, right_frame)).squeeze(1)
        #     ex_list = [pad_feats[:, i:i+t, :] for i in range(left_frame+1+right_frame)]
        #     log_noisy_feat = torch.cat(ex_list, 2)

        # with torch.no_grad():
        #     feat_ex_list = [log_noisy_feat[:, :, i * 40:(i + 1) * 40].unsqueeze(1) for i in range(left_frame + 1 + right_frame)]
        #     nn_input = torch.cat(feat_ex_list, 1)
        #     mask = torch.sigmoid(model.forward(log_noisy_feat))
        #     mask = model.forward(log_noisy_feat)
        #     enhanced_feat = mask[0, :, :].cpu().numpy() * n_noisy_feat[0].T
        #     log_enhanced_feat = np.log(enhanced_feat)
        #     log_enhanced_feat = log_noisy_feat[:, :, 40*left_frame:40*(left_frame+1)] # * mask
        #     log_enhanced_feat = model.forward(log_noisy_feat)
        #     log_enhanced_feat = log_enhanced_feat * (log_pow_max - log_pow_min) + log_pow_min
        #     log_enhanced_feat = log_enhanced_feat.cpu().numpy()

        sio.savemat(method + "_chime2", {'noisy_mag': n_noisy_mag,
                                         'irm': log_mask,
                                         'early50_mag': n_early50_mag,
                                         'enhanced_mag': np.sqrt(enhanced_power),
                                         # 'direct_feat': n_direct_feat[0]
                                         })
        return

        kaldi_holder.set_key(wav_id)
        # kaldi_holder.set_value(log_enhanced_feat[0, :, :].cpu().numpy())
        # kaldi_holder.set_value(np.log((n_early50_feat[0]).T + 1e-8))
        kaldi_holder.set_value(np.log((n_noisy_feat[0] * n_irm).T + 1e-8))
        kaldi_holder.write_to(ark_file)
        scp_file.write("%s %s/feats.ark:%d\n" % (wav_id, udir_path, offset))
        offset += kaldi_holder.get_real_len()
        enhanced_number += 1
        if enhanced_number % 40 == 0:
            print_with_time(
                "Enhanced %5d(%6.2f%%) utterance" % (enhanced_number, 100. * enhanced_number / len(wav_scp)))
    print_with_time("Enhanced %d utterance" % enhanced_number)
    ark_file.close()
    scp_file.close()
    post_process(noisy_dir_path, udir_path)