コード例 #1
0
ファイル: synthesize.py プロジェクト: stallboy/FastSpeech2
def synthesize(model, waveglow, melgan, text, sentence, prefix=''):
    sentence = sentence[:200]  # long filename will result in OS Error

    src_len = torch.from_numpy(np.array([text.shape[1]])).to(device)

    with torch.no_grad():
        mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(text, src_len)

    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    mel = mel[0].cpu().transpose(0, 1).detach()
    mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
    f0_output = f0_output[0].detach().cpu().numpy()
    energy_output = energy_output[0].detach().cpu().numpy()

    if not os.path.exists(hp.test_path):
        os.makedirs(hp.test_path)

    Audio.tools.inv_mel_spec(mel_postnet, os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence)))
    if waveglow is not None:
        utils.waveglow_infer(mel_postnet_torch, waveglow,
                             os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))
    if melgan is not None:
        utils.melgan_infer(mel_postnet_torch, melgan,
                           os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))

    utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'],
                    filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))
コード例 #2
0
def synthesize(model, waveglow, melgan, text, sentence, prefix=''):
    sentence = sentence[:10]  # long filename will result in OS Error

    mean_mel, std_mel = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "mel_stat.npy")),
                                     dtype=torch.float).to(device)
    mean_f0, std_f0 = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "f0_stat.npy")),
                                   dtype=torch.float).to(device)
    mean_energy, std_energy = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "energy_stat.npy")),
                                           dtype=torch.float).to(device)

    mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1)
    mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1)
    mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape(
        1, -1)

    src_len = torch.from_numpy(np.array([text.shape[1]])).to(device)

    mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(
        text, src_len)

    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    f0_output = f0_output[0]
    energy_output = energy_output[0]

    mel_torch = utils.de_norm(mel_torch.transpose(1, 2), mean_mel, std_mel)
    mel_postnet_torch = utils.de_norm(mel_postnet_torch.transpose(1, 2),
                                      mean_mel, std_mel).transpose(1, 2)
    f0_output = utils.de_norm(f0_output, mean_f0,
                              std_f0).squeeze().detach().cpu().numpy()
    energy_output = utils.de_norm(energy_output, mean_energy,
                                  std_energy).squeeze().detach().cpu().numpy()

    if not os.path.exists(hp.test_path):
        os.makedirs(hp.test_path)

    Audio.tools.inv_mel_spec(
        mel_postnet_torch[0],
        os.path.join(hp.test_path,
                     '{}_griffin_lim_{}.wav'.format(prefix, sentence)))
    if waveglow is not None:
        utils.waveglow_infer(
            mel_postnet_torch, waveglow,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))
    if melgan is not None:
        utils.melgan_infer(
            mel_postnet_torch, melgan,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))

    utils.plot_data([
        (mel_postnet_torch[0].detach().cpu().numpy(), f0_output, energy_output)
    ], ['Synthesized Spectrogram'],
                    filename=os.path.join(hp.test_path,
                                          '{}_{}.png'.format(prefix,
                                                             sentence)))
コード例 #3
0
def synthesize(model, waveglow, py_text_seq,  cn_text_seq, duration_control=1.0,prefix=''):
    #sentence = sentence[:200]  # long filename will result in OS Error

    src_len = torch.from_numpy(np.array([py_text_seq.shape[1]])).to(device)
    
    mel, mel_postnet, log_duration_output, _, _, mel_len = model(
        py_text_seq, src_len, hz_seq=cn_text_seq,d_control=duration_control)
   # print(log_duration_output)
    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    mel = mel[0].cpu().transpose(0, 1).detach()
    mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
    dst_name = os.path.join(
        '/dev/shm/', '{}-out.wav'.format(prefix))
    utils.waveglow_infer(mel_postnet_torch+hp.mel_mean, waveglow, dst_name)
    return dst_name
コード例 #4
0
def synthesize(model, text, sentence, prefix=''):
    src_pos = np.array([i + 1 for i in range(text.shape[1])])
    src_pos = np.stack([src_pos])
    src_pos = torch.from_numpy(src_pos).to(device).long()

    model.to(device)
    mel, mel_postnet, duration_output, f0_output, energy_output = model(
        text, src_pos)
    model.to('cpu')

    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    mel = mel[0].cpu().transpose(0, 1).detach()
    mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
    f0_output = f0_output[0].detach().cpu().numpy()
    energy_output = energy_output[0].detach().cpu().numpy()

    if not os.path.exists(hp.test_path):
        os.makedirs(hp.test_path)

    Audio.tools.inv_mel_spec(
        mel_postnet,
        os.path.join(hp.test_path,
                     '{}_griffin_lim_{}.wav'.format(prefix, sentence)))

    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
        utils.melgan_infer(
            mel_postnet_torch, melgan,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)
        utils.waveglow_infer(
            mel_postnet_torch, waveglow,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))

    utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)],
                    ['Synthesized Spectrogram'],
                    filename=os.path.join(hp.test_path,
                                          '{}_{}.png'.format(prefix,
                                                             sentence)))
