Пример #1
0
def evaluate(model, step, vocoder=None):

    # Get dataset
    print('evaluating..')
          
        
    # Get dataset
    if hp.with_hanzi:
        dataset = Dataset(filename_py="val_pinyin.txt",vocab_file_py = 'vocab_pinyin.txt',
                     filename_hz = "val_hanzi.txt",
                     vocab_file_hz = 'vocab_hanzi.txt')
        py_vocab_size = len(dataset.py_vocab)
        hz_vocab_size = len(dataset.hz_vocab)


    else:
        dataset = Dataset(filename_py="val_pinyin.txt",vocab_file_py = 'vocab_pinyin.txt',
                     filename_hz = None,
                     vocab_file_hz = None)
        py_vocab_size = len(dataset.py_vocab)
        hz_vocab_size = None

    
    loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False,
                        collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, )



    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    bar = tqdm.tqdm_notebook(total=len(dataset)//hp.batch_size)

    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            bar.update(1)
            
            # Get Data
            id_ = data_of_batch["id"]
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            if hp.with_hanzi:
                hz_text = torch.from_numpy(
                data_of_batch["hz_text"]).long().to(device)
            else:
                hz_text = None
                
            
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            with torch.no_grad():
                mel_output, mel_postnet_output, log_duration_output, src_mask, mel_mask, out_mel_len = model(
                src_seq=text, src_len=src_len, hz_seq=hz_text,mel_len=mel_len,
                d_target=D, max_src_len=max_src_len, max_mel_len=max_mel_len)
                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss = Loss(
                    log_duration_output, log_D, mel_output, mel_postnet_output, mel_target-hp.mel_mean, ~src_mask, ~mel_mask)

                d_l.append(d_loss.item())
               # f_l.append(f_loss.item())
               # e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                if vocoder is not None:
                    # Run vocoding and plotting spectrogram only when the vocoder is defined
                    for k in range(len(mel_target)):
                        basename = id_[k]
                        gt_length = mel_len[k]
                        out_length = out_mel_len[k]

                        mel_target_torch = mel_target[k:k+1,
                                                      :gt_length].transpose(1, 2).detach()
                        mel_target_ = mel_target[k, :gt_length].cpu(
                        ).transpose(0, 1).detach()

                        mel_postnet_torch = mel_postnet_output[k:k +
                                                               1, :out_length].transpose(1, 2).detach()
                        mel_postnet = mel_postnet_output[k, :out_length].cpu(
                        ).transpose(0, 1).detach()

                        if hp.vocoder == 'melgan':
                            utils.melgan_infer(mel_target_torch, vocoder, os.path.join(
                                hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                            utils.melgan_infer(mel_postnet_torch+hp.mel_mean, vocoder, os.path.join(
                                hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
                        elif hp.vocoder == 'waveglow':
                            utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(
                                hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                            utils.waveglow_infer(mel_postnet_torch+hp.mel_mean, vocoder, os.path.join(
                                hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))

                       # np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(
                           # basename)), mel_postnet.numpy()+hp.mel_mean)

#                         f0_ = f0[k, :gt_length].detach().cpu().numpy()
#                         energy_ = energy[k, :gt_length].detach().cpu().numpy()
#                         f0_output_ = f0_output[k,
#                                                :out_length].detach().cpu().numpy()
#                         energy_output_ = energy_output[k, :out_length].detach(
#                         ).cpu().numpy()

                        utils.plot_data([mel_postnet.numpy()+hp.mel_mean,mel_target_.numpy()],
                                        ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename)))
                        idx += 1

            current_step += 1

    d_l = sum(d_l) / len(d_l)
   # f_l = sum(f_l) / len(f_l)
   # e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

 
    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    #str3 = "F0 Loss: {}".format(f_l)
    #  str4 = "Energy Loss: {}".format(e_l)
    str4 = "Mel Loss: {}".format(mel_l)
    str5 = "Mel Postnet Loss: {}".format(mel_p_l)
    str6 = "total Loss: {}".format(mel_p_l+mel_l+d_l)

    print("\n" + str1)
    print(str2)
    # print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
       # f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")
    return d_l,  mel_l, mel_p_l
Пример #2
0
def evaluate(model, step):
    torch.manual_seed(0)
    
    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(dataset, batch_size=hp.batch_size*4, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, )
    
    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    if hp.vocoder=='WORLD':
        ap = []
        sp_l = []
        sp_p_l = []
    else:
        mel_l = []
        mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            id_ = data_of_batch["id"]
            condition = torch.from_numpy(data_of_batch["condition"]).long().to(device)
            mel_refer = torch.from_numpy(data_of_batch["mel_refer"]).float().to(device)
            if hp.vocoder=='WORLD':
                ap_target = torch.from_numpy(data_of_batch["ap_target"]).float().to(device)
                sp_target = torch.from_numpy(data_of_batch["sp_target"]).float().to(device)
            else:
                mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).long().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device)
            #print(D,log_D)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)
        
            with torch.no_grad():
                # Forward
                if hp.vocoder=='WORLD':
#                     print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape)
                    ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output,energy_output, src_mask, ap_mask,sp_mask ,variance_adaptor_output,decoder_output= model(
                    condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len)
                
                    ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, 
                        sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, 
                        sp_target=sp_target,src_mask=src_mask, ap_mask=ap_mask,sp_mask=sp_mask)
                    total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss
                else:
                    mel_output, mel_postnet_output, log_duration_output, f0_output,energy_output, src_mask, mel_mask, _ = model(
                    condition,mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len)
                    
                    mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output,
                        mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask)
                    total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss
                     
                t_l = total_loss.item()
                if hp.vocoder=='WORLD':
                    ap_l = ap_loss.item()
                    sp_l = sp_loss.item()
                    sp_p_l = sp_postnet_loss.item()
                else:
                    m_l = mel_loss.item()
                    m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                
    
                # Run vocoding and plotting spectrogram only when the vocoder is defined
                for k in range(len(mel_target)):
                    basename = id_[k]
                    gt_length = mel_len[k]
                    out_length = out_mel_len[k]

                    mel_target_torch = mel_target[k:k+1, :gt_length].transpose(1, 2).detach()                        
                    mel_postnet_torch = mel_postnet_output[k:k+1, :out_length].transpose(1, 2).detach()

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                        utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                        utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
                    elif hp.vocoder=='WORLD':
                        utils.world_infer(mel_postnet_torch.numpy(),f0_output, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
                        utils.world_infer(mel_target_torch.numpy(),f0,  os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
                    np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(basename)), mel_postnet.numpy())

                    f0_ = f0[k, :gt_length].detach().cpu().numpy()
                    energy_ = energy[k, :gt_length].detach().cpu().numpy()
                    f0_output_ = f0_output[k, :out_length].detach().cpu().numpy()
                    energy_output_ = energy_output[k, :out_length].detach().cpu().numpy()

                    utils.plot_data([(mel_postnet[0].numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], 
                        ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename)))
                    idx += 1
                
            current_step += 1            

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    
    if hp.vocoder=='WORLD':
        ap_l = sum(ap_l) / len(ap_l)
        sp_l = sum(sp_l) / len(sp_l)
        sp_p_l = sum(sp_p_l) / len(sp_p_l) 
    else:
        mel_l = sum(mel_l) / len(mel_l)
        mel_p_l = sum(mel_p_l) / len(mel_p_l) 
                    
    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
        f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")

    return d_l, f_l, e_l, mel_l, mel_p_l
