Ejemplo n.º 1
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    print("Use Tacotron2")
    model = nn.DataParallel(Tacotron2(hp)).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of TTS Parameters:', num_param)
    # Get buffer
    print("Load data to buffer")
    buffer = get_data_to_buffer()

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_rnn_dim,
                                     hp.n_warm_up_step, args.restore_step)
    tts_loss = DNNLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    try:
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Get dataset
    dataset = BufferDataset(buffer)

    # Get Training Loader
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_expand_size *
                                 hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn_tensor,
                                 drop_last=True,
                                 num_workers=0)
    total_step = hp.epochs * len(training_loader) * hp.batch_expand_size

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        for i, batchs in enumerate(training_loader):
            # real batch start here
            for j, db in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_expand_size + j + args.restore_step + \
                    epoch * len(training_loader) * hp.batch_expand_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = db["text"].long().to(device)
                mel_target = db["mel_target"].float().to(device)
                mel_pos = db["mel_pos"].long().to(device)
                src_pos = db["src_pos"].long().to(device)
                max_mel_len = db["mel_max_len"]

                mel_target = mel_target.contiguous().transpose(1, 2)
                src_length = torch.max(src_pos, -1)[0]
                mel_length = torch.max(mel_pos, -1)[0]

                gate_target = mel_pos.eq(0).float()
                gate_target = gate_target[:, 1:]
                gate_target = F.pad(gate_target, (0, 1, 0, 0), value=1.)

                # Forward
                inputs = character, src_length, mel_target, max_mel_len, mel_length
                mel_output, mel_output_postnet, gate_output = model(inputs)

                # Cal Loss
                mel_loss, mel_postnet_loss, gate_loss \
                    = tts_loss(mel_output, mel_output_postnet, gate_output,
                               mel_target, gate_target)
                total_loss = mel_loss + mel_postnet_loss + gate_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                g_l = gate_loss.item()

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join("logger", "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join("logger", "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join("logger", "gate_loss.txt"),
                          "a") as f_g_loss:
                    f_g_loss.write(str(g_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:"\
                        .format(epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Gate Loss: {:.4f};".format(
                        m_l, m_p_l, g_l)
                    str3 = "Current Learning Rate is {:.6f}."\
                        .format(scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s."\
                        .format((Now-Start), (total_step-current_step) * np.mean(Time, dtype=np.float32))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 2
0
def main(args):
    torch.manual_seed(0)

    # Get device
    # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoaderX(dataset,
                         batch_size=hp.batch_size * 4,
                         shuffle=True,
                         collate_fn=dataset.collate_fn,
                         drop_last=True,
                         num_workers=8)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    #     model = FastSpeech2().to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    train_logger = SummaryWriter(log_dir='log/train/' + current_time)
    val_logger = SummaryWriter(log_dir='log/validation/' + current_time)
    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()
    current_step0 = 0
    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * len(batchs) + j + args.restore_step + \
                    epoch * len(loader)*len(batchs) + 1
                # Get Data
                condition = torch.from_numpy(
                    data_of_batch["condition"]).long().to(device)
                mel_refer = torch.from_numpy(
                    data_of_batch["mel_refer"]).float().to(device)
                if hp.vocoder == 'WORLD':
                    ap_target = torch.from_numpy(
                        data_of_batch["ap_target"]).float().to(device)
                    sp_target = torch.from_numpy(
                        data_of_batch["sp_target"]).float().to(device)
                else:
                    mel_target = torch.from_numpy(
                        data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                #print(D,log_D)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                if hp.vocoder == 'WORLD':
                    #                     print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape)
                    ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model(
                        condition, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        ap_output=ap_output,
                        sp_output=sp_output,
                        sp_postnet_output=sp_postnet_output,
                        ap_target=ap_target,
                        sp_target=sp_target,
                        src_mask=src_mask,
                        ap_mask=ap_mask,
                        sp_mask=sp_mask)
                    total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss
                else:
                    mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                        condition, mel_refer, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        log_D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        mel_output=mel_output,
                        mel_postnet_output=mel_postnet_output,
                        mel_target=mel_target,
                        src_mask=~src_mask,
                        mel_mask=~mel_mask)
                    total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                if hp.vocoder == 'WORLD':
                    ap_l = ap_loss.item()
                    sp_l = sp_loss.item()
                    sp_p_l = sp_postnet_loss.item()
                else:
                    m_l = mel_loss.item()
                    m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                #                 with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
                #                     f_total_loss.write(str(t_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
                #                     f_mel_loss.write(str(m_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                #                     f_mel_postnet_loss.write(str(m_p_l)+"\n")
                #                 with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
                #                     f_d_loss.write(str(d_l)+"\n")
                #                 with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
                #                     f_f_loss.write(str(f_l)+"\n")
                #                 with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
                #                     f_e_loss.write(str(e_l)+"\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}],Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    if hp.vocoder == 'WORLD':
                        str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l)
                    else:
                        str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "T:{:.1f}s,ETA:{:.1f}s.".format(
                        (Now - Start) / (current_step - current_step0),
                        (total_step - current_step) * np.mean(Time))

                    print("" + str1 + str2 + str3 + '')

                    #                     with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                    #                         f_log.write(str1 + "\n")
                    #                         f_log.write(str2 + "\n")
                    #                         f_log.write(str3 + "\n")
                    #                         f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    if hp.vocoder == 'WORLD':
                        train_logger.add_scalar('Loss/ap_loss', ap_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_loss', sp_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l,
                                                current_step)
                    else:
                        train_logger.add_scalar('Loss/mel_loss', m_l,
                                                current_step)
                        train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                                current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                if current_step % hp.save_step == 0 or current_step == 20:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0 or current_step == 5:
                    length = mel_len[0].item()

                    if hp.vocoder == 'WORLD':
                        ap_target_torch = ap_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        ap_torch = ap_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_target_torch = sp_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        sp_torch = sp_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_postnet_torch = sp_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                    else:
                        mel_target_torch = mel_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_target = mel_target[
                            0, :length].detach().cpu().transpose(0, 1)
                        mel_torch = mel_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        mel = mel_output[0, :length].detach().cpu().transpose(
                            0, 1)
                        mel_postnet_torch = mel_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_postnet = mel_postnet_output[
                            0, :length].detach().cpu().transpose(0, 1)
#                     Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step)))
#                     Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'WORLD':
                        #     ap=np.swapaxes(ap,0,1)
                        #     sp=np.swapaxes(sp,0,1)
                        wav = utils.world_infer(
                            np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1),
                            np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0,
                                        1), f0_output)
                        sf.write(
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)
                        wav = utils.world_infer(
                            np.swapaxes(ap_target_torch[0].cpu().numpy(), 0,
                                        1),
                            np.swapaxes(sp_target_torch[0].cpu().numpy(), 0,
                                        1), f0)
                        sf.write(
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)

                    utils.plot_data([
                        (sp_postnet_torch[0].cpu().numpy(), f0_output,
                         energy_output),
                        (sp_target_torch[0].cpu().numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                    plt.matshow(sp_postnet_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'sp_postnet_{}.png'.format(current_step)))
                    plt.matshow(ap_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'ap_{}.png'.format(current_step)))
                    plt.matshow(
                        variance_adaptor_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step)))
                    #                     plt.matshow(decoder_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step)))

                    plt.cla()
                    fout = open(
                        os.path.join(synth_path,
                                     'D_{}.txt'.format(current_step)), 'w')
                    fout.write(
                        str(log_duration_output[0].detach().cpu().numpy()) +
                        '\n')
                    fout.write(str(D[0].detach().cpu().numpy()) + '\n')
                    fout.write(
                        str(condition[0, :, 2].detach().cpu().numpy()) + '\n')
                    fout.close()