コード例 #5
0
ファイル: synthesize.py プロジェクト: cadia-lvl/FastSpeech2
def synthesize(model,
               waveglow,
               melgan,
               text,
               sentence,
               prefix='',
               duration_control=1.0,
               pitch_control=1.0,
               energy_control=1.0,
               output_dir=None):
    src_len = torch.from_numpy(np.array([text.shape[1]])).to(device)

    mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(
        text,
        src_len,
        d_control=duration_control,
        p_control=pitch_control,
        e_control=energy_control)

    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    mel = mel[0].cpu().transpose(0, 1).detach()
    mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
    f0_output = f0_output[0].detach().cpu().numpy()
    energy_output = energy_output[0].detach().cpu().numpy()

    if not output_dir:
        output_dir = hp.test_path

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    gl_fname = '{}_griffin_lim.wav'.format(prefix)
    Audio.tools.inv_mel_spec(mel_postnet, os.path.join(hp.test_path, gl_fname))

    vocoder_fname = '{}_{}.wav'.format(prefix, hp.vocoder)
    if waveglow is not None:
        utils.waveglow_infer(mel_postnet_torch, waveglow,
                             os.path.join(output_dir, vocoder_fname))
    if melgan is not None:
        utils.melgan_infer(mel_postnet_torch, melgan,
                           os.path.join(output_dir, vocoder_fname))
コード例 #6
0
def synthesize(model,
               waveglow,
               text,
               idx,
               prefix='',
               duration_control=1.0,
               pitch_control=1.0,
               energy_control=1.0):
    t = time()
    src_len = torch.from_numpy(np.array([text.shape[1]])).to(device)
    mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(
        text,
        src_len,
        d_control=duration_control,
        p_control=pitch_control,
        e_control=energy_control)

    # mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    # mel = mel[0].cpu().transpose(0, 1).detach()
    # mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
    # f0_output = f0_output[0].detach().cpu().numpy()
    # energy_output = energy_output[0].detach().cpu().numpy()

    if not os.path.exists(args.test_path):
        os.makedirs(args.test_path)
    # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(
    #     hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, name)))
    t1 = time() - t
    if waveglow is not None:
        utils.waveglow_infer(
            mel_postnet_torch, waveglow,
            os.path.join(args.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, idx)))
    t2 = time() - t
    print('{}: time FS: {} (s) time {}: {}'.format(idx, t1, hp.vocoder,
                                                   t2 - t1))
コード例 #7
0
ファイル: evaluate.py プロジェクト: xushengyuan/FastSing2
def evaluate(model, step):
    torch.manual_seed(0)
    
    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(dataset, batch_size=hp.batch_size*4, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, )
    
    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    if hp.vocoder=='WORLD':
        ap = []
        sp_l = []
        sp_p_l = []
    else:
        mel_l = []
        mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            id_ = data_of_batch["id"]
            condition = torch.from_numpy(data_of_batch["condition"]).long().to(device)
            mel_refer = torch.from_numpy(data_of_batch["mel_refer"]).float().to(device)
            if hp.vocoder=='WORLD':
                ap_target = torch.from_numpy(data_of_batch["ap_target"]).float().to(device)
                sp_target = torch.from_numpy(data_of_batch["sp_target"]).float().to(device)
            else:
                mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).long().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device)
            #print(D,log_D)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)
        
            with torch.no_grad():
                # Forward
                if hp.vocoder=='WORLD':
#                     print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape)
                    ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output,energy_output, src_mask, ap_mask,sp_mask ,variance_adaptor_output,decoder_output= model(
                    condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len)
                
                    ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, 
                        sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, 
                        sp_target=sp_target,src_mask=src_mask, ap_mask=ap_mask,sp_mask=sp_mask)
                    total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss
                else:
                    mel_output, mel_postnet_output, log_duration_output, f0_output,energy_output, src_mask, mel_mask, _ = model(
                    condition,mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len)
                    
                    mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output,
                        mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask)
                    total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss
                     
                t_l = total_loss.item()
                if hp.vocoder=='WORLD':
                    ap_l = ap_loss.item()
                    sp_l = sp_loss.item()
                    sp_p_l = sp_postnet_loss.item()
                else:
                    m_l = mel_loss.item()
                    m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                
    
                # Run vocoding and plotting spectrogram only when the vocoder is defined
                for k in range(len(mel_target)):
                    basename = id_[k]
                    gt_length = mel_len[k]
                    out_length = out_mel_len[k]

                    mel_target_torch = mel_target[k:k+1, :gt_length].transpose(1, 2).detach()                        
                    mel_postnet_torch = mel_postnet_output[k:k+1, :out_length].transpose(1, 2).detach()

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                        utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                        utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
                    elif hp.vocoder=='WORLD':
                        utils.world_infer(mel_postnet_torch.numpy(),f0_output, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
                        utils.world_infer(mel_target_torch.numpy(),f0,  os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                    np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(basename)), mel_postnet.numpy())

                    f0_ = f0[k, :gt_length].detach().cpu().numpy()
                    energy_ = energy[k, :gt_length].detach().cpu().numpy()
                    f0_output_ = f0_output[k, :out_length].detach().cpu().numpy()
                    energy_output_ = energy_output[k, :out_length].detach().cpu().numpy()

                    utils.plot_data([(mel_postnet[0].numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], 
                        ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename)))
                    idx += 1
                
            current_step += 1            

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    
    if hp.vocoder=='WORLD':
        ap_l = sum(ap_l) / len(ap_l)
        sp_l = sum(sp_l) / len(sp_l)
        sp_p_l = sum(sp_p_l) / len(sp_p_l) 
    else:
        mel_l = sum(mel_l) / len(mel_l)
        mel_p_l = sum(mel_p_l) / len(mel_p_l) 
                    
    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
        f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")

    return d_l, f_l, e_l, mel_l, mel_p_l