Пример #3
0
def main(args):
    torch.manual_seed(0)

    # Get device
    # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoaderX(dataset,
                         batch_size=hp.batch_size * 4,
                         shuffle=True,
                         collate_fn=dataset.collate_fn,
                         drop_last=True,
                         num_workers=8)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    #     model = FastSpeech2().to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    train_logger = SummaryWriter(log_dir='log/train/' + current_time)
    val_logger = SummaryWriter(log_dir='log/validation/' + current_time)
    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()
    current_step0 = 0
    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * len(batchs) + j + args.restore_step + \
                    epoch * len(loader)*len(batchs) + 1
                # Get Data
                condition = torch.from_numpy(
                    data_of_batch["condition"]).long().to(device)
                mel_refer = torch.from_numpy(
                    data_of_batch["mel_refer"]).float().to(device)
                if hp.vocoder == 'WORLD':
                    ap_target = torch.from_numpy(
                        data_of_batch["ap_target"]).float().to(device)
                    sp_target = torch.from_numpy(
                        data_of_batch["sp_target"]).float().to(device)
                else:
                    mel_target = torch.from_numpy(
                        data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                #print(D,log_D)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                if hp.vocoder == 'WORLD':
                    #                     print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape)
                    ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model(
                        condition, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        ap_output=ap_output,
                        sp_output=sp_output,
                        sp_postnet_output=sp_postnet_output,
                        ap_target=ap_target,
                        sp_target=sp_target,
                        src_mask=src_mask,
                        ap_mask=ap_mask,
                        sp_mask=sp_mask)
                    total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss
                else:
                    mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                        condition, mel_refer, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        log_D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        mel_output=mel_output,
                        mel_postnet_output=mel_postnet_output,
                        mel_target=mel_target,
                        src_mask=~src_mask,
                        mel_mask=~mel_mask)
                    total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                if hp.vocoder == 'WORLD':
                    ap_l = ap_loss.item()
                    sp_l = sp_loss.item()
                    sp_p_l = sp_postnet_loss.item()
                else:
                    m_l = mel_loss.item()
                    m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                #                 with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
                #                     f_total_loss.write(str(t_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
                #                     f_mel_loss.write(str(m_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                #                     f_mel_postnet_loss.write(str(m_p_l)+"\n")
                #                 with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
                #                     f_d_loss.write(str(d_l)+"\n")
                #                 with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
                #                     f_f_loss.write(str(f_l)+"\n")
                #                 with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
                #                     f_e_loss.write(str(e_l)+"\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}],Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    if hp.vocoder == 'WORLD':
                        str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l)
                    else:
                        str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "T:{:.1f}s,ETA:{:.1f}s.".format(
                        (Now - Start) / (current_step - current_step0),
                        (total_step - current_step) * np.mean(Time))

                    print("" + str1 + str2 + str3 + '')

                    #                     with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                    #                         f_log.write(str1 + "\n")
                    #                         f_log.write(str2 + "\n")
                    #                         f_log.write(str3 + "\n")
                    #                         f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    if hp.vocoder == 'WORLD':
                        train_logger.add_scalar('Loss/ap_loss', ap_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_loss', sp_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l,
                                                current_step)
                    else:
                        train_logger.add_scalar('Loss/mel_loss', m_l,
                                                current_step)
                        train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                                current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                if current_step % hp.save_step == 0 or current_step == 20:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0 or current_step == 5:
                    length = mel_len[0].item()

                    if hp.vocoder == 'WORLD':
                        ap_target_torch = ap_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        ap_torch = ap_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_target_torch = sp_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        sp_torch = sp_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_postnet_torch = sp_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                    else:
                        mel_target_torch = mel_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_target = mel_target[
                            0, :length].detach().cpu().transpose(0, 1)
                        mel_torch = mel_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        mel = mel_output[0, :length].detach().cpu().transpose(
                            0, 1)
                        mel_postnet_torch = mel_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_postnet = mel_postnet_output[
                            0, :length].detach().cpu().transpose(0, 1)