#                 if current_step % hp.eval_step == 0 or current_step==20:
#                     model.eval()
#                     with torch.no_grad():

#                         if hp.vocoder=='WORLD':
#                             d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)
#                         else:
#                             d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + m_l + m_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_loss', m_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)

#                     model.train()
#                 if current_step%10==0:
#                     print(energy_output[0],energy[0])
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 3
0
def main(args):
    debug = (args.debug == 'True')
    print(args)
    np.random.seed(args.seed)
    with tf.Graph().as_default():
        train_dataset, num_train_file = DateSet(args.file_list, args, debug)
        test_dataset, num_test_file = DateSet(args.test_list, args, debug)
        list_ops = {}

        batch_train_dataset = train_dataset.batch(args.batch_size).repeat()
        train_iterator = batch_train_dataset.make_one_shot_iterator()
        train_next_element = train_iterator.get_next()

        batch_test_dataset = test_dataset.batch(args.batch_size).repeat()
        test_iterator = batch_test_dataset.make_one_shot_iterator()
        test_next_element = test_iterator.get_next()

        list_ops['num_train_file'] = num_train_file
        list_ops['num_test_file'] = num_test_file

        model_dir = args.model_dir

        print('Total number of examples: {}'.format(num_train_file))
        print('Test number of examples: {}'.format(num_test_file))
        print('Model dir: {}'.format(model_dir))

        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False)

        list_ops['global_step'] = global_step
        list_ops['train_dataset'] = train_dataset
        list_ops['test_dataset'] = test_dataset
        list_ops['train_next_element'] = train_next_element
        list_ops['test_next_element'] = test_next_element

        epoch_size = num_train_file // args.batch_size
        print('Number of batches per epoch: {}'.format(epoch_size))

        image_batch = tf.placeholder(tf.float32,
                                     shape=(None, args.image_size,
                                            args.image_size, 3),
                                     name='image_batch')
        landmark_batch = tf.placeholder(tf.float32,
                                        shape=(None, 196),
                                        name='landmark_batch')

        list_ops['image_batch'] = image_batch
        list_ops['landmark_batch'] = landmark_batch

        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        list_ops['phase_train_placeholder'] = phase_train_placeholder

        print('Building training graph.')

        landmarks_pre, landmarks_loss = create_model(image_batch,
                                                     landmark_batch,
                                                     phase_train_placeholder,
                                                     args)
        get_param_num()

        L2_loss = tf.add_n(tf.losses.get_regularization_losses())

        loss_sum = tf.reduce_sum(tf.square(landmark_batch - landmarks_pre),
                                 axis=1)
        loss_sum = tf.reduce_mean(loss_sum)
        loss_sum += L2_loss

        train_op, lr_op = train_model(loss_sum, global_step, num_train_file,
                                      args)

        list_ops['landmarks'] = landmarks_pre
        list_ops['L2_loss'] = L2_loss
        list_ops['loss'] = loss_sum
        list_ops['train_op'] = train_op
        list_ops['lr_op'] = lr_op

        test_mean_error = tf.Variable(tf.constant(0.0),
                                      dtype=tf.float32,
                                      name='ME')
        test_failure_rate = tf.Variable(tf.constant(0.0),
                                        dtype=tf.float32,
                                        name='FR')
        test_10_loss = tf.Variable(tf.constant(0.0),
                                   dtype=tf.float32,
                                   name='TestLoss')
        train_loss = tf.Variable(tf.constant(0.0),
                                 dtype=tf.float32,
                                 name='TrainLoss')
        train_loss_l2 = tf.Variable(tf.constant(0.0),
                                    dtype=tf.float32,
                                    name='TrainLoss2')
        tf.summary.scalar('test_mean_error', test_mean_error)
        tf.summary.scalar('test_failure_rate', test_failure_rate)
        tf.summary.scalar('test_10_loss', test_10_loss)
        tf.summary.scalar('train_loss', train_loss)
        tf.summary.scalar('train_loss_l2', train_loss_l2)

        save_params = tf.trainable_variables()
        saver = tf.train.Saver(save_params, max_to_keep=None)
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)

        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=False,
                                                log_device_placement=False))
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        with sess.as_default():
            epoch_start = 0
            if args.pretrained_model:
                pretrained_model = args.pretrained_model
                if (not os.path.isdir(pretrained_model)):
                    print('Restoring pretrained model: {}'.format(
                        pretrained_model))
                    saver.restore(sess, args.pretrained_model)
                else:
                    print('Model directory: {}'.format(pretrained_model))
                    ckpt = tf.train.get_checkpoint_state(pretrained_model)
                    model_path = ckpt.model_checkpoint_path
                    assert (ckpt and model_path)
                    epoch_start = int(
                        model_path[model_path.find('model.ckpt-') + 11:]) + 1
                    print('Checkpoint file: {}'.format(model_path))
                    saver.restore(sess, model_path)

            # if args.save_image_example:
            #     save_image_example(sess, list_ops, args)

            print('Running train.')

            merged = tf.summary.merge_all()
            train_write = tf.summary.FileWriter(log_dir, sess.graph)
            for epoch in range(epoch_start, args.max_epoch):
                start = time.time()
                train_L, train_L2 = train(sess, epoch_size, epoch, list_ops)
                print("train time: {}".format(time.time() - start))

                checkpoint_path = os.path.join(model_dir, 'model.ckpt')
                metagraph_path = os.path.join(model_dir, 'model.meta')
                saver.save(sess,
                           checkpoint_path,
                           global_step=epoch,
                           write_meta_graph=False)
                if not os.path.exists(metagraph_path):
                    saver.export_meta_graph(metagraph_path)

                start = time.time()
                test_ME, test_FR, test_loss = test(sess, list_ops, args)
                print("test time: {}".format(time.time() - start))

                summary, _, _, _, _, _ = sess.run([
                    merged,
                    test_mean_error.assign(test_ME),
                    test_failure_rate.assign(test_FR),
                    test_10_loss.assign(test_loss),
                    train_loss.assign(train_L),
                    train_loss_l2.assign(train_L2)
                ])
                train_write.add_summary(summary, epoch)
Ejemplo n.º 4
0
        # return [mel, mel_postnet_1, mel_postnet_2], predicted, cemb
        return mel, predicted, cemb

    def inference(self, character, alpha=1.0):
        x = self.embeddings(character)

        self.pre_gru.flatten_parameters()
        x, _ = self.pre_gru(x)

        x = self.pre_linear(x)
        x = self.LR(x, alpha=alpha)

        self.post_gru.flatten_parameters()
        x, _ = self.post_gru(x)
        mel = self.post_linear(x)
        # mel_postnet_1, mel_postnet_2 = self.postnet.inference(mel)

        # return mel, mel_postnet_1, mel_postnet_2
        return mel