コード例 #8
0
def evaluate(model, step, vocoder=None):

    # Get dataset
    print('evaluating..')
          
        
    # Get dataset
    if hp.with_hanzi:
        dataset = Dataset(filename_py="val_pinyin.txt",vocab_file_py = 'vocab_pinyin.txt',
                     filename_hz = "val_hanzi.txt",
                     vocab_file_hz = 'vocab_hanzi.txt')
        py_vocab_size = len(dataset.py_vocab)
        hz_vocab_size = len(dataset.hz_vocab)


    else:
        dataset = Dataset(filename_py="val_pinyin.txt",vocab_file_py = 'vocab_pinyin.txt',
                     filename_hz = None,
                     vocab_file_hz = None)
        py_vocab_size = len(dataset.py_vocab)
        hz_vocab_size = None

    
    loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False,
                        collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, )



    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    bar = tqdm.tqdm_notebook(total=len(dataset)//hp.batch_size)

    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            bar.update(1)
            
            # Get Data
            id_ = data_of_batch["id"]
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            if hp.with_hanzi:
                hz_text = torch.from_numpy(
                data_of_batch["hz_text"]).long().to(device)
            else:
                hz_text = None
                
            
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            with torch.no_grad():
                mel_output, mel_postnet_output, log_duration_output, src_mask, mel_mask, out_mel_len = model(
                src_seq=text, src_len=src_len, hz_seq=hz_text,mel_len=mel_len,
                d_target=D, max_src_len=max_src_len, max_mel_len=max_mel_len)
                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss = Loss(
                    log_duration_output, log_D, mel_output, mel_postnet_output, mel_target-hp.mel_mean, ~src_mask, ~mel_mask)

                d_l.append(d_loss.item())
               # f_l.append(f_loss.item())
               # e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                if vocoder is not None:
                    # Run vocoding and plotting spectrogram only when the vocoder is defined
                    for k in range(len(mel_target)):
                        basename = id_[k]
                        gt_length = mel_len[k]
                        out_length = out_mel_len[k]

                        mel_target_torch = mel_target[k:k+1,
                                                      :gt_length].transpose(1, 2).detach()
                        mel_target_ = mel_target[k, :gt_length].cpu(
                        ).transpose(0, 1).detach()

                        mel_postnet_torch = mel_postnet_output[k:k +
                                                               1, :out_length].transpose(1, 2).detach()
                        mel_postnet = mel_postnet_output[k, :out_length].cpu(
                        ).transpose(0, 1).detach()

                        if hp.vocoder == 'melgan':
                            utils.melgan_infer(mel_target_torch, vocoder, os.path.join(
                                hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                            utils.melgan_infer(mel_postnet_torch+hp.mel_mean, vocoder, os.path.join(
                                hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
                        elif hp.vocoder == 'waveglow':
                            utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(
                                hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                            utils.waveglow_infer(mel_postnet_torch+hp.mel_mean, vocoder, os.path.join(
                                hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))

                       # np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(
                           # basename)), mel_postnet.numpy()+hp.mel_mean)

#                         f0_ = f0[k, :gt_length].detach().cpu().numpy()
#                         energy_ = energy[k, :gt_length].detach().cpu().numpy()
#                         f0_output_ = f0_output[k,
#                                                :out_length].detach().cpu().numpy()
#                         energy_output_ = energy_output[k, :out_length].detach(
#                         ).cpu().numpy()

                        utils.plot_data([mel_postnet.numpy()+hp.mel_mean,mel_target_.numpy()],
                                        ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename)))
                        idx += 1

            current_step += 1

    d_l = sum(d_l) / len(d_l)
   # f_l = sum(f_l) / len(f_l)
   # e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

 
    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    #str3 = "F0 Loss: {}".format(f_l)
    #  str4 = "Energy Loss: {}".format(e_l)
    str4 = "Mel Loss: {}".format(mel_l)
    str5 = "Mel Postnet Loss: {}".format(mel_p_l)
    str6 = "total Loss: {}".format(mel_p_l+mel_l+d_l)

    print("\n" + str1)
    print(str2)
    # print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
       # f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")
    return d_l,  mel_l, mel_p_l
コード例 #9
0
def main(args):
    torch.manual_seed(0)

    # Get device
    # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoaderX(dataset,
                         batch_size=hp.batch_size * 4,
                         shuffle=True,
                         collate_fn=dataset.collate_fn,
                         drop_last=True,
                         num_workers=8)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    #     model = FastSpeech2().to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    train_logger = SummaryWriter(log_dir='log/train/' + current_time)
    val_logger = SummaryWriter(log_dir='log/validation/' + current_time)
    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()
    current_step0 = 0
    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * len(batchs) + j + args.restore_step + \
                    epoch * len(loader)*len(batchs) + 1
                # Get Data
                condition = torch.from_numpy(
                    data_of_batch["condition"]).long().to(device)
                mel_refer = torch.from_numpy(
                    data_of_batch["mel_refer"]).float().to(device)
                if hp.vocoder == 'WORLD':
                    ap_target = torch.from_numpy(
                        data_of_batch["ap_target"]).float().to(device)
                    sp_target = torch.from_numpy(
                        data_of_batch["sp_target"]).float().to(device)
                else:
                    mel_target = torch.from_numpy(
                        data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                #print(D,log_D)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                if hp.vocoder == 'WORLD':
                    #                     print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape)
                    ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model(
                        condition, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        ap_output=ap_output,
                        sp_output=sp_output,
                        sp_postnet_output=sp_postnet_output,
                        ap_target=ap_target,
                        sp_target=sp_target,
                        src_mask=src_mask,
                        ap_mask=ap_mask,
                        sp_mask=sp_mask)
                    total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss
                else:
                    mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                        condition, mel_refer, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        log_D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        mel_output=mel_output,
                        mel_postnet_output=mel_postnet_output,
                        mel_target=mel_target,
                        src_mask=~src_mask,
                        mel_mask=~mel_mask)
                    total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                if hp.vocoder == 'WORLD':
                    ap_l = ap_loss.item()
                    sp_l = sp_loss.item()
                    sp_p_l = sp_postnet_loss.item()
                else:
                    m_l = mel_loss.item()
                    m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                #                 with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
                #                     f_total_loss.write(str(t_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
                #                     f_mel_loss.write(str(m_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                #                     f_mel_postnet_loss.write(str(m_p_l)+"\n")
                #                 with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
                #                     f_d_loss.write(str(d_l)+"\n")
                #                 with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
                #                     f_f_loss.write(str(f_l)+"\n")
                #                 with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
                #                     f_e_loss.write(str(e_l)+"\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}],Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    if hp.vocoder == 'WORLD':
                        str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l)
                    else:
                        str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "T:{:.1f}s,ETA:{:.1f}s.".format(
                        (Now - Start) / (current_step - current_step0),
                        (total_step - current_step) * np.mean(Time))

                    print("" + str1 + str2 + str3 + '')

                    #                     with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                    #                         f_log.write(str1 + "\n")
                    #                         f_log.write(str2 + "\n")
                    #                         f_log.write(str3 + "\n")
                    #                         f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    if hp.vocoder == 'WORLD':
                        train_logger.add_scalar('Loss/ap_loss', ap_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_loss', sp_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l,
                                                current_step)
                    else:
                        train_logger.add_scalar('Loss/mel_loss', m_l,
                                                current_step)
                        train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                                current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                if current_step % hp.save_step == 0 or current_step == 20:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0 or current_step == 5:
                    length = mel_len[0].item()

                    if hp.vocoder == 'WORLD':
                        ap_target_torch = ap_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        ap_torch = ap_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_target_torch = sp_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        sp_torch = sp_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_postnet_torch = sp_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                    else:
                        mel_target_torch = mel_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_target = mel_target[
                            0, :length].detach().cpu().transpose(0, 1)
                        mel_torch = mel_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        mel = mel_output[0, :length].detach().cpu().transpose(
                            0, 1)
                        mel_postnet_torch = mel_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_postnet = mel_postnet_output[
                            0, :length].detach().cpu().transpose(0, 1)
#                     Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step)))
#                     Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'WORLD':
                        #     ap=np.swapaxes(ap,0,1)
                        #     sp=np.swapaxes(sp,0,1)
                        wav = utils.world_infer(
                            np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1),
                            np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0,
                                        1), f0_output)
                        sf.write(
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)
                        wav = utils.world_infer(
                            np.swapaxes(ap_target_torch[0].cpu().numpy(), 0,
                                        1),
                            np.swapaxes(sp_target_torch[0].cpu().numpy(), 0,
                                        1), f0)
                        sf.write(
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)

                    utils.plot_data([
                        (sp_postnet_torch[0].cpu().numpy(), f0_output,
                         energy_output),
                        (sp_target_torch[0].cpu().numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                    plt.matshow(sp_postnet_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'sp_postnet_{}.png'.format(current_step)))
                    plt.matshow(ap_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'ap_{}.png'.format(current_step)))
                    plt.matshow(
                        variance_adaptor_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step)))
                    #                     plt.matshow(decoder_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step)))

                    plt.cla()
                    fout = open(
                        os.path.join(synth_path,
                                     'D_{}.txt'.format(current_step)), 'w')
                    fout.write(
                        str(log_duration_output[0].detach().cpu().numpy()) +
                        '\n')
                    fout.write(str(D[0].detach().cpu().numpy()) + '\n')
                    fout.write(
                        str(condition[0, :, 2].detach().cpu().numpy()) + '\n')
                    fout.close()