#                     Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step)))
#                     Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'WORLD':
                        #     ap=np.swapaxes(ap,0,1)
                        #     sp=np.swapaxes(sp,0,1)
                        wav = utils.world_infer(
                            np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1),
                            np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0,
                                        1), f0_output)
                        sf.write(
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)
                        wav = utils.world_infer(
                            np.swapaxes(ap_target_torch[0].cpu().numpy(), 0,
                                        1),
                            np.swapaxes(sp_target_torch[0].cpu().numpy(), 0,
                                        1), f0)
                        sf.write(
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)

                    utils.plot_data([
                        (sp_postnet_torch[0].cpu().numpy(), f0_output,
                         energy_output),
                        (sp_target_torch[0].cpu().numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                    plt.matshow(sp_postnet_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'sp_postnet_{}.png'.format(current_step)))
                    plt.matshow(ap_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'ap_{}.png'.format(current_step)))
                    plt.matshow(
                        variance_adaptor_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step)))
                    #                     plt.matshow(decoder_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step)))

                    plt.cla()
                    fout = open(
                        os.path.join(synth_path,
                                     'D_{}.txt'.format(current_step)), 'w')
                    fout.write(
                        str(log_duration_output[0].detach().cpu().numpy()) +
                        '\n')
                    fout.write(str(D[0].detach().cpu().numpy()) + '\n')
                    fout.write(
                        str(condition[0, :, 2].detach().cpu().numpy()) + '\n')
                    fout.close()