if __name__ == "__main__":
    # Test
    num_1 = utils.get_param_num(LightSpeech())
    print(num_1)

    model = utils.get_Tacotron2()
    num_2 = utils.get_param_num(model)
    print(num_2 / num_1)
Ejemplo n.º 5
0
def main(args, configs):
    preprocess_cfg, model_cfg, train_cfg = configs

    # dataset
    print("Loading dataset...")
    dataset = MyDataset("train.txt", configs)

    batch_size = train_cfg["batch_size"]

    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        drop_last=True,
                        collate_fn=dataset.collate_fn)

    # model, optimizer
    print("Loading model...")
    model, optimizer = get_model(args, configs, device, train=True)
    model = nn.DataParallel(model)
    print("Number of model parameters:", get_param_num(model))

    # loss
    Loss = MyLoss().to(device)

    # output
    ckpt_dir = train_cfg["path"]["ckpt_dir"]
    log_dir = train_cfg["path"]["log_dir"]
    log_path = os.path.join(log_dir, "log.txt")
    val_path = os.path.join(log_dir, "log_val.txt")
    os.makedirs(ckpt_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    # training
    step = train_cfg["optimizer"]["restore_step"] + 1
    epoch = 1

    grad_clip_thresh = train_cfg["optimizer"]["grad_clip_thresh"]

    total_step = train_cfg["step"]["total_step"]
    val_step = train_cfg["step"]["val_step"]
    log_step = train_cfg["step"]["log_step"]
    synth_step = train_cfg["step"]["synth_step"]
    save_step = train_cfg["step"]["save_step"]

    print("Training...")
    outer_bar = tqdm(total=total_step, desc="Training", position=0)
    outer_bar.n = train_cfg["optimizer"]["restore_step"]
    outer_bar.update()

    while True:
        for batch in tqdm(loader, desc="Epoch {}".format(epoch), position=1):
            batch = to_device(batch, device)

            # Forward
            output = model(*batch)

            # Cal loss
            loss = Loss(output, batch)
            total_loss = loss[0]

            # Backward
            total_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
            optimizer.update_lr_and_step()
            optimizer.zero_grad()

            # Log
            if step % log_step == 0:
                message1 = "Step {}/{}, ".format(step, total_step)
                message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Gate Loss: {:.4f}".format(
                    *loss)
                with open(log_path, "a") as f:
                    f.write(message1 + message2 + "\n")
                outer_bar.write(message1 + message2)

            # Eval
            if step % val_step == 0:
                model.eval()

                message = evaluate(model, step, configs)
                with open(val_path, "a") as f:
                    f.write(message + "\n")
                outer_bar.write(message)

                model.train()

            # Synth
            if step % synth_step == 0:
                with torch.no_grad():
                    mel = output[1][0].detach()
                    mel = mel.cpu().numpy().astype(np.float32)
                    plt.imshow(mel)
                    plt.ylim(0, mel.shape[0])
                    plt.colorbar()
                    plt.savefig(
                        os.path.join(log_dir,
                                     "{}.png".format(str(step) + "-mel")))
                    plt.close()

                    mel_truth = batch[2][0].detach()
                    mel_truth = mel_truth.cpu().numpy().astype(np.float32)
                    plt.imshow(mel_truth)
                    plt.ylim(0, mel_truth.shape[0])
                    plt.colorbar()
                    plt.savefig(
                        os.path.join(log_dir,
                                     "{}.png".format(str(step) +
                                                     "-mel_truth")))
                    plt.close()

                    alignment = output[-1][0].detach()
                    alignment = alignment.cpu().numpy().astype(np.float32).T
                    plt.imshow(alignment)
                    plt.ylim(0, alignment.shape[0])
                    plt.colorbar()
                    plt.savefig(
                        os.path.join(log_dir,
                                     "{}.png".format(str(step) +
                                                     "-alignment")))
                    plt.close()

            # Save
            if step % save_step == 0:
                torch.save(
                    {
                        "model": model.module.state_dict(),
                        "optimizer": optimizer.optimizer.state_dict(),
                    },
                    os.path.join(
                        train_cfg["path"]["ckpt_dir"],
                        "{}.pth.tar".format(step),
                    ),
                )

            # Quit
            if step == total_step:
                quit()

            step += 1
            outer_bar.update(1)

        epoch += 1
Ejemplo n.º 6
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # read params
    mean_mel, std_mel = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "mel_stat.npy")),
                                     dtype=torch.float).to(device)
    mean_f0, std_f0 = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "f0_stat.npy")),
                                   dtype=torch.float).to(device)
    mean_energy, std_energy = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "energy_stat.npy")),
                                           dtype=torch.float).to(device)

    mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1)
    mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1)
    mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape(
        1, -1)

    # Load vocoder
    if hp.vocoder == 'vocgan':
        vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path)
        vocoder.to(device)
    else:
        vocoder = None

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                train_logger.add_scalar('Loss/total_loss', t_l, current_step)
                train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                        current_step)
                train_logger.add_scalar('Loss/duration_loss', d_l,
                                        current_step)
                train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                train_logger.add_scalar('Loss/energy_loss', e_l, current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step, vocoder)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 7
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    model = LightSpeech().to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of LightSpeech Parameters:', num_param)

    # Get dataset
    dataset = LightSpeechDataset()

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=hp.learning_rate,
                                 weight_decay=hp.weight_decay)
    actor_optimizer = torch.optim.Adam([{
        "params": model.embeddings.parameters()
    }, {
        "params": model.pre_gru.parameters()
    }, {
        "params": model.pre_linear.parameters()
    }, {
        "params": model.LR.parameters()
    }],
                                       lr=hp.learning_rate,
                                       weight_decay=hp.weight_decay)

    # Criterion
    criterion = LigthSpeechLoss()

    # Load checkpoint if exists
    try:
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Define Some Information
    Time = np.array([])
    Start = time.clock()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        # Get Training Loader
        training_loader = DataLoader(dataset,
                                     batch_size=hp.batch_size**2,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     drop_last=True,
                                     num_workers=cpu_count())
        total_step = hp.epochs * len(training_loader) * hp.batch_size

        for i, batchs in enumerate(training_loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.clock()

                current_step = i * hp.batch_size + j + args.restore_step + \
                    epoch * len(training_loader)*hp.batch_size + 1

                # Get Data
                character = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_gt_target = torch.from_numpy(
                    data_of_batch["mel_gt_target"]).float().to(device)
                # mel_tac2_target = torch.from_numpy(
                #     data_of_batch["mel_tac2_target"]).float().to(device)

                # D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                # cemb = torch.from_numpy(
                #     data_of_batch["cemb"]).float().to(device)

                input_lengths = torch.from_numpy(
                    data_of_batch["length_text"]).long().to(device)
                output_lengths = torch.from_numpy(
                    data_of_batch["length_mel"]).long().to(device)

                max_c_len = max(input_lengths).item()
                max_mel_len = max(output_lengths).item()

                # Forward
                mel, P, predicted_length, history = model(
                    character, input_lengths, max_c_len, max_mel_len)

                # print(predicted_length)

                # Cal Loss
                mel_loss, len_loss, pg_loss = criterion(
                    mel, predicted_length, mel_gt_target, output_lengths, P,
                    history)
                # model_loss = mel_loss + len_loss
                model_loss = mel_loss
                actor_loss = pg_loss

                # print(mel_loss, len_loss, pg_loss)

                # Logger
                m_l = mel_loss.item()
                l_l = len_loss.item()
                p_l = pg_loss.item()

                with open(os.path.join("logger", "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join("logger", "len_loss.txt"),
                          "a") as f_l_loss:
                    f_l_loss.write(str(l_l) + "\n")

                with open(os.path.join("logger", "pg_loss.txt"),
                          "a") as f_p_loss:
                    f_p_loss.write(str(p_l) + "\n")

                # Backward
                model_loss.backward(retain_graph=True)
                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(), 1.)
                # Update weights
                optimizer.step()
                # Init
                optimizer.zero_grad()

                actor_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.)
                actor_optimizer.step()
                actor_optimizer.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.clock()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss: {:.4f}, Length Loss: {:.4f}, Policy Loss: {:.4f};".format(
                        m_l, l_l, p_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                end_time = time.clock()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 8