#                 if current_step % hp.eval_step == 0 or current_step==20:
#                     model.eval()
#                     with torch.no_grad():

#                         if hp.vocoder=='WORLD':
#                             d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)
#                         else:
#                             d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + m_l + m_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_loss', m_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)

#                     model.train()
#                 if current_step%10==0:
#                     print(energy_output[0],energy[0])
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
コード例 #10
0
ファイル: evaluate.py プロジェクト: stallboy/FastSpeech2
def evaluate(model, step, vocoder=None):
    torch.manual_seed(0)

    if not os.path.exists(hp.eval_path):
        os.makedirs(hp.eval_path)

    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(dataset, batch_size=hp.batch_size ** 2, shuffle=False, collate_fn=dataset.collate_fn,
                        drop_last=False, num_workers=0, )

    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):

        for j, data_of_batch in enumerate(batchs):
            if j == 1:
                break
            # Get Data
            id_ = data_of_batch["id"]
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            with torch.no_grad():
                # Forward
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, out_mel_len = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output,
                    mel_target, ~src_mask, ~mel_mask)

                d_l.append(d_loss.item())
                f_l.append(f_loss.item())
                e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                if vocoder is not None:
                    # Run vocoding and plotting spectrogram only when the vocoder is defined
                    for k in range(len(mel_target)):
                        basename = id_[k]
                        gt_phone_length = src_len[k]
                        gt_length = mel_len[k]
                        out_length = out_mel_len[k]

                        mel_target_torch = mel_target[k:k + 1, :gt_length].transpose(1, 2).detach()
                        mel_target_ = mel_target[k, :gt_length].cpu().transpose(0, 1).detach()

                        mel_postnet_torch = mel_postnet_output[k:k + 1, :out_length].transpose(1, 2).detach()
                        mel_postnet = mel_postnet_output[k, :out_length].cpu().transpose(0, 1).detach()

                        if hp.vocoder == 'melgan':
                            utils.melgan_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path,
                                                                                       '{}_ground-truth_{}.wav'.format(
                                                                                           basename, hp.vocoder)))
                            utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path,
                                                                                        '{}_eval_{}.wav'.format(
                                                                                            basename, hp.vocoder)))
                        elif hp.vocoder == 'waveglow':
                            utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path,
                                                                                         'ground-truth_{}_{}.wav'.format(
                                                                                             basename, hp.vocoder)))
                            utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path,
                                                                                          'eval_{}_{}.wav'.format(
                                                                                              basename, hp.vocoder)))

                        np.save(os.path.join(hp.eval_path, '{}_eval_mel.npy'.format(basename)), mel_postnet.numpy())

                        f0_ = f0[k, :gt_length].detach().cpu().numpy()
                        energy_ = energy[k, :gt_length].detach().cpu().numpy()
                        f0_output_ = f0_output[k, :out_length].detach().cpu().numpy()
                        energy_output_ = energy_output[k, :out_length].detach().cpu().numpy()

                        d_ = D[k, :gt_phone_length].detach().cpu().numpy()
                        log_d_output_ = log_duration_output[k, :gt_phone_length].detach().cpu().numpy()
                        d_output_ = np.exp(log_d_output_) - hp.log_offset

                        utils.plot_data(
                            [(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)],
                            ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'],
                            filename=os.path.join(hp.eval_path, '{}_eval.png'.format(basename)),
                            )
                        utils.plot_duration(
                            [d_output_,  d_],
                            ["Synthesized Duration", "Ground-Truth"],
                            filename=os.path.join(hp.eval_path, '{}_eval_dur.png'.format(basename))
                        )
                        idx += 1

            current_step += 1

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
        f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")

    return d_l, f_l, e_l, mel_l, mel_p_l