#                 if current_step % hp.eval_step == 0 or current_step==20:
#                     model.eval()
#                     with torch.no_grad():

#                         if hp.vocoder=='WORLD':
#                             d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)
#                         else:
#                             d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + m_l + m_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_loss', m_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)

#                     model.train()
#                 if current_step%10==0:
#                     print(energy_output[0],energy[0])
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Пример #4
0
def main(args):
    torch.manual_seed(0)
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    if hp.use_spk_embed:
        n_pkers = len(dataset.spk_table.keys())
        model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device)
    else:
        model = nn.DataParallel(FastSpeech2()).to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        #melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Init evaluation directory
    eval_path = hp.eval_path
    if not os.path.exists(eval_path):
        os.makedirs(eval_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1
                print("step : {}".format(current_step), end='\r', flush=True)

                ### Get Data ###
                if hp.use_spk_embed:
                    spk_ids = torch.tensor(data_of_batch["spk_ids"]).to(
                        torch.int64).to(device)
                else:
                    spk_ids = None
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                ### Forward ###
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len, spk_ids)

                ### Cal Loss ###
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss

                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                ## Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                ### Update weights ###
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                ### Print ###
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                ### Save model ###
                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                ### Synth ###
                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    print("step: {} , length {}, {}".format(
                        current_step, length, mel_len))
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)

                    spk_id = dataset.inv_spk_table[int(spk_ids[0])]
                    if hp.vocoder == 'melgan':
                        vocoder.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    elif hp.vocoder == 'waveglow':
                        vocoder.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                ### Evaluation ###
                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)

                    model.train()

                ### Time ###
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Пример #5
0
def main(args):
    torch.manual_seed(0)

    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(
        dataset,
        batch_size=hp.batch_size**2,
        shuffle=False,
        collate_fn=dataset.collate_fn,
        drop_last=False,
        num_workers=0,
    )

    # Get model
    model = get_FastSpeech2(args.step).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Init directories
    if not os.path.exists(hp.logger_path):
        os.makedirs(hp.logger_path)
    if not os.path.exists(hp.eval_path):
        os.makedirs(hp.eval_path)

    # Get loss function
    Loss = FastSpeech2Loss().to(device)
    print("Loss Function Defined.")

    # Load vocoder
    wave_glow = utils.get_WaveGlow()

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            mel_pos = torch.from_numpy(
                data_of_batch["mel_pos"]).long().to(device)
            src_pos = torch.from_numpy(
                data_of_batch["src_pos"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_len = max(data_of_batch["mel_len"]).astype(np.int16)

            with torch.no_grad():
                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    duration_output, D, f0_output, f0, energy_output, energy,
                    mel_output, mel_postnet_output, mel_target)

                d_l.append(d_loss.item())
                f_l.append(f_loss.item())
                e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                for k in range(len(mel_target)):
                    length = mel_len[k]

                    mel_target_torch = mel_target[k:k + 1, :length].transpose(
                        1, 2).detach()
                    mel_target_ = mel_target[k, :length].cpu().transpose(
                        0, 1).detach()
                    waveglow.inference.inference(
                        mel_target_torch, wave_glow,
                        os.path.join(
                            hp.eval_path,
                            'ground-truth_{}_waveglow.wav'.format(idx)))

                    mel_postnet_torch = mel_postnet_output[
                        k:k + 1, :length].transpose(1, 2).detach()
                    mel_postnet = mel_postnet_output[
                        k, :length].cpu().transpose(0, 1).detach()
                    waveglow.inference.inference(
                        mel_postnet_torch, wave_glow,
                        os.path.join(hp.eval_path,
                                     'eval_{}_waveglow.wav'.format(idx)))

                    utils.plot_data([
                        (mel_postnet.numpy(), None, None),
                        (mel_target_.numpy(), None, None)
                    ], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        hp.eval_path,
                                        'eval_{}.png'.format(idx)))
                    idx += 1

            current_step += 1

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(args.step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.logger_path, "eval.txt"), "a") as f_logger:
        f_logger.write(str1 + "\n")
        f_logger.write(str2 + "\n")
        f_logger.write(str3 + "\n")
        f_logger.write(str4 + "\n")
        f_logger.write(str5 + "\n")
        f_logger.write(str6 + "\n")
        f_logger.write("\n")