0
        return mel_1, mel_2

    def forward(self, mels, length_mel, max_mel_len):
        self.gru_1.flatten_parameters()
        x, _ = self.gru_1(mels)

        mel_postnet_1 = mels + x
        self.gru_2.flatten_parameters()
        y, _ = self.gru_2(mel_postnet_1)

        mel_postnet_2 = mel_postnet_1 + x + y
        mel_postnet_1, mel_postnet_2 = self.mask(mel_postnet_1, mel_postnet_2,
                                                 length_mel, max_mel_len)

        return mel_postnet_1, mel_postnet_2

    def inference(self, mels):
        x, _ = self.gru_1(mels)
        mel_postnet_1 = mels + x
        y, _ = self.gru_2(mel_postnet_1)
        mel_postnet_2 = mel_postnet_1 + x + y

        return mel_postnet_1, mel_postnet_2


if __name__ == "__main__":
    # Test
    test_dp = DurationPredictor()
    print(utils.get_param_num(test_dp))
Ejemplo n.º 9
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(STYLER()).to(device)
    print("Model Has Been Defined")

    # Parameters
    num_param = utils.get_param_num(model)
    text_encoder = utils.get_param_num(
        model.module.style_modeling.style_encoder.text_encoder)
    audio_encoder = utils.get_param_num(
        model.module.style_modeling.style_encoder.audio_encoder)
    predictors = utils.get_param_num(model.module.style_modeling.duration_predictor)\
         + utils.get_param_num(model.module.style_modeling.pitch_predictor)\
              + utils.get_param_num(model.module.style_modeling.energy_predictor)
    decoder = utils.get_param_num(model.module.decoder)
    print('Number of Model Parameters          :', num_param)
    print('Number of Text Encoder Parameters   :', text_encoder)
    print('Number of Audio Encoder Parameters  :', audio_encoder)
    print('Number of Predictor Parameters      :', predictors)
    print('Number of Decoder Parameters        :', decoder)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = STYLERLoss().to(device)
    DATLoss = DomainAdversarialTrainingLoss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path())
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    vocoder = utils.get_vocoder()

    # Init logger
    log_path = hp.log_path()
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path()
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i*hp.batch_size + j + args.restore_step + \
                    epoch*len(loader)*hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                mel_aug = torch.from_numpy(
                    data_of_batch["mel_aug"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                f0_norm = torch.from_numpy(
                    data_of_batch["f0_norm"]).float().to(device)
                f0_norm_aug = torch.from_numpy(
                    data_of_batch["f0_norm_aug"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                energy_input = torch.from_numpy(
                    data_of_batch["energy_input"]).float().to(device)
                energy_input_aug = torch.from_numpy(
                    data_of_batch["energy_input_aug"]).float().to(device)
                speaker_embed = torch.from_numpy(
                    data_of_batch["speaker_embed"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_outputs, mel_postnet_outputs, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _, aug_posteriors = model(
                    text,
                    mel_target,
                    mel_aug,
                    f0_norm,
                    energy_input,
                    src_len,
                    mel_len,
                    D,
                    f0,
                    energy,
                    max_src_len,
                    max_mel_len,
                    speaker_embed=speaker_embed)

                # Cal Loss Clean
                mel_output, mel_postnet_output = mel_outputs[
                    0], mel_postnet_outputs[0]
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss, classifier_loss_a = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask, src_len, mel_len,\
                        aug_posteriors, torch.zeros(mel_target.size(0)).long().to(device))

                # Cal Loss Noisy
                mel_output_noisy, mel_postnet_output_noisy = mel_outputs[
                    1], mel_postnet_outputs[1]
                mel_noisy_loss, mel_postnet_noisy_loss = Loss.cal_mel_loss(
                    mel_output_noisy, mel_postnet_output_noisy, mel_aug,
                    ~mel_mask)

                # Forward DAT
                enc_cat = model.module.style_modeling.style_encoder.encoder_input_cat(
                    mel_aug, f0_norm_aug, energy_input_aug, mel_aug)
                duration_encoding, pitch_encoding, energy_encoding, _ = model.module.style_modeling.style_encoder.audio_encoder(
                    enc_cat, mel_len, src_len, mask=None)
                aug_posterior_d = model.module.style_modeling.augmentation_classifier_d(
                    duration_encoding)
                aug_posterior_p = model.module.style_modeling.augmentation_classifier_p(
                    pitch_encoding)
                aug_posterior_e = model.module.style_modeling.augmentation_classifier_e(
                    energy_encoding)

                # Cal Loss DAT
                classifier_loss_a_dat = DATLoss(
                    (aug_posterior_d, aug_posterior_p, aug_posterior_e),
                    torch.ones(mel_target.size(0)).long().to(device))

                # Total loss
                total_loss = mel_loss + mel_postnet_loss + mel_noisy_loss + mel_postnet_noisy_loss + d_loss + f_loss + e_loss\
                    + hp.dat_weight*(classifier_loss_a + classifier_loss_a_dat)

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                m_n_l = mel_noisy_loss.item()
                m_p_n_l = mel_postnet_noisy_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                cl_a = classifier_loss_a.item()
                cl_a_dat = classifier_loss_a_dat.item()

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step == 1 or current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_noisy_loss', m_n_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_postnet_noisy_loss',
                                            m_p_n_l, current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)
                    train_logger.add_scalar('Loss/dat_clean_loss', cl_a,
                                            current_step)
                    train_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat,
                                            current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step == 1 or current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_aug_torch = mel_aug[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_aug = mel_aug[0, :length].detach().cpu().transpose(
                        0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel_noisy_torch = mel_output_noisy[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_noisy = mel_output_noisy[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet_noisy_torch = mel_postnet_output_noisy[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_noisy = mel_postnet_output_noisy[
                        0, :length].detach().cpu().transpose(0, 1)
                    # Audio.tools.inv_mel_spec(mel, os.path.join(
                    #     synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "c")))
                    # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(
                    #     synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "c")))
                    # Audio.tools.inv_mel_spec(mel_noisy, os.path.join(
                    #     synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "n")))
                    # Audio.tools.inv_mel_spec(mel_postnet_noisy, os.path.join(
                    #     synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "n")))

                    wav_mel = utils.vocoder_infer(
                        mel_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_{}.wav'.format(current_step, "c",
                                                       hp.vocoder)))
                    wav_mel_postnet = utils.vocoder_infer(
                        mel_postnet_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_postnet_{}.wav'.format(
                                current_step, "c", hp.vocoder)))
                    wav_ground_truth = utils.vocoder_infer(
                        mel_target_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_ground-truth_{}.wav'.format(
                                current_step, "c", hp.vocoder)))
                    wav_mel_noisy = utils.vocoder_infer(
                        mel_noisy_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_{}.wav'.format(current_step, "n",
                                                       hp.vocoder)))
                    wav_mel_postnet_noisy = utils.vocoder_infer(
                        mel_postnet_noisy_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_postnet_{}.wav'.format(
                                current_step, "n", hp.vocoder)))
                    wav_aug = utils.vocoder_infer(
                        mel_aug_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_ground-truth_{}.wav'.format(
                                current_step, "n", hp.vocoder)))

                    # Model duration prediction
                    log_duration_output = log_duration_output[
                        0, :src_len[0].item()].detach().cpu()  # [seg_len]
                    log_duration_output = torch.clamp(torch.round(
                        torch.exp(log_duration_output) - hp.log_offset),
                                                      min=0).int()
                    model_duration = utils.get_alignment_2D(
                        log_duration_output).T  # [seg_len, mel_len]
                    model_duration = utils.plot_alignment([model_duration])

                    # Model mel prediction
                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()
                    mel_predicted = utils.plot_data(
                        [(mel_postnet.numpy(), f0_output, energy_output),
                         (mel_target.numpy(), f0, energy)], [
                             'Synthetized Spectrogram Clean',
                             'Ground-Truth Spectrogram'
                         ],
                        filename=os.path.join(
                            synth_path,
                            'step_{}_{}.png'.format(current_step, "c")))
                    mel_noisy_predicted = utils.plot_data(
                        [(mel_postnet_noisy.numpy(), f0_output, energy_output),
                         (mel_aug.numpy(), f0, energy)],
                        ['Synthetized Spectrogram Noisy', 'Aug Spectrogram'],
                        filename=os.path.join(
                            synth_path,
                            'step_{}_{}.png'.format(current_step, "n")))

                    # Normalize audio for tensorboard logger. See https://github.com/lanpa/tensorboardX/issues/511#issuecomment-537600045
                    wav_ground_truth = wav_ground_truth / max(wav_ground_truth)
                    wav_mel = wav_mel / max(wav_mel)
                    wav_mel_postnet = wav_mel_postnet / max(wav_mel_postnet)
                    wav_aug = wav_aug / max(wav_aug)
                    wav_mel_noisy = wav_mel_noisy / max(wav_mel_noisy)
                    wav_mel_postnet_noisy = wav_mel_postnet_noisy / max(
                        wav_mel_postnet_noisy)

                    train_logger.add_image("model_duration",
                                           model_duration,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_image("mel_predicted/Clean",
                                           mel_predicted,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_image("mel_predicted/Noisy",
                                           mel_noisy_predicted,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_audio("Clean/wav_ground_truth",
                                           wav_ground_truth,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Clean/wav_mel",
                                           wav_mel,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Clean/wav_mel_postnet",
                                           wav_mel_postnet,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_aug",
                                           wav_aug,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_mel_noisy",
                                           wav_mel_noisy,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_mel_postnet_noisy",
                                           wav_mel_postnet_noisy,
                                           current_step,
                                           sample_rate=hp.sampling_rate)

                if current_step == 1 or current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, cl_a, cl_a_dat, m_l, m_p_l, m_n_l, m_p_n_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l + m_n_l + m_p_n_l\
                            + hp.dat_weight*(cl_a + cl_a_dat)

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_noisy_loss', m_n_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_noisy_loss',
                                              m_p_n_l, current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)
                        val_logger.add_scalar('Loss/dat_clean_loss', cl_a,
                                              current_step)
                        val_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 10
0
def main(args):
    # Get device
    device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
#     torch.distributed.init_process_group(backend='nccl')
    
    # Define model
    model = nn.DataParallel(FastSpeech())
    model=model.cuda()
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech Parameters:', num_param)


    
    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    writer = SummaryWriter(log_dir='log/'+current_time)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=hp.learning_rate)

    # Load checkpoint if exists
    try:
        checkpoint_in=open(os.path.join(hp.checkpoint_path, 'checkpoint.txt'),'r')
        args.restore_step=int(checkpoint_in.readline().strip())
        checkpoint_in.close()
        checkpoint = torch.load(os.path.join(
            hp.checkpoint_path,  'checkpoint_%08d.pth'%args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)
    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss
    
    
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    
#     model = torch.nn.parallel.DistributedDataParallel(model) # device_ids will include all GPU devices by default
    print('Start')
#     model = model.train()
    
    for epoch in range(hp.epochs):
        # Get Training Loader
        print('Start Epoch %d'%epoch)
        training_loader = DataLoader(dataset,
                                     batch_size=hp.batch_size**2,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     drop_last=True,
                                     num_workers=0)
        total_step = hp.epochs * len(training_loader) * hp.batch_size
        
        m_l=0.0
        m_p_l=0.0
        t_l=0.0
        for i, batchs in enumerate(training_loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + \
                    epoch * len(training_loader)*hp.batch_size + 1

                model.zero_grad()

                # Get Data
                condition1 = torch.from_numpy(
                    data_of_batch["condition1"]).long().to(device)#.fill_(1)
                condition2 = torch.from_numpy(
                    data_of_batch["condition2"]).long().to(device)#.fill_(1)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                max_mel_len = data_of_batch["mel_max_len"]

                # Forward
                mel_output, mel_postnet_output = model(src_seq1=condition1,src_seq2=condition2,
                                                                                  src_pos=src_pos,
                                                                                  mel_pos=mel_pos,
                                                                                  mel_max_length=max_mel_len,
                                                                                  length_target=D)

#                 print(mel_target.size())
#                 print(mel_output)
#                 print(mel_postnet_output)

                # Cal Loss
                mel_loss, mel_postnet_loss= fastspeech_loss(mel_output,                                                                            mel_postnet_output,mel_target,)
                total_loss =mel_loss+mel_postnet_loss 

                # Logger
                t_l += np.log(total_loss.item())
                m_l += np.log(mel_loss.item())
                m_p_l += np.log(mel_postnet_loss.item())
                
#                 assert np.isnan(t_l)==False

                with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss:
                    f_total_loss.write(str(t_l)+"\n")

                with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l)+"\n")

                with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l)+"\n")

                # Backward
                if not np.isnan(t_l):
                    total_loss.backward()
                    optimizer.step()
                else:
                    print(condition1,condition2,D)

                # Clipping gradients to avoid gradient explosion
