예제 #1
0
파일: train.py 프로젝트: ssumin6/fastspeech
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    model = nn.DataParallel(FastSpeech()).to(device)
    #tacotron2 = get_tacotron2()
    print("FastSpeech and Tacotron2 Have Been Defined")
    num_param = sum(param.numel() for param in model.parameters())
    print('Number of FastSpeech Parameters:', num_param)

    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss
    optimizer = torch.optim.Adam(
        model.parameters(), betas=(0.9, 0.98), eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer,
                                     hp.word_vec_dim,
                                     hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Get training loader
    print("Get Training Loader")
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn,
                                 drop_last=True,
                                 num_workers=cpu_count())

    try:
        checkpoint = torch.load(os.path.join(
            hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n------Model Restored at Step %d------\n" % args.restore_step)

    except:
        print("\n------Start New Training------\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Training
    model = model.train()

    total_step = hp.epochs * len(training_loader)
    Time = np.array(list())
    Start = time.clock()

    summary = SummaryWriter()

    for epoch in range(hp.epochs):
        for i, data_of_batch in enumerate(training_loader):
        
        
            start_time = time.clock()

            current_step = i + args.restore_step + \
                epoch * len(training_loader) + 1

            # Init
            scheduled_optim.zero_grad()

            if not hp.pre_target:
                # Prepare Data
                src_seq = data_of_batch["texts"]
                src_pos = data_of_batch["pos"]
                mel_tgt = data_of_batch["mels"]

                src_seq = torch.from_numpy(src_seq).long().to(device)
                src_pos = torch.from_numpy(src_pos).long().to(device)
                mel_tgt = torch.from_numpy(mel_tgt).float().to(device)
                alignment_target = get_alignment(
                    src_seq, tacotron2).float().to(device)
                # For Data Parallel
                mel_max_len = mel_tgt.size(1)
            else:
                # Prepare Data
                src_seq = data_of_batch["texts"]
                src_pos = data_of_batch["pos"]
                mel_tgt = data_of_batch["mels"]
                alignment_target = data_of_batch["alignment"]
               # print(alignment_target)
             #   print(alignment_target.shape)
             #   print(mel_tgt.shape)
             #   print(src_seq.shape)
            #    print(src_seq)
                src_seq = torch.from_numpy(src_seq).long().to(device)
                src_pos = torch.from_numpy(src_pos).long().to(device)
                mel_tgt = torch.from_numpy(mel_tgt).float().to(device)
                alignment_target = torch.from_numpy(
                    alignment_target).float().to(device)
                # For Data Parallel
                mel_max_len = mel_tgt.size(1)
            #    print(alignment_target.shape)
            # Forward
            mel_output, mel_output_postnet, duration_predictor_output = model(
                src_seq, src_pos,
                mel_max_length=mel_max_len,
                length_target=alignment_target)

            # Cal Loss
            mel_loss, mel_postnet_loss, duration_predictor_loss = fastspeech_loss(
                mel_output, mel_output_postnet, duration_predictor_output, mel_tgt, alignment_target)
            total_loss = mel_loss + mel_postnet_loss + duration_predictor_loss

            # Logger
            t_l = total_loss.item()
            m_l = mel_loss.item()
            m_p_l = mel_postnet_loss.item()
            d_p_l = duration_predictor_loss.item()

            with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss:
                f_total_loss.write(str(t_l)+"\n")

            with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss:
                f_mel_loss.write(str(m_l)+"\n")

            with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                f_mel_postnet_loss.write(str(m_p_l)+"\n")

            with open(os.path.join("logger", "duration_predictor_loss.txt"), "a") as f_d_p_loss:
                f_d_p_loss.write(str(d_p_l)+"\n")

            # Backward
            total_loss.backward()

            # Clipping gradients to avoid gradient explosion
            nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh)

            # Update weights
            if args.frozen_learning_rate:
                scheduled_optim.step_and_update_lr_frozen(
                    args.learning_rate_frozen)
            else:
                scheduled_optim.step_and_update_lr()

            # Print
            if current_step % hp.log_step == 0:
                Now = time.clock()

                str1 = "Epoch [{}/{}], Step [{}/{}], Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f};".format(
                    epoch+1, hp.epochs, current_step, total_step, mel_loss.item(), mel_postnet_loss.item())
                str2 = "Duration Predictor Loss: {:.4f}, Total Loss: {:.4f}.".format(
                    duration_predictor_loss.item(), total_loss.item())
                str3 = "Current Learning Rate is {:.6f}.".format(
                    scheduled_optim.get_learning_rate())
                str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                    (Now-Start), (total_step-current_step)*np.mean(Time))

                print("\n" + str1)
                print(str2)
                print(str3)
                print(str4)

                with open(os.path.join("logger", "logger.txt"), "a") as f_logger:
                    f_logger.write(str1 + "\n")
                    f_logger.write(str2 + "\n")
                    f_logger.write(str3 + "\n")
                    f_logger.write(str4 + "\n")
                    f_logger.write("\n")

                summary.add_scalar('loss', total_loss.item(), current_step)

            if current_step % hp.save_step == 0:
                torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
                )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step))
                print("save model at step %d ..." % current_step)

            end_time = time.clock()
            Time = np.append(Time, end_time - start_time)
            if len(Time) == hp.clear_Time:
                temp_value = np.mean(Time)
                Time = np.delete(
                    Time, [i for i in range(len(Time))], axis=None)
                Time = np.append(Time, temp_value)