コード例 #11
0
ファイル: train.py プロジェクト: zhangsong427/FastSpeech2
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    
    # Get dataset
    dataset = Dataset("train.txt") 
    loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, 
        collate_fn=dataset.collate_fn, drop_last=True, num_workers=0)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay = hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device) 
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(os.path.join(
            checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()
    
    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i*hp.batch_size + j + args.restore_step + epoch*len(loader)*hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                text = torch.from_numpy(data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(data_of_batch["energy"]).float().to(device)
                mel_pos = torch.from_numpy(data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(data_of_batch["src_pos"]).long().to(device)
                src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device)
                max_len = max(data_of_batch["mel_len"]).astype(np.int16)
                
                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D, f0, energy)
                
                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, src_len, mel_len)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss
                 
                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
                    f_total_loss.write(str(t_l)+"\n")
                with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l)+"\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l)+"\n")
                with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
                    f_d_loss.write(str(d_l)+"\n")
                with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
                    f_f_loss.write(str(f_l)+"\n")
                with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
                    f_e_loss.write(str(e_l)+"\n")
                 
                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                
                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch+1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now-Start), (total_step-current_step)*np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    
                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    logger.add_scalars('Loss/total_loss', {'training': t_l}, current_step)
                    logger.add_scalars('Loss/mel_loss', {'training': m_l}, current_step)
                    logger.add_scalars('Loss/mel_postnet_loss', {'training': m_p_l}, current_step)
                    logger.add_scalars('Loss/duration_loss', {'training': d_l}, current_step)
                    logger.add_scalars('Loss/F0_loss', {'training': f_l}, current_step)
                    logger.add_scalars('Loss/energy_loss', {'training': e_l}, current_step)
                
                if current_step % hp.save_step == 0:
                    torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
                    )}, os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[0, :length].detach().cpu().transpose(0, 1)
                    Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step)))
                    Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step)))
                    
                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(mel_torch, melgan, os.path.join(hp.test_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder)))
                        utils.melgan_infer(mel_postnet_torch, melgan, os.path.join(hp.test_path, 'step_{}_postnet_{}.wav'.format(current_step, hp.vocoder)))
                        utils.melgan_infer(mel_target_torch, melgan, os.path.join(hp.test_path, 'step_{}_ground-truch_{}.wav'.format(current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(mel_torch, waveglow, os.path.join(hp.test_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder)))
                        utils.waveglow_infer(mel_postnet_torch, waveglow, os.path.join(hp.test_path, 'step_{}_postnet_{}.wav'.format(current_step, hp.vocoder)))
                        utils.waveglow_infer(mel_target_torch, waveglow, os.path.join(hp.test_path, 'step_{}_ground-truch_{}.wav'.format(current_step, hp.vocoder)))
                    
                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[0, :length].detach().cpu().numpy()
                    
                    utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy)], 
                        ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(synth_path, 'step_{}.png'.format(current_step)))
                
                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l
                        
                        logger.add_scalars('Loss/total_loss', {'validation': t_l}, current_step)
                        logger.add_scalars('Loss/mel_loss', {'validation': m_l}, current_step)
                        logger.add_scalars('Loss/mel_postnet_loss', {'validation': m_p_l}, current_step)
                        logger.add_scalars('Loss/duration_loss', {'validation': d_l}, current_step)
                        logger.add_scalars('Loss/F0_loss', {'validation': f_l}, current_step)
                        logger.add_scalars('Loss/energy_loss', {'validation': e_l}, current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(
                        Time, [i for i in range(len(Time))], axis=None)
                    Time = np.append(Time, temp_value)
コード例 #12
0
def synthesize(model, waveglow, melgan, text, sentence, prefix=''):
    sentence = sentence[:150]  # long filename will result in OS Error
    src_len = torch.from_numpy(np.array([text.shape[1]])).to(device)

    # create dir
    if not os.path.exists(os.path.join(hp.test_path, hp.dataset)):
        os.makedirs(os.path.join(hp.test_path, hp.dataset))

    # generate wav
    if hp.use_spk_embed:
        hp.batch_size = 3
        # select speakers
        # TODO
        spk_ids = torch.tensor(
            list(inv_spk_table.keys())[5:5 + hp.batch_size]).to(
                torch.int64).to(device)
        text = text.repeat(hp.batch_size, 1)
        src_len = src_len.repeat(hp.batch_size)
        mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(
            text, src_len, speaker_ids=spk_ids)

        mel_mask = get_mask_from_lengths(mel_len, None)
        mel_mask = mel_mask.unsqueeze(-1).expand(mel_postnet.size())
        silence = (torch.ones(mel_postnet.size()) * -5).to(device)
        mel = torch.where(~mel_mask, mel, silence)
        mel_postnet = torch.where(~mel_mask, mel_postnet, silence)

        mel_torch = mel.transpose(1, 2).detach()
        mel_postnet_torch = mel_postnet.transpose(1, 2).detach()

        if waveglow is not None:
            wavs = utils.waveglow_infer_batch(mel_postnet_torch, waveglow)
        if melgan is not None:
            wavs = utils.melgan_infer_batch(mel_postnet_torch, melgan)

        for i, spk_id in enumerate(spk_ids):
            spker = inv_spk_table[int(spk_id)]
            mel_postnet_i = mel_postnet[i].cpu().transpose(0, 1).detach()
            f0_i = f0_output[i].detach().cpu().numpy()
            energy_i = energy_output[i].detach().cpu().numpy()
            mel_mask_i = mel_mask[i]
            wav_i = wavs[i]

            # output
            base_dir_i = os.path.join(hp.test_path, hp.dataset,
                                      "step {}".format(args.step), spker)
            os.makedirs(base_dir_i, exist_ok=True)
            path_i = os.path.join(
                base_dir_i, '{}_{}_{}.wav'.format(prefix, hp.vocoder,
                                                  sentence))
            soundfile.write(path_i, wav_i, hp.sampling_rate)
            utils.plot_data([(mel_postnet_i.numpy(), f0_i, energy_i)],
                            ['Synthesized Spectrogram'],
                            filename=os.path.join(
                                base_dir_i,
                                '{}_{}.png'.format(prefix, sentence)))

    else:
        spk_ids = None
        mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(
            text, src_len, speaker_ids=spk_ids)
        mel_torch = mel.transpose(1, 2).detach()
        mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
        mel = mel[0].cpu().transpose(0, 1).detach()
        mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
        f0_output = f0_output[0].detach().cpu().numpy()
        energy_output = energy_output[0].detach().cpu().numpy()

        Audio.tools.inv_mel_spec(
            mel_postnet,
            os.path.join(hp.test_path,
                         '{}_griffin_lim_{}.wav'.format(prefix, sentence)))
        if waveglow is not None:
            utils.waveglow_infer(
                mel_postnet_torch, waveglow,
                os.path.join(
                    hp.test_path, hp.dataset,
                    '{}_{}_{}_{}.wav'.format(prefix, hp.vocoder, spker,
                                             sentence)))
        if melgan is not None:
            utils.melgan_infer(
                mel_postnet_torch, melgan,
                os.path.join(
                    hp.test_path, hp.dataset,
                    '{}_{}_{}_{}.wav'.format(prefix, hp.vocoder, spker,
                                             sentence)))

        utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)],
                        ['Synthesized Spectrogram'],
                        filename=os.path.join(
                            hp.test_path, '{}_{}.png'.format(prefix,
                                                             sentence)))