#                 nn.utils.clip_grad_norm_(
#                     model.parameters(), hp.grad_clip_thresh)

                # Update weights
#                 if args.frozen_learning_rate:
#                     scheduled_optim.step_and_update_lr_frozen(
#                         args.learning_rate_frozen)
#                 else:
#                     scheduled_optim.step_and_update_lr()


                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()
                    
                    str1 = "Epoch[{}/{}] Step[{}/{}]:".format(
                        epoch+1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss:{:.4f} MelPostNet Loss:{:.4f}".format(
                        m_l/hp.log_step, m_p_l/hp.log_step)
                    str3 = "LR:{:.6f}".format(
                        hp.learning_rate)
                    str4 = "T: {:.1f}s ETR:{:.1f}s.".format(
                        (Now-Start), (total_step-current_step)*np.mean(Time))
                    
                    writer.add_scalar('Mel Loss', m_l/hp.log_step, current_step)
                    writer.add_scalar('MelPostNet Loss', m_p_l/hp.log_step, current_step)
                    writer.add_scalar('Loss', t_l/hp.log_step, current_step)
                    writer.add_scalar('lreaning rate', hp.learning_rate, current_step)
                    
                    
                    print('\r' + str1+' '+str2+' '+str3+' '+str4,end='')

                    if hp.gpu_log_step!=-1 and current_step%hp.gpu_log_step==0:
                        os.system('nvidia-smi')

                    with open(os.path.join("logger", "logger.txt"), "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")
                    m_l=0.0
                    m_p_l=0.0
                    t_l=0.0
                if current_step % hp.fig_step==0:
                  f=plt.figure()
                  plt.matshow(mel_postnet_output[0].cpu().detach().numpy())
                  plt.savefig('out_predicted_postnet.png')
                  writer.add_figure('predict',f,current_step)
                  plt.cla() 
                  f=plt.figure()
                  plt.matshow(mel_target[0].cpu().detach().numpy())
                  plt.savefig('out_target.png')
                  writer.add_figure('target',f,current_step)
                  plt.cla() 
                  plt.close("all")
                  

                if current_step % (hp.save_step) == 0:
                    print("save model at step %d ..." % current_step,end='')
                    torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
                    )}, os.path.join(hp.checkpoint_path, 'checkpoint_%08d.pth'%current_step))
                    checkpoint_out=open(os.path.join(hp.checkpoint_path, 'checkpoint.txt'),'w')
                    checkpoint_out.write(str(current_step))
                    checkpoint_out.close()

                    