Пример #6
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # read params
    mean_mel, std_mel = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "mel_stat.npy")),
                                     dtype=torch.float).to(device)
    mean_f0, std_f0 = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "f0_stat.npy")),
                                   dtype=torch.float).to(device)
    mean_energy, std_energy = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "energy_stat.npy")),
                                           dtype=torch.float).to(device)

    mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1)
    mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1)
    mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape(
        1, -1)

    # Load vocoder
    if hp.vocoder == 'vocgan':
        vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path)
        vocoder.to(device)
    else:
        vocoder = None

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                train_logger.add_scalar('Loss/total_loss', t_l, current_step)
                train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                        current_step)
                train_logger.add_scalar('Loss/duration_loss', d_l,
                                        current_step)
                train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                train_logger.add_scalar('Loss/energy_loss', e_l, current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step, vocoder)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
def evaluate(model, step, vocoder=None):
    model.eval()
    torch.manual_seed(0)

    mean_mel, std_mel = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "mel_stat.npy")),
                                     dtype=torch.float).to(device)
    mean_f0, std_f0 = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "f0_stat.npy")),
                                   dtype=torch.float).to(device)
    mean_energy, std_energy = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "energy_stat.npy")),
                                           dtype=torch.float).to(device)

    eval_path = hp.eval_path
    if not os.path.exists(eval_path):
        os.makedirs(eval_path)

    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(
        dataset,
        batch_size=hp.batch_size**2,
        shuffle=False,
        collate_fn=dataset.collate_fn,
        drop_last=False,
        num_workers=0,
    )

    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            id_ = data_of_batch["id"]
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            with torch.no_grad():
                # Forward
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, out_mel_len = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)

                d_l.append(d_loss.item())
                f_l.append(f_loss.item())
                e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                if idx == 0 and vocoder is not None:
                    # Run vocoding and plotting spectrogram only when the vocoder is defined
                    for k in range(1):
                        basename = id_[k]
                        gt_length = mel_len[k]
                        out_length = out_mel_len[k]

                        mel_target_torch = mel_target[k:k + 1, :gt_length]
                        mel_target_ = mel_target[k, :gt_length]
                        mel_postnet_torch = mel_postnet_output[k:k +
                                                               1, :out_length]
                        mel_postnet = mel_postnet_output[k, :out_length]

                        mel_target_torch = utils.de_norm(
                            mel_target_torch, mean_mel,
                            std_mel).transpose(1, 2).detach()
                        mel_target_ = utils.de_norm(mel_target_, mean_mel,
                                                    std_mel).cpu().transpose(
                                                        0, 1).detach()
                        mel_postnet_torch = utils.de_norm(
                            mel_postnet_torch, mean_mel,
                            std_mel).transpose(1, 2).detach()
                        mel_postnet = utils.de_norm(mel_postnet, mean_mel,
                                                    std_mel).cpu().transpose(
                                                        0, 1).detach()

                        if hp.vocoder == "vocgan":
                            utils.vocgan_infer(
                                mel_target_torch,
                                vocoder,
                                path=os.path.join(
                                    hp.eval_path,
                                    'eval_groundtruth_{}_{}.wav'.format(
                                        basename, hp.vocoder)))
                            utils.vocgan_infer(mel_postnet_torch,
                                               vocoder,
                                               path=os.path.join(
                                                   hp.eval_path,
                                                   'eval_{}_{}_{}.wav'.format(
                                                       step, basename,
                                                       hp.vocoder)))
                        np.save(
                            os.path.join(
                                hp.eval_path, 'eval_step_{}_{}_mel.npy'.format(
                                    step, basename)), mel_postnet.numpy())

                        f0_ = f0[k, :gt_length]
                        energy_ = energy[k, :gt_length]
                        f0_output_ = f0_output[k, :out_length]
                        energy_output_ = energy_output[k, :out_length]

                        f0_ = utils.de_norm(f0_, mean_f0,
                                            std_f0).detach().cpu().numpy()
                        f0_output_ = utils.de_norm(
                            f0_output, mean_f0, std_f0).detach().cpu().numpy()
                        energy_ = utils.de_norm(
                            energy_, mean_energy,
                            std_energy).detach().cpu().numpy()
                        energy_output_ = utils.de_norm(
                            energy_output_, mean_energy,
                            std_energy).detach().cpu().numpy()

                        utils.plot_data(
                            [(mel_postnet.numpy(), f0_output_, energy_output_),
                             (mel_target_.numpy(), f0_, energy_)], [
                                 'Synthesized Spectrogram',
                                 'Ground-Truth Spectrogram'
                             ],
                            filename=os.path.join(
                                hp.eval_path,
                                'eval_step_{}_{}.png'.format(step, basename)))
                        idx += 1
                    print("done")
            current_step += 1

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
        f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")
    model.train()

    return d_l, f_l, e_l, mel_l, mel_p_l