コード例 #13
0
                        mel_postnet_torch, melgan,
                        os.path.join(
                            hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                current_step, hp.vocoder)))
                    utils.melgan_infer(
                        mel_target_torch, melgan,
                        os.path.join(
                            hp.synth_path,
                            'step_{}_ground-truth_{}.wav'.format(
                                current_step, hp.vocoder)))
                elif hp.vocoder == 'waveglow':
                    # utils.waveglow_infer(mel_torch, waveglow, os.path.join(
                    # hp.synth_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder)))
                    utils.waveglow_infer(
                        mel_postnet_torch + hp.mel_mean, waveglow,
                        os.path.join(
                            hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                current_step, hp.vocoder)))
                    utils.waveglow_infer(
                        mel_target_torch, waveglow,
                        os.path.join(
                            hp.synth_path,
                            'step_{}_ground-truth_{}.wav'.format(
                                current_step, hp.vocoder)))

                utils.plot_data(
                    [mel_postnet.numpy() + hp.mel_mean,
                     mel_target.numpy()],
                    ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                    filename=os.path.join(synth_path,
                                          'step_{}.png'.format(current_step)))
コード例 #14
0
def evaluate(model, step, vocoder=None):
    torch.manual_seed(0)

    # Get dataset
    print("Load data to buffer")
    buffer = get_data_to_buffer('val.txt')
    dataset = BufferDataset(buffer)

    # Get Training Loader
    validating_loader = DataLoader(dataset,
                                   batch_size=hp.batch_expand_size *
                                   hp.batch_size,
                                   shuffle=True,
                                   collate_fn=collate_fn_tensor,
                                   drop_last=False,
                                   num_workers=0)

    # Get Loss
    fastspeech_loss = DNNLoss().to(device)

    t_l = []
    d_l = []
    mel_l = []
    mel_p_l = []
    idx = 0
    current_step = 0
    x = [i for i, batchs in enumerate(validating_loader)]
    print(len(x))
    for i, batchs in enumerate(validating_loader):
        # real batch start here
        for j, db in enumerate(batchs):
            print(len(batchs), len(db))
            # Get Data
            id_ = db["name"]
            src_len = torch.from_numpy(db["src_len"]).long().to(device)
            mel_len = torch.from_numpy(db["mel_len"]).long().to(device)

            # Get Data
            character = db["text"].long().to(device)
            mel_target = db["mel_target"].float().to(device)
            duration = db["duration"].int().to(device)
            mel_pos = db["mel_pos"].long().to(device)
            src_pos = db["src_pos"].long().to(device)
            max_mel_len = db["mel_max_len"]
            print(duration.shape)
            # Forward
            mel_output, mel_postnet_output, duration_predictor_output = model(
                character,
                src_pos,
                mel_pos=mel_pos,
                mel_max_length=max_mel_len,
                length_target=duration)

            # Cal Loss
            mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(
                mel_output, mel_postnet_output, duration_predictor_output,
                mel_target, duration)
            total_loss = mel_loss + mel_postnet_loss + duration_loss

            t_l.append(total_loss.item())
            d_l.append(duration_loss.item())
            mel_l.append(mel_loss.item())
            mel_p_l.append(mel_postnet_loss.item())

            if vocoder is not None:
                # Run vocoding and plotting spectrogram only when the vocoder is defined
                for k in range(len(mel_target)):
                    basename = id_[k]
                    gt_length = mel_len[k]
                    # out_length = out_mel_len[k]
                    mel_target_torch = mel_target[k:k +
                                                  1, :gt_length].transpose(
                                                      1, 2).detach()
                    # mel_target_ = mel_target[k, :gt_length].cpu(
                    # ).transpose(0, 1).detach()

                    mel_postnet_torch = mel_postnet_output[k:k +
                                                           1, :].transpose(
                                                               1, 2).detach()
                    mel_postnet = mel_postnet_output[k, :].cpu().transpose(
                        0, 1).detach()

                    if hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_target_torch, vocoder,
                            os.path.join(
                                hp.eval_path, 'ground-truth_{}_{}.wav'.format(
                                    basename, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, vocoder,
                            os.path.join(
                                hp.eval_path,
                                'eval_{}_{}.wav'.format(basename, hp.vocoder)))

                    np.save(
                        os.path.join(hp.eval_path,
                                     'eval_{}_mel.npy'.format(basename)),
                        mel_postnet.numpy())
                    idx += 1

            current_step += 1

    t_l = sum(t_l) / len(t_l)
    d_l = sum(d_l) / len(d_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Total Loss {},".format(t_l)
    str3 = "Duration Loss: {}".format(d_l)
    str4 = "Mel Loss: {}".format(mel_l)
    str5 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
        f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write("\n")

    return t_l, d_l, mel_l, mel_p_l
コード例 #15
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("test.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=False,
                        collate_fn=dataset.collate_fn,
                        drop_last=False,
                        num_workers=0)

    model = get_FastSpeech2(args.step, full_path=args.model_fs).to(device)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan(full_path=args.model_melgan)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'test'))
    test_logger = SummaryWriter(os.path.join(log_path, 'test'))

    # Init synthesis directory
    test_path = hp.test_path
    if not os.path.exists(test_path):
        os.makedirs(test_path)

    current_step = args.step
    findex = open(os.path.join(test_path, "index.tsv"), "w")

    # Testing
    print("Generate test audio")
    prefix = ""
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            print("Start batch", j)
            fids = data_of_batch["id"]

            # Get Data
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).long().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                text, src_len)

            for i in range(len(mel_len)):
                fid = fids[i]
                print("Generate audio for:", fid)
                length = mel_len[i].item()

                _mel_target_torch = mel_target[i, :length].detach().unsqueeze(
                    0).transpose(1, 2)

                _mel_target = mel_target[i, :length].detach().cpu().transpose(
                    0, 1)

                _mel_torch = mel_output[i, :length].detach().unsqueeze(
                    0).transpose(1, 2)

                _mel = mel_output[i, :length].detach().cpu().transpose(0, 1)

                _mel_postnet_torch = mel_postnet_output[
                    i, :length].detach().unsqueeze(0).transpose(1, 2)

                _mel_postnet = mel_postnet_output[
                    i, :length].detach().cpu().transpose(0, 1)

                fname = "{}{}_step_{}_gt_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel_target,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "Griffin Lim", "vocoder", fname, "",
                                  fid)

                fname = "{}{}_step_{}_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "FastSpeech2 + GL", "tts", fname, "",
                                  fid)

                fname = "{}{}_step_{}_postnet_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel_postnet,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "FastSpeech2 + PN + GL", "tts",
                                  fname, "", fid)

                if hp.vocoder == 'melgan':
                    fname = '{}{}_step_{}_ground-truth_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_target_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "Melgan", "vocoder", fname, "",
                                      fid)

                    fname = '{}{}_step_{}_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "FastSpeech2 + Melgan", "tts",
                                      fname, "", fid)

                    fname = '{}{}_step_{}_postnet_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_postnet_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "FastSpeech2 + PN + Melgan",
                                      "tts", fname, "", fid)

                elif hp.vocoder == 'waveglow':
                    utils.waveglow_infer(
                        _mel_torch, waveglow,
                        os.path.join(
                            hp.test_path,
                            'step_{}_{}.wav'.format(current_step, hp.vocoder)))
                    utils.waveglow_infer(
                        _mel_postnet_torch, waveglow,
                        os.path.join(
                            hp.test_path, 'step_{}_postnet_{}.wav'.format(
                                current_step, hp.vocoder)))
                    utils.waveglow_infer(
                        _mel_target_torch, waveglow,
                        os.path.join(
                            hp.test_path, 'step_{}_ground-truth_{}.wav'.format(
                                current_step, hp.vocoder)))