#                     os.system('python savefig.py')

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(
                        Time, [i for i in range(len(Time))], axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 11
0
    def inference(self, inputs, speaker_id):
        embedded_inputs = self.embedding(inputs)
        speaker_embeddings = self.speaker_embedding(speaker_id)
        speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(
            embedded_inputs.size(0), embedded_inputs.size(
                1), speaker_embeddings.size(1)
        )
        embedded_inputs = torch.cat([embedded_inputs, speaker_embeddings], 2).transpose(
            1, 2
        )

        encoder_outputs = self.encoder.inference(embedded_inputs)
        mel_outputs, gate_outputs, alignments = self.decoder.inference(
            encoder_outputs)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        outputs = self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]
        )

        return outputs


if __name__ == "__main__":
    # Test
    model = Tacotron2(hparams)
    print(model)
    print(get_param_num(model))
Ejemplo n.º 12
0
def main(args, device):
    hp.checkpoint_path = os.path.join(hp.root_path, args.name_task, "ckpt",
                                      hp.dataset)
    hp.synth_path = os.path.join(hp.root_path, args.name_task, "synth",
                                 hp.dataset)
    hp.eval_path = os.path.join(hp.root_path, args.name_task, "eval",
                                hp.dataset)
    hp.log_path = os.path.join(hp.root_path, args.name_task, "log", hp.dataset)
    hp.test_path = os.path.join(hp.root_path, args.name_task, 'results')
    # Define model
    print("Use FastSpeech")
    model = nn.DataParallel(FastSpeech()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of TTS Parameters:', num_param)
    # Get buffer
    print("Load data to buffer")
    buffer = get_data_to_buffer('train.txt')

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp1.decoder_dim,
                                     hp1.n_warm_up_step, args.restore_step)
    fastspeech_loss = DNNLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(args.restore_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        checkpoint_path = os.path.join(hp.checkpoint_path)
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Init logger
    if not os.path.exists(hp.log_path):
        os.makedirs(hp.log_path)

    waveglow = None
    if hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()

    # Get dataset
    dataset = BufferDataset(buffer)

    # Get Training Loader
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_expand_size *
                                 hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn_tensor,
                                 drop_last=True,
                                 num_workers=0)
    total_step = hp.epochs * len(training_loader) * hp.batch_expand_size

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        for i, batchs in enumerate(training_loader):
            # real batch start here
            for j, db in enumerate(batchs):
                print(len(batchs), len(db))
                start_time = time.perf_counter()

                current_step = i * hp.batch_expand_size + j + args.restore_step + \
                    epoch * len(training_loader) * hp.batch_expand_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = db["text"].long().to(device)
                mel_target = db["mel_target"].float().to(device)
                duration = db["duration"].int().to(device)
                mel_pos = db["mel_pos"].long().to(device)
                src_pos = db["src_pos"].long().to(device)
                max_mel_len = db["mel_max_len"]

                # Forward
                mel_output, mel_postnet_output, duration_predictor_output = model(
                    character,
                    src_pos,
                    mel_pos=mel_pos,
                    mel_max_length=max_mel_len,
                    length_target=duration)

                # Cal Loss
                mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(
                    mel_output, mel_postnet_output, duration_predictor_output,
                    mel_target, duration)
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = duration_loss.item()

                with open(os.path.join(hp.log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join(hp.log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join(hp.log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join(hp.log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp1.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l)
                    str3 = "Current Learning Rate is {:.6f}.".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join(hp.log_path, "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                # if current_step % 10 == 0:
                #     model.eval()
                #     with torch.no_grad():
                #         t_l, d_l, mel_l, mel_p_l = evaluate(
                #             model, current_step, vocoder=waveglow)

                #         str0 = 'Validating'
                #         str1 = "\tEpoch [{}/{}], Step [{}/{}]:".format(
                #             epoch + 1, hp.epochs, current_step, total_step)
                #         str2 = "\tTotal Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                #             t_l, m_l, m_p_l, d_l)
                #         print(str0)
                #         print(str1)
                #         print(str2)
                #     model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 13
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    wave_glow = utils.get_WaveGlow()

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_len = max(data_of_batch["mel_len"]).astype(np.int16)

                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D, f0, energy)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    duration_output, D, f0_output, f0, energy_output, energy,
                    mel_output, mel_postnet_output, mel_target, src_len,
                    mel_len)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    logger.add_scalars('Loss/total_loss', {'training': t_l},
                                       current_step)
                    logger.add_scalars('Loss/mel_loss', {'training': m_l},
                                       current_step)
                    logger.add_scalars('Loss/mel_postnet_loss',
                                       {'training': m_p_l}, current_step)
                    logger.add_scalars('Loss/duration_loss', {'training': d_l},
                                       current_step)
                    logger.add_scalars('Loss/F0_loss', {'training': f_l},
                                       current_step)
                    logger.add_scalars('Loss/energy_loss', {'training': e_l},
                                       current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    Audio.tools.inv_mel_spec(
                        mel,
                        os.path.join(
                            synth_path,
                            "step_{}_griffin_lim.wav".format(current_step)))
                    Audio.tools.inv_mel_spec(
                        mel_postnet,
                        os.path.join(
                            synth_path,
                            "step_{}_postnet_griffin_lim.wav".format(
                                current_step)))
                    waveglow.inference.inference(
                        mel_torch, wave_glow,
                        os.path.join(
                            synth_path,
                            "step_{}_waveglow.wav".format(current_step)))
                    waveglow.inference.inference(
                        mel_postnet_torch, wave_glow,
                        os.path.join(
                            synth_path, "step_{}_postnet_waveglow.wav".format(
                                current_step)))
                    waveglow.inference.inference(
                        mel_target_torch, wave_glow,
                        os.path.join(
                            synth_path,
                            "step_{}_ground-truth_waveglow.wav".format(
                                current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        logger.add_scalars('Loss/total_loss',
                                           {'validation': t_l}, current_step)
                        logger.add_scalars('Loss/mel_loss',
                                           {'validation': m_l}, current_step)
                        logger.add_scalars('Loss/mel_postnet_loss',
                                           {'validation': m_p_l}, current_step)
                        logger.add_scalars('Loss/duration_loss',
                                           {'validation': d_l}, current_step)
                        logger.add_scalars('Loss/F0_loss', {'validation': f_l},
                                           current_step)
                        logger.add_scalars('Loss/energy_loss',
                                           {'validation': e_l}, current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 14
0
        for i, ind in enumerate(sorted_length_index):
            x[ind] = x_sorted[i]
        x = torch.stack(x).to(device)

        x = self.post_linear(x)
        mel, predicted, _ = self.mask(x, predicted, cemb, length_c, length_mel,
                                      max_c_len, max_mel_len)

        return mel, predicted, length_mel, history

    def inference(self, character, alpha=1.0):
        x = self.embeddings(character)

        self.pre_gru.flatten_parameters()
        x, _ = self.pre_gru(x)

        x = self.pre_linear(x)
        x = self.LR(x, alpha=alpha)

        self.post_gru.flatten_parameters()
        x, _ = self.post_gru(x)
        mel = self.post_linear(x)

        return mel


if __name__ == "__main__":
    # Test
    num_param = utils.get_param_num(LightSpeech())
    print(num_param)
Ejemplo n.º 15
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    model = nn.DataParallel(FastSpeech()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech Parameters:', num_param)

    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    try:
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Define Some Information
    Time = np.array([])
    Start = time.clock()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        # Get Training Loader
        training_loader = DataLoader(dataset,
                                     batch_size=hp.batch_size**2,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     drop_last=True,
                                     num_workers=0)
        total_step = hp.epochs * len(training_loader) * hp.batch_size

        for i, batchs in enumerate(training_loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.clock()

                current_step = i * hp.batch_size + j + args.restore_step + \
                    epoch * len(training_loader)*hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                max_mel_len = data_of_batch["mel_max_len"]

                # Forward
                mel_output, mel_postnet_output, duration_predictor_output = model(
                    character,
                    src_pos,
                    mel_pos=mel_pos,
                    mel_max_length=max_mel_len,
                    length_target=D)

                # print(mel_target.size())
                # print(mel_output.size())

                # Cal Loss
                mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(
                    mel_output, mel_postnet_output, duration_predictor_output,
                    mel_target, D)
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = duration_loss.item()

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join("logger", "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join("logger", "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join("logger", "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.clock()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        m_l, m_p_l, d_l)
                    str3 = "Current Learning Rate is {:.6f}.".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                end_time = time.clock()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 16
0
def main(args):
    torch.manual_seed(0)
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    if hp.use_spk_embed:
        n_pkers = len(dataset.spk_table.keys())
        model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device)
    else:
        model = nn.DataParallel(FastSpeech2()).to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        #melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Init evaluation directory
    eval_path = hp.eval_path
    if not os.path.exists(eval_path):
        os.makedirs(eval_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1
                print("step : {}".format(current_step), end='\r', flush=True)

                ### Get Data ###
                if hp.use_spk_embed:
                    spk_ids = torch.tensor(data_of_batch["spk_ids"]).to(
                        torch.int64).to(device)
                else:
                    spk_ids = None
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                ### Forward ###
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len, spk_ids)

                ### Cal Loss ###
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss

                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                ## Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                ### Update weights ###
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                ### Print ###
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                ### Save model ###
                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                ### Synth ###
                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    print("step: {} , length {}, {}".format(
                        current_step, length, mel_len))
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)

                    spk_id = dataset.inv_spk_table[int(spk_ids[0])]
                    if hp.vocoder == 'melgan':
                        vocoder.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    elif hp.vocoder == 'waveglow':
                        vocoder.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                ### Evaluation ###
                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)

                    model.train()

                ### Time ###
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Ejemplo n.º 17
0
def main(args):
    torch.manual_seed(0)

    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(
        dataset,
        batch_size=hp.batch_size**2,
        shuffle=False,
        collate_fn=dataset.collate_fn,
        drop_last=False,
        num_workers=0,
    )

    # Get model
    model = get_FastSpeech2(args.step).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Init directories
    if not os.path.exists(hp.logger_path):
        os.makedirs(hp.logger_path)
    if not os.path.exists(hp.eval_path):
        os.makedirs(hp.eval_path)

    # Get loss function
    Loss = FastSpeech2Loss().to(device)
    print("Loss Function Defined.")

    # Load vocoder
    wave_glow = utils.get_WaveGlow()

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            mel_pos = torch.from_numpy(
                data_of_batch["mel_pos"]).long().to(device)
            src_pos = torch.from_numpy(
                data_of_batch["src_pos"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_len = max(data_of_batch["mel_len"]).astype(np.int16)

            with torch.no_grad():
                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    duration_output, D, f0_output, f0, energy_output, energy,
                    mel_output, mel_postnet_output, mel_target)

                d_l.append(d_loss.item())
                f_l.append(f_loss.item())
                e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                for k in range(len(mel_target)):
                    length = mel_len[k]

                    mel_target_torch = mel_target[k:k + 1, :length].transpose(
                        1, 2).detach()
                    mel_target_ = mel_target[k, :length].cpu().transpose(
                        0, 1).detach()
                    waveglow.inference.inference(
                        mel_target_torch, wave_glow,
                        os.path.join(
                            hp.eval_path,
                            'ground-truth_{}_waveglow.wav'.format(idx)))

                    mel_postnet_torch = mel_postnet_output[
                        k:k + 1, :length].transpose(1, 2).detach()
                    mel_postnet = mel_postnet_output[
                        k, :length].cpu().transpose(0, 1).detach()
                    waveglow.inference.inference(
                        mel_postnet_torch, wave_glow,
                        os.path.join(hp.eval_path,
                                     'eval_{}_waveglow.wav'.format(idx)))

                    utils.plot_data([
                        (mel_postnet.numpy(), None, None),
                        (mel_target_.numpy(), None, None)
                    ], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        hp.eval_path,
                                        'eval_{}.png'.format(idx)))
                    idx += 1

            current_step += 1

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(args.step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.logger_path, "eval.txt"), "a") as f_logger:
        f_logger.write(str1 + "\n")
        f_logger.write(str2 + "\n")
        f_logger.write(str3 + "\n")
        f_logger.write(str4 + "\n")
        f_logger.write(str5 + "\n")
        f_logger.write(str6 + "\n")
        f_logger.write("\n")
Ejemplo n.º 18
0
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")

    return d_l, f_l, e_l, mel_l, mel_p_l

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--step', type=int, default=30000)
    args = parser.parse_args()
    
    # Get model
    model = get_FastSpeech2(args.step).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)
    
    # Load vocoder
    if hp.vocoder == 'melgan':
        vocoder = utils.get_melgan()
    elif hp.vocoder == 'waveglow':
        vocoder = utils.get_waveglow()
    vocoder.to(device)
        
    # Init directories
    if not os.path.exists(hp.log_path):
        os.makedirs(hp.log_path)
    if not os.path.exists(hp.eval_path):
        os.makedirs(hp.eval_path)
    
Ejemplo n.º 19
0
def main(args):
    # Get device
    #     device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'
    # Define model
    model = FastSpeech().to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech Parameters:', num_param)

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    writer = SummaryWriter(log_dir='log/' + current_time)

    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)

    # Load checkpoint if exists
    try:
        checkpoint_in = open(
            os.path.join(hp.checkpoint_path, 'checkpoint.txt'), 'r')
        args.restore_step = int(checkpoint_in.readline().strip())
        checkpoint_in.close()
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%08d.pth' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:

        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)
    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss

    scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    t_l = 0.0
    for epoch in range(hp.epochs):
        # Get Training Loader
        training_loader = DataLoader(dataset,
                                     batch_size=hp.batch_size**2,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     drop_last=True,
                                     num_workers=0)
        total_step = hp.epochs * len(training_loader) * hp.batch_size

        for i, batchs in enumerate(training_loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + \
                    epoch * len(training_loader)*hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                condition1 = torch.from_numpy(
                    data_of_batch["condition1"]).long().to(device)
                condition2 = torch.from_numpy(
                    data_of_batch["condition2"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).long().to(device)
                norm_f0 = torch.from_numpy(
                    data_of_batch["norm_f0"]).long().to(device)
                mel_in = torch.from_numpy(
                    data_of_batch["mel_in"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                lens = data_of_batch["lens"]
                max_mel_len = data_of_batch["mel_max_len"]
                #                 print(condition1,condition2)
                # Forward
                mel_output = model(src_seq1=condition1,
                                   src_seq2=condition2,
                                   mel_in=mel_in,
                                   src_pos=src_pos,
                                   mel_pos=mel_pos,
                                   mel_max_length=max_mel_len,
                                   length_target=D)

                #                 print(mel_target.size())
                #                 print(mel_output)
                #                 print(mel_postnet_output)

                # Cal Loss
                #                 mel_loss, mel_postnet_loss= fastspeech_loss(mel_output,                                                                            mel_postnet_output,mel_target,)
                #                 print(mel_output.shape,mel_target.shape)
                Loss = torch.nn.CrossEntropyLoss()
                predict = mel_output.transpose(1, 2)
                target1 = mel_target.long().squeeze()
                target2 = norm_f0.long().squeeze()
                target = ((target1 + target2) / 2).long().squeeze()

                #                 print(predict.shape,target.shape)
                #                 print(target.float().mean())
                losses = []
                #                 print(lens,target)
                for index in range(predict.shape[0]):
                    #                     print(predict[i,:,:lens[i]].shape,target[i,:lens[i]].shape)
                    losses.append(
                        Loss(predict[index, :, :lens[index]].transpose(0, 1),
                             target[index, :lens[index]]).unsqueeze(0))


#                     losses.append(0.5*Loss(predict[index,:,:lens[index]].transpose(0,1),target2[index,:lens[index]]).unsqueeze(0))
                total_loss = torch.cat(losses).mean()
                t_l += total_loss.item()

                #                 assert np.isnan(t_l)==False

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                # Backward
                if not np.isnan(t_l):
                    total_loss.backward()
                else:
                    print(condition1, condition2, D)

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}] Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Loss:{:.4f} ".format(t_l / hp.log_step)

                    str3 = "LR:{:.6f}".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "T: {:.1f}s ETR:{:.1f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print('\r' + str1 + ' ' + str2 + ' ' + str3 + ' ' + str4,
                          end='')
                    writer.add_scalar('loss', t_l / hp.log_step, current_step)
                    writer.add_scalar('lreaning rate',
                                      scheduled_optim.get_learning_rate(),
                                      current_step)

                    if hp.gpu_log_step != -1 and current_step % hp.gpu_log_step == 0:
                        os.system('nvidia-smi')

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                    t_l = 0.0

                if current_step % hp.fig_step == 0 or current_step == 20:
                    f = plt.figure()
                    plt.matshow(mel_output[0].cpu().detach().numpy())
                    plt.savefig('out_predicted.png')
                    plt.matshow(
                        F.softmax(predict, dim=1).transpose(
                            1, 2)[0].cpu().detach().numpy())
                    plt.savefig('out_predicted_softmax.png')
                    writer.add_figure('predict', f, current_step)
                    plt.cla()

                    f = plt.figure(figsize=(8, 6))
                    #                   plt.matshow(mel_target[0].cpu().detach().numpy())
                    #                   x=np.arange(mel_target.shape[1])
                    #                   y=sample_from_discretized_mix_logistic(mel_output.transpose(1,2)).cpu().detach().numpy()[0]
                    #                   plt.plot(x,y)
                    sample = []
                    p = F.softmax(predict, dim=1).transpose(
                        1, 2)[0].detach().cpu().numpy()
                    for index in range(p.shape[0]):
                        sample.append(np.random.choice(200, 1, p=p[index]))
                    sample = np.array(sample)
                    plt.plot(np.arange(sample.shape[0]),
                             sample,
                             color='grey',
                             linewidth='1')
                    for index in range(D.shape[1]):
                        x = np.arange(D[0][index].cpu().numpy()
                                      ) + D[0][:index].cpu().numpy().sum()
                        y = np.arange(D[0][index].detach().cpu().numpy())
                        if condition2[0][index].cpu().numpy() != 0:
                            y.fill(
                                (condition2[0][index].cpu().numpy() - 40.0) *
                                5)
                            plt.plot(x, y, color='blue')
                    plt.plot(np.arange(target.shape[1]),
                             target[0].squeeze().detach().cpu().numpy(),
                             color='red',
                             linewidth='1')
                    plt.savefig('out_target.png', dpi=300)
                    writer.add_figure('target', f, current_step)
                    plt.cla()

                    plt.close("all")

                if current_step % (hp.save_step) == 0:
                    print("save model at step %d ..." % current_step, end='')
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%08d.pth' % current_step))
                    checkpoint_out = open(
                        os.path.join(hp.checkpoint_path, 'checkpoint.txt'),
                        'w')
                    checkpoint_out.write(str(current_step))
                    checkpoint_out.close()

                    #                     os.system('python savefig.py')

                    print('save completed')

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)