Пример #8
0
                    collate_fn=dataset.collate_fn,
                    drop_last=False,
                    num_workers=8)
# Define model
model = FastSpeech2(py_vocab_size, hz_vocab_size).to(device)
num_param = utils.get_param_num(model)

# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(),
                             lr=hp.start_lr,
                             betas=hp.betas,
                             eps=hp.eps,
                             weight_decay=0)
scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                 hp.n_warm_up_step, args.restore_step)
Loss = FastSpeech2Loss().to(device)
print("Optimizer and Loss Function Defined.")
# Load checkpoint if exists
checkpoint_path = os.path.join(hp.checkpoint_path)
try:
    checkpoint = torch.load(
        os.path.join(checkpoint_path,
                     'checkpoint_{}.pth.tar'.format(args.restore_step)))

    #temp = nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])
    #model.load_state_dict(temp.module.state_dict())
    #del temp
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("\n---Model Restored at Step {}---\n".format(args.restore_step))
except:
def evaluate(model, step, vocoder=None):
    torch.manual_seed(0)
    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(
        dataset,
        batch_size=hp.batch_size**2,
        shuffle=False,
        collate_fn=dataset.collate_fn,
        drop_last=False,
        num_workers=0,
    )

    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            id_ = data_of_batch["id"]
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            with torch.no_grad():
                # Forward
                mel_output, mel_postnet_output, duration_output, src_mask, pred_mel_mask, enc_attns, dec_attns, W = model(
                    text, src_len, mel_len, max_src_len, max_mel_len)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss = Loss(
                    duration_output, mel_len, mel_output, mel_postnet_output,
                    mel_target, src_mask, pred_mel_mask)

                d_l.append(d_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                if vocoder is not None:
                    # Run vocoding and plotting spectrogram only when the vocoder is defined
                    for k in range(len(mel_target)):
                        basename = id_[k]
                        gt_length = mel_len[k]
                        out_length = out_mel_len[k]

                        mel_target_torch = mel_target[k:k +
                                                      1, :gt_length].transpose(
                                                          1, 2).detach()
                        mel_target_ = mel_target[
                            k, :gt_length].cpu().transpose(0, 1).detach()

                        mel_postnet_torch = mel_postnet_output[
                            k:k + 1, :out_length].transpose(1, 2).detach()
                        mel_postnet = mel_postnet_output[
                            k, :out_length].cpu().transpose(0, 1).detach()

                        if hp.vocoder == 'melgan':
                            utils.melgan_infer(
                                mel_target_torch, vocoder,
                                os.path.join(
                                    hp.eval_path,
                                    'ground-truth_{}_{}.wav'.format(
                                        basename, hp.vocoder)))
                            utils.melgan_infer(
                                mel_postnet_torch, vocoder,
                                os.path.join(
                                    hp.eval_path, 'eval_{}_{}.wav'.format(
                                        basename, hp.vocoder)))
                        elif hp.vocoder == 'waveglow':
                            utils.waveglow_infer(
                                mel_target_torch, vocoder,
                                os.path.join(
                                    hp.eval_path,
                                    'ground-truth_{}_{}.wav'.format(
                                        basename, hp.vocoder)))
                            utils.waveglow_infer(
                                mel_postnet_torch, vocoder,
                                os.path.join(
                                    hp.eval_path, 'eval_{}_{}.wav'.format(
                                        basename, hp.vocoder)))

                        np.save(
                            os.path.join(hp.eval_path,
                                         'eval_{}_mel.npy'.format(basename)),
                            mel_postnet.numpy())

                        f0_ = f0[k, :gt_length].detach().cpu().numpy()
                        energy_ = energy[k, :gt_length].detach().cpu().numpy()
                        f0_output_ = f0_output[
                            k, :out_length].detach().cpu().numpy()
                        energy_output_ = energy_output[
                            k, :out_length].detach().cpu().numpy()

                        utils.plot_data(
                            [(mel_postnet.numpy(), f0_output_, energy_output_),
                             (mel_target_.numpy(), f0_, energy_)], [
                                 'Synthesized Spectrogram',
                                 'Ground-Truth Spectrogram'
                             ],
                            filename=os.path.join(
                                hp.eval_path, 'eval_{}.png'.format(basename)))
                        idx += 1

            current_step += 1

    d_l = sum(d_l) / len(d_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")

    return d_l, mel_l, mel_p_l
Пример #10
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    wave_glow = utils.get_WaveGlow()

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_len = max(data_of_batch["mel_len"]).astype(np.int16)

                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D, f0, energy)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    duration_output, D, f0_output, f0, energy_output, energy,
                    mel_output, mel_postnet_output, mel_target, src_len,
                    mel_len)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    logger.add_scalars('Loss/total_loss', {'training': t_l},
                                       current_step)
                    logger.add_scalars('Loss/mel_loss', {'training': m_l},
                                       current_step)
                    logger.add_scalars('Loss/mel_postnet_loss',
                                       {'training': m_p_l}, current_step)
                    logger.add_scalars('Loss/duration_loss', {'training': d_l},
                                       current_step)
                    logger.add_scalars('Loss/F0_loss', {'training': f_l},
                                       current_step)
                    logger.add_scalars('Loss/energy_loss', {'training': e_l},
                                       current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    Audio.tools.inv_mel_spec(
                        mel,
                        os.path.join(
                            synth_path,
                            "step_{}_griffin_lim.wav".format(current_step)))
                    Audio.tools.inv_mel_spec(
                        mel_postnet,
                        os.path.join(
                            synth_path,
                            "step_{}_postnet_griffin_lim.wav".format(
                                current_step)))
                    waveglow.inference.inference(
                        mel_torch, wave_glow,
                        os.path.join(
                            synth_path,
                            "step_{}_waveglow.wav".format(current_step)))
                    waveglow.inference.inference(
                        mel_postnet_torch, wave_glow,
                        os.path.join(
                            synth_path, "step_{}_postnet_waveglow.wav".format(
                                current_step)))
                    waveglow.inference.inference(
                        mel_target_torch, wave_glow,
                        os.path.join(
                            synth_path,
                            "step_{}_ground-truth_waveglow.wav".format(
                                current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        logger.add_scalars('Loss/total_loss',
                                           {'validation': t_l}, current_step)
                        logger.add_scalars('Loss/mel_loss',
                                           {'validation': m_l}, current_step)
                        logger.add_scalars('Loss/mel_postnet_loss',
                                           {'validation': m_p_l}, current_step)
                        logger.add_scalars('Loss/duration_loss',
                                           {'validation': d_l}, current_step)
                        logger.add_scalars('Loss/F0_loss', {'validation': f_l},
                                           current_step)
                        logger.add_scalars('Loss/energy_loss',
                                           {'validation': e_l}, current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Пример #11
0
def evaluate(model, step, wave_glow=None):
    torch.manual_seed(0)

    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(
        dataset,
        batch_size=hp.batch_size**2,
        shuffle=False,
        collate_fn=dataset.collate_fn,
        drop_last=False,
        num_workers=0,
    )

    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            mel_pos = torch.from_numpy(
                data_of_batch["mel_pos"]).long().to(device)
            src_pos = torch.from_numpy(
                data_of_batch["src_pos"]).long().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_len = max(data_of_batch["mel_len"]).astype(np.int16)

            with torch.no_grad():
                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    duration_output, D, f0_output, f0, energy_output, energy,
                    mel_output, mel_postnet_output, mel_target, src_len,
                    mel_len)

                d_l.append(d_loss.item())
                f_l.append(f_loss.item())
                e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                if wave_glow is not None:
                    # Run vocoding and plotting spectrogram only when the vocoder is defined
                    for k in range(len(mel_target)):
                        length = mel_len[k]

                        mel_target_torch = mel_target[k:k +
                                                      1, :length].transpose(
                                                          1, 2).detach()
                        mel_target_ = mel_target[k, :length].cpu().transpose(
                            0, 1).detach()
                        waveglow.inference.inference(
                            mel_target_torch, wave_glow,
                            os.path.join(
                                hp.eval_path,
                                'ground-truth_{}_waveglow.wav'.format(idx)))

                        mel_postnet_torch = mel_postnet_output[
                            k:k + 1, :length].transpose(1, 2).detach()
                        mel_postnet = mel_postnet_output[
                            k, :length].cpu().transpose(0, 1).detach()
                        waveglow.inference.inference(
                            mel_postnet_torch, wave_glow,
                            os.path.join(hp.eval_path,
                                         'eval_{}_waveglow.wav'.format(idx)))

                        f0_ = f0[k, :length].detach().cpu().numpy()
                        energy_ = energy[k, :length].detach().cpu().numpy()
                        f0_output_ = f0_output[
                            k, :length].detach().cpu().numpy()
                        energy_output_ = energy_output[
                            k, :length].detach().cpu().numpy()

                        utils.plot_data(
                            [(mel_postnet.numpy(), f0_output_, energy_output_),
                             (mel_target_.numpy(), f0_, energy_)], [
                                 'Synthesized Spectrogram',
                                 'Ground-Truth Spectrogram'
                             ],
                            filename=os.path.join(hp.eval_path,
                                                  'eval_{}.png'.format(idx)))
                        idx += 1

            current_step += 1

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
        f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")

    return d_l, f_l, e_l, mel_l, mel_p_l