コード例 #1
0
ファイル: inference.py プロジェクト: yangyongjx/infocom2021
def main(args, hp):
    with torch.no_grad():
        model = VoiceFilter(hp).cuda()
        chkpt_model = torch.load(args.checkpoint_path)['model']
        model.load_state_dict(chkpt_model)
        model.eval()

        embedder = SpeechEmbedder(hp).cuda()
        chkpt_embed = torch.load(args.embedder_path)
        embedder.load_state_dict(chkpt_embed)
        embedder.eval()

        audio = Audio(hp)
        ref_wav, _ = librosa.load(args.reference_file, sr=16000)
        ref_mel = audio.get_mel(ref_wav)
        ref_mel = torch.from_numpy(ref_mel).float().cuda()
        dvec = embedder(ref_mel)
        dvec = dvec.unsqueeze(0)

        mixed_wav, _ = librosa.load(args.mixed_file, sr=16000)
        mixed_mag, mixed_phase = audio.wav2spec(mixed_wav)
        mixed_mag = torch.from_numpy(mixed_mag).float().cuda()

        mixed_mag = mixed_mag.unsqueeze(0)
        shadow_mag = model(mixed_mag, dvec)

        shadow_mag = shadow_mag[0].cpu().detach().numpy()
        recorded_mag = tensor_normalize(mixed_mag + shadow_mag)
        recorded_mag = recorded_mag[0].cpu().detach().numpy()
        recorded_wav = audio.spec2wav(recorded_mag, mixed_mag)

        os.makedirs(args.out_dir, exist_ok=True)
        out_path = os.path.join(args.out_dir, 'result.wav')
        librosa.output.write_wav(out_path, recorded_wav, sr=16000)
コード例 #2
0
            dvec_wav, _ = librosa.load(dvec_path, sr=16000)
            ref_mel = audio.get_mel(dvec_wav)
            ref_mel = torch.from_numpy(ref_mel).float().cuda()
            dvec = embedder(ref_mel)
            dvec = dvec.unsqueeze(0)  # (1, 256)

            mixed_wav, _ = librosa.load(mixed_wav_path, sr=16000)
            mixed_mag, mixed_phase = audio.wav2spec(mixed_wav)
            mixed_mag = torch.from_numpy(mixed_mag).float().cuda()

            mixed_mag = mixed_mag.unsqueeze(0)

            shadow_mag = model(mixed_mag, dvec)
            # shadow_mag.size() = [1, 301, 601]

            recorded_mag = tensor_normalize(mixed_mag + shadow_mag)
            recorded_mag = recorded_mag[0].cpu().detach().numpy()
            mixed_mag = mixed_mag[0].cpu().detach().numpy()

            shadow_mag = shadow_mag[0].cpu().detach().numpy()
            shadow_wav = audio.spec2wav(shadow_mag, mixed_phase)

            # scale is frequency pass to time domain, used on wav signal normalization
            recorded_wav1 = audio.spec2wav(recorded_mag, mixed_phase)  # path 1

            # mixed_Wav_path = '/data/our_dataset/test/13/babble/000001-mixed.wav'
            hide1 = mixed_wav_path[:-9] + 'hide1.wav'
            hide2 = mixed_wav_path[:-9] + 'hide2.wav'
            # purified3 = os.path.join(args.out_dir, 'result3.wav')

            # original mixed wav and expected_focused wav are not PCM, cannot be read by google cloud