예제 #2
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    model = nn.DataParallel(FastSpeech()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech Parameters:', num_param)

    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    try:
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Define Some Information
    Time = np.array([])
    Start = time.clock()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        # Get Training Loader
        training_loader = DataLoader(dataset,
                                     batch_size=hp.batch_size**2,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     drop_last=True,
                                     num_workers=0)
        total_step = hp.epochs * len(training_loader) * hp.batch_size

        for i, batchs in enumerate(training_loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.clock()

                current_step = i * hp.batch_size + j + args.restore_step + \
                    epoch * len(training_loader)*hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                max_mel_len = data_of_batch["mel_max_len"]

                # Forward
                mel_output, mel_postnet_output, duration_predictor_output = model(
                    character,
                    src_pos,
                    mel_pos=mel_pos,
                    mel_max_length=max_mel_len,
                    length_target=D)

                # print(mel_target.size())
                # print(mel_output.size())

                # Cal Loss
                mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(
                    mel_output, mel_postnet_output, duration_predictor_output,
                    mel_target, D)
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = duration_loss.item()

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join("logger", "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join("logger", "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join("logger", "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.clock()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        m_l, m_p_l, d_l)
                    str3 = "Current Learning Rate is {:.6f}.".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                end_time = time.clock()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #3
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    print("Use Tacotron2")
    model = nn.DataParallel(Tacotron2(hp)).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of TTS Parameters:', num_param)
    # Get buffer
    print("Load data to buffer")
    buffer = get_data_to_buffer()

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_rnn_dim,
                                     hp.n_warm_up_step, args.restore_step)
    tts_loss = DNNLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    try:
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Get dataset
    dataset = BufferDataset(buffer)

    # Get Training Loader
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_expand_size *
                                 hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn_tensor,
                                 drop_last=True,
                                 num_workers=0)
    total_step = hp.epochs * len(training_loader) * hp.batch_expand_size

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        for i, batchs in enumerate(training_loader):
            # real batch start here
            for j, db in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_expand_size + j + args.restore_step + \
                    epoch * len(training_loader) * hp.batch_expand_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = db["text"].long().to(device)
                mel_target = db["mel_target"].float().to(device)
                mel_pos = db["mel_pos"].long().to(device)
                src_pos = db["src_pos"].long().to(device)
                max_mel_len = db["mel_max_len"]

                mel_target = mel_target.contiguous().transpose(1, 2)
                src_length = torch.max(src_pos, -1)[0]
                mel_length = torch.max(mel_pos, -1)[0]

                gate_target = mel_pos.eq(0).float()
                gate_target = gate_target[:, 1:]
                gate_target = F.pad(gate_target, (0, 1, 0, 0), value=1.)

                # Forward
                inputs = character, src_length, mel_target, max_mel_len, mel_length
                mel_output, mel_output_postnet, gate_output = model(inputs)

                # Cal Loss
                mel_loss, mel_postnet_loss, gate_loss \
                    = tts_loss(mel_output, mel_output_postnet, gate_output,
                               mel_target, gate_target)
                total_loss = mel_loss + mel_postnet_loss + gate_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                g_l = gate_loss.item()

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join("logger", "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join("logger", "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join("logger", "gate_loss.txt"),
                          "a") as f_g_loss:
                    f_g_loss.write(str(g_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:"\
                        .format(epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Gate Loss: {:.4f};".format(
                        m_l, m_p_l, g_l)
                    str3 = "Current Learning Rate is {:.6f}."\
                        .format(scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s."\
                        .format((Now-Start), (total_step-current_step) * np.mean(Time, dtype=np.float32))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #4
0
def main(args):
    # Get device
    #     device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'
    # Define model
    model = FastSpeech().to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech Parameters:', num_param)

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    writer = SummaryWriter(log_dir='log/' + current_time)

    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)

    # Load checkpoint if exists
    try:
        checkpoint_in = open(
            os.path.join(hp.checkpoint_path, 'checkpoint.txt'), 'r')
        args.restore_step = int(checkpoint_in.readline().strip())
        checkpoint_in.close()
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%08d.pth' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:

        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)
    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss

    scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    t_l = 0.0
    for epoch in range(hp.epochs):
        # Get Training Loader
        training_loader = DataLoader(dataset,
                                     batch_size=hp.batch_size**2,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     drop_last=True,
                                     num_workers=0)
        total_step = hp.epochs * len(training_loader) * hp.batch_size

        for i, batchs in enumerate(training_loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + \
                    epoch * len(training_loader)*hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                condition1 = torch.from_numpy(
                    data_of_batch["condition1"]).long().to(device)
                condition2 = torch.from_numpy(
                    data_of_batch["condition2"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).long().to(device)
                norm_f0 = torch.from_numpy(
                    data_of_batch["norm_f0"]).long().to(device)
                mel_in = torch.from_numpy(
                    data_of_batch["mel_in"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                lens = data_of_batch["lens"]
                max_mel_len = data_of_batch["mel_max_len"]
                #                 print(condition1,condition2)
                # Forward
                mel_output = model(src_seq1=condition1,
                                   src_seq2=condition2,
                                   mel_in=mel_in,
                                   src_pos=src_pos,
                                   mel_pos=mel_pos,
                                   mel_max_length=max_mel_len,
                                   length_target=D)

                #                 print(mel_target.size())
                #                 print(mel_output)
                #                 print(mel_postnet_output)

                # Cal Loss
                #                 mel_loss, mel_postnet_loss= fastspeech_loss(mel_output,                                                                            mel_postnet_output,mel_target,)
                #                 print(mel_output.shape,mel_target.shape)
                Loss = torch.nn.CrossEntropyLoss()
                predict = mel_output.transpose(1, 2)
                target1 = mel_target.long().squeeze()
                target2 = norm_f0.long().squeeze()
                target = ((target1 + target2) / 2).long().squeeze()

                #                 print(predict.shape,target.shape)
                #                 print(target.float().mean())
                losses = []
                #                 print(lens,target)
                for index in range(predict.shape[0]):
                    #                     print(predict[i,:,:lens[i]].shape,target[i,:lens[i]].shape)
                    losses.append(
                        Loss(predict[index, :, :lens[index]].transpose(0, 1),
                             target[index, :lens[index]]).unsqueeze(0))


#                     losses.append(0.5*Loss(predict[index,:,:lens[index]].transpose(0,1),target2[index,:lens[index]]).unsqueeze(0))
                total_loss = torch.cat(losses).mean()
                t_l += total_loss.item()

                #                 assert np.isnan(t_l)==False

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                # Backward
                if not np.isnan(t_l):
                    total_loss.backward()
                else:
                    print(condition1, condition2, D)

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}] Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Loss:{:.4f} ".format(t_l / hp.log_step)

                    str3 = "LR:{:.6f}".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "T: {:.1f}s ETR:{:.1f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print('\r' + str1 + ' ' + str2 + ' ' + str3 + ' ' + str4,
                          end='')
                    writer.add_scalar('loss', t_l / hp.log_step, current_step)
                    writer.add_scalar('lreaning rate',
                                      scheduled_optim.get_learning_rate(),
                                      current_step)

                    if hp.gpu_log_step != -1 and current_step % hp.gpu_log_step == 0:
                        os.system('nvidia-smi')

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                    t_l = 0.0

                if current_step % hp.fig_step == 0 or current_step == 20:
                    f = plt.figure()
                    plt.matshow(mel_output[0].cpu().detach().numpy())
                    plt.savefig('out_predicted.png')
                    plt.matshow(
                        F.softmax(predict, dim=1).transpose(
                            1, 2)[0].cpu().detach().numpy())
                    plt.savefig('out_predicted_softmax.png')
                    writer.add_figure('predict', f, current_step)
                    plt.cla()

                    f = plt.figure(figsize=(8, 6))
                    #                   plt.matshow(mel_target[0].cpu().detach().numpy())
                    #                   x=np.arange(mel_target.shape[1])
                    #                   y=sample_from_discretized_mix_logistic(mel_output.transpose(1,2)).cpu().detach().numpy()[0]
                    #                   plt.plot(x,y)
                    sample = []
                    p = F.softmax(predict, dim=1).transpose(
                        1, 2)[0].detach().cpu().numpy()
                    for index in range(p.shape[0]):
                        sample.append(np.random.choice(200, 1, p=p[index]))
                    sample = np.array(sample)
                    plt.plot(np.arange(sample.shape[0]),
                             sample,
                             color='grey',
                             linewidth='1')
                    for index in range(D.shape[1]):
                        x = np.arange(D[0][index].cpu().numpy()
                                      ) + D[0][:index].cpu().numpy().sum()
                        y = np.arange(D[0][index].detach().cpu().numpy())
                        if condition2[0][index].cpu().numpy() != 0:
                            y.fill(
                                (condition2[0][index].cpu().numpy() - 40.0) *
                                5)
                            plt.plot(x, y, color='blue')
                    plt.plot(np.arange(target.shape[1]),
                             target[0].squeeze().detach().cpu().numpy(),
                             color='red',
                             linewidth='1')
                    plt.savefig('out_target.png', dpi=300)
                    writer.add_figure('target', f, current_step)
                    plt.cla()

                    plt.close("all")

                if current_step % (hp.save_step) == 0:
                    print("save model at step %d ..." % current_step, end='')
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%08d.pth' % current_step))
                    checkpoint_out = open(
                        os.path.join(hp.checkpoint_path, 'checkpoint.txt'),
                        'w')
                    checkpoint_out.write(str(current_step))
                    checkpoint_out.close()

                    #                     os.system('python savefig.py')

                    print('save completed')

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)