예제 #1
0
def main():
    opt = get_args()

    device = torch.device("cuda") if not opt.no_cuda else torch.device("cpu")

    random.seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)

    if not os.path.exists(opt.save_dir):
        os.makedirs(opt.save_dir)
    else:
        print("[Warning] Save directory already exists.")

    data = torch.load(opt.data_path)
    (training_data, validation_data, src_vocab_size, trg_vocab_size,
     src_idx2word, trg_idx2word) = prepare_dataloaders(data, opt)
    opt.src_vocab_size = src_vocab_size  # for save
    opt.trg_vocab_size = trg_vocab_size  # for save

    transformer = Transformer(
        src_vocab_size,
        trg_vocab_size,
        opt.d_model,
        opt.d_inner_hid,
        opt.d_k,
        opt.d_v,
        opt.n_head,
        opt.n_layers,
        opt.dropout,
    ).to(device)

    optimizer = ScheduledOptim(
        Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09),
        1.0,
        opt.d_model,
        opt.n_warmup_steps,
    )

    train(transformer, training_data, validation_data, optimizer, device, opt,
          src_idx2word, trg_idx2word)
예제 #2
0
def main(args):
    torch.manual_seed(0)
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    if hp.use_spk_embed:
        n_pkers = len(dataset.spk_table.keys())
        model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device)
    else:
        model = nn.DataParallel(FastSpeech2()).to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        #melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Init evaluation directory
    eval_path = hp.eval_path
    if not os.path.exists(eval_path):
        os.makedirs(eval_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1
                print("step : {}".format(current_step), end='\r', flush=True)

                ### Get Data ###
                if hp.use_spk_embed:
                    spk_ids = torch.tensor(data_of_batch["spk_ids"]).to(
                        torch.int64).to(device)
                else:
                    spk_ids = None
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                ### Forward ###
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len, spk_ids)

                ### Cal Loss ###
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss

                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                ## Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                ### Update weights ###
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                ### Print ###
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                ### Save model ###
                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                ### Synth ###
                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    print("step: {} , length {}, {}".format(
                        current_step, length, mel_len))
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)

                    spk_id = dataset.inv_spk_table[int(spk_ids[0])]
                    if hp.vocoder == 'melgan':
                        vocoder.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    elif hp.vocoder == 'waveglow':
                        vocoder.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                ### Evaluation ###
                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)

                    model.train()

                ### Time ###
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #3
0
파일: train.py 프로젝트: ssumin6/fastspeech
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    model = nn.DataParallel(FastSpeech()).to(device)
    #tacotron2 = get_tacotron2()
    print("FastSpeech and Tacotron2 Have Been Defined")
    num_param = sum(param.numel() for param in model.parameters())
    print('Number of FastSpeech Parameters:', num_param)

    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss
    optimizer = torch.optim.Adam(
        model.parameters(), betas=(0.9, 0.98), eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer,
                                     hp.word_vec_dim,
                                     hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Get training loader
    print("Get Training Loader")
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn,
                                 drop_last=True,
                                 num_workers=cpu_count())

    try:
        checkpoint = torch.load(os.path.join(
            hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n------Model Restored at Step %d------\n" % args.restore_step)

    except:
        print("\n------Start New Training------\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Training
    model = model.train()

    total_step = hp.epochs * len(training_loader)
    Time = np.array(list())
    Start = time.clock()

    summary = SummaryWriter()

    for epoch in range(hp.epochs):
        for i, data_of_batch in enumerate(training_loader):
        
        
            start_time = time.clock()

            current_step = i + args.restore_step + \
                epoch * len(training_loader) + 1

            # Init
            scheduled_optim.zero_grad()

            if not hp.pre_target:
                # Prepare Data
                src_seq = data_of_batch["texts"]
                src_pos = data_of_batch["pos"]
                mel_tgt = data_of_batch["mels"]

                src_seq = torch.from_numpy(src_seq).long().to(device)
                src_pos = torch.from_numpy(src_pos).long().to(device)
                mel_tgt = torch.from_numpy(mel_tgt).float().to(device)
                alignment_target = get_alignment(
                    src_seq, tacotron2).float().to(device)
                # For Data Parallel
                mel_max_len = mel_tgt.size(1)
            else:
                # Prepare Data
                src_seq = data_of_batch["texts"]
                src_pos = data_of_batch["pos"]
                mel_tgt = data_of_batch["mels"]
                alignment_target = data_of_batch["alignment"]
               # print(alignment_target)
             #   print(alignment_target.shape)
             #   print(mel_tgt.shape)
             #   print(src_seq.shape)
            #    print(src_seq)
                src_seq = torch.from_numpy(src_seq).long().to(device)
                src_pos = torch.from_numpy(src_pos).long().to(device)
                mel_tgt = torch.from_numpy(mel_tgt).float().to(device)
                alignment_target = torch.from_numpy(
                    alignment_target).float().to(device)
                # For Data Parallel
                mel_max_len = mel_tgt.size(1)
            #    print(alignment_target.shape)
            # Forward
            mel_output, mel_output_postnet, duration_predictor_output = model(
                src_seq, src_pos,
                mel_max_length=mel_max_len,
                length_target=alignment_target)

            # Cal Loss
            mel_loss, mel_postnet_loss, duration_predictor_loss = fastspeech_loss(
                mel_output, mel_output_postnet, duration_predictor_output, mel_tgt, alignment_target)
            total_loss = mel_loss + mel_postnet_loss + duration_predictor_loss

            # Logger
            t_l = total_loss.item()
            m_l = mel_loss.item()
            m_p_l = mel_postnet_loss.item()
            d_p_l = duration_predictor_loss.item()

            with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss:
                f_total_loss.write(str(t_l)+"\n")

            with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss:
                f_mel_loss.write(str(m_l)+"\n")

            with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                f_mel_postnet_loss.write(str(m_p_l)+"\n")

            with open(os.path.join("logger", "duration_predictor_loss.txt"), "a") as f_d_p_loss:
                f_d_p_loss.write(str(d_p_l)+"\n")

            # Backward
            total_loss.backward()

            # Clipping gradients to avoid gradient explosion
            nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh)

            # Update weights
            if args.frozen_learning_rate:
                scheduled_optim.step_and_update_lr_frozen(
                    args.learning_rate_frozen)
            else:
                scheduled_optim.step_and_update_lr()

            # Print
            if current_step % hp.log_step == 0:
                Now = time.clock()

                str1 = "Epoch [{}/{}], Step [{}/{}], Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f};".format(
                    epoch+1, hp.epochs, current_step, total_step, mel_loss.item(), mel_postnet_loss.item())
                str2 = "Duration Predictor Loss: {:.4f}, Total Loss: {:.4f}.".format(
                    duration_predictor_loss.item(), total_loss.item())
                str3 = "Current Learning Rate is {:.6f}.".format(
                    scheduled_optim.get_learning_rate())
                str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                    (Now-Start), (total_step-current_step)*np.mean(Time))

                print("\n" + str1)
                print(str2)
                print(str3)
                print(str4)

                with open(os.path.join("logger", "logger.txt"), "a") as f_logger:
                    f_logger.write(str1 + "\n")
                    f_logger.write(str2 + "\n")
                    f_logger.write(str3 + "\n")
                    f_logger.write(str4 + "\n")
                    f_logger.write("\n")

                summary.add_scalar('loss', total_loss.item(), current_step)

            if current_step % hp.save_step == 0:
                torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
                )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step))
                print("save model at step %d ..." % current_step)

            end_time = time.clock()
            Time = np.append(Time, end_time - start_time)
            if len(Time) == hp.clear_Time:
                temp_value = np.mean(Time)
                Time = np.delete(
                    Time, [i for i in range(len(Time))], axis=None)
                Time = np.append(Time, temp_value)
예제 #4
0
def main(args):
    torch.manual_seed(0)

    # Get device
    # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoaderX(dataset,
                         batch_size=hp.batch_size * 4,
                         shuffle=True,
                         collate_fn=dataset.collate_fn,
                         drop_last=True,
                         num_workers=8)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    #     model = FastSpeech2().to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    train_logger = SummaryWriter(log_dir='log/train/' + current_time)
    val_logger = SummaryWriter(log_dir='log/validation/' + current_time)
    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()
    current_step0 = 0
    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * len(batchs) + j + args.restore_step + \
                    epoch * len(loader)*len(batchs) + 1
                # Get Data
                condition = torch.from_numpy(
                    data_of_batch["condition"]).long().to(device)
                mel_refer = torch.from_numpy(
                    data_of_batch["mel_refer"]).float().to(device)
                if hp.vocoder == 'WORLD':
                    ap_target = torch.from_numpy(
                        data_of_batch["ap_target"]).float().to(device)
                    sp_target = torch.from_numpy(
                        data_of_batch["sp_target"]).float().to(device)
                else:
                    mel_target = torch.from_numpy(
                        data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                #print(D,log_D)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                if hp.vocoder == 'WORLD':
                    #                     print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape)
                    ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model(
                        condition, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        ap_output=ap_output,
                        sp_output=sp_output,
                        sp_postnet_output=sp_postnet_output,
                        ap_target=ap_target,
                        sp_target=sp_target,
                        src_mask=src_mask,
                        ap_mask=ap_mask,
                        sp_mask=sp_mask)
                    total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss
                else:
                    mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                        condition, mel_refer, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        log_D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        mel_output=mel_output,
                        mel_postnet_output=mel_postnet_output,
                        mel_target=mel_target,
                        src_mask=~src_mask,
                        mel_mask=~mel_mask)
                    total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                if hp.vocoder == 'WORLD':
                    ap_l = ap_loss.item()
                    sp_l = sp_loss.item()
                    sp_p_l = sp_postnet_loss.item()
                else:
                    m_l = mel_loss.item()
                    m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                #                 with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
                #                     f_total_loss.write(str(t_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
                #                     f_mel_loss.write(str(m_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                #                     f_mel_postnet_loss.write(str(m_p_l)+"\n")
                #                 with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
                #                     f_d_loss.write(str(d_l)+"\n")
                #                 with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
                #                     f_f_loss.write(str(f_l)+"\n")
                #                 with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
                #                     f_e_loss.write(str(e_l)+"\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}],Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    if hp.vocoder == 'WORLD':
                        str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l)
                    else:
                        str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "T:{:.1f}s,ETA:{:.1f}s.".format(
                        (Now - Start) / (current_step - current_step0),
                        (total_step - current_step) * np.mean(Time))

                    print("" + str1 + str2 + str3 + '')

                    #                     with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                    #                         f_log.write(str1 + "\n")
                    #                         f_log.write(str2 + "\n")
                    #                         f_log.write(str3 + "\n")
                    #                         f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    if hp.vocoder == 'WORLD':
                        train_logger.add_scalar('Loss/ap_loss', ap_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_loss', sp_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l,
                                                current_step)
                    else:
                        train_logger.add_scalar('Loss/mel_loss', m_l,
                                                current_step)
                        train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                                current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                if current_step % hp.save_step == 0 or current_step == 20:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0 or current_step == 5:
                    length = mel_len[0].item()

                    if hp.vocoder == 'WORLD':
                        ap_target_torch = ap_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        ap_torch = ap_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_target_torch = sp_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        sp_torch = sp_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_postnet_torch = sp_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                    else:
                        mel_target_torch = mel_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_target = mel_target[
                            0, :length].detach().cpu().transpose(0, 1)
                        mel_torch = mel_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        mel = mel_output[0, :length].detach().cpu().transpose(
                            0, 1)
                        mel_postnet_torch = mel_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_postnet = mel_postnet_output[
                            0, :length].detach().cpu().transpose(0, 1)
#                     Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step)))
#                     Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'WORLD':
                        #     ap=np.swapaxes(ap,0,1)
                        #     sp=np.swapaxes(sp,0,1)
                        wav = utils.world_infer(
                            np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1),
                            np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0,
                                        1), f0_output)
                        sf.write(
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)
                        wav = utils.world_infer(
                            np.swapaxes(ap_target_torch[0].cpu().numpy(), 0,
                                        1),
                            np.swapaxes(sp_target_torch[0].cpu().numpy(), 0,
                                        1), f0)
                        sf.write(
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)

                    utils.plot_data([
                        (sp_postnet_torch[0].cpu().numpy(), f0_output,
                         energy_output),
                        (sp_target_torch[0].cpu().numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                    plt.matshow(sp_postnet_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'sp_postnet_{}.png'.format(current_step)))
                    plt.matshow(ap_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'ap_{}.png'.format(current_step)))
                    plt.matshow(
                        variance_adaptor_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step)))
                    #                     plt.matshow(decoder_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step)))

                    plt.cla()
                    fout = open(
                        os.path.join(synth_path,
                                     'D_{}.txt'.format(current_step)), 'w')
                    fout.write(
                        str(log_duration_output[0].detach().cpu().numpy()) +
                        '\n')
                    fout.write(str(D[0].detach().cpu().numpy()) + '\n')
                    fout.write(
                        str(condition[0, :, 2].detach().cpu().numpy()) + '\n')
                    fout.close()


#                 if current_step % hp.eval_step == 0 or current_step==20:
#                     model.eval()
#                     with torch.no_grad():

#                         if hp.vocoder=='WORLD':
#                             d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)
#                         else:
#                             d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + m_l + m_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_loss', m_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)

#                     model.train()
#                 if current_step%10==0:
#                     print(energy_output[0],energy[0])
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #5
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    print("Use Tacotron2")
    model = nn.DataParallel(Tacotron2(hp)).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of TTS Parameters:', num_param)
    # Get buffer
    print("Load data to buffer")
    buffer = get_data_to_buffer()

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_rnn_dim,
                                     hp.n_warm_up_step, args.restore_step)
    tts_loss = DNNLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    try:
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Get dataset
    dataset = BufferDataset(buffer)

    # Get Training Loader
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_expand_size *
                                 hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn_tensor,
                                 drop_last=True,
                                 num_workers=0)
    total_step = hp.epochs * len(training_loader) * hp.batch_expand_size

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        for i, batchs in enumerate(training_loader):
            # real batch start here
            for j, db in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_expand_size + j + args.restore_step + \
                    epoch * len(training_loader) * hp.batch_expand_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = db["text"].long().to(device)
                mel_target = db["mel_target"].float().to(device)
                mel_pos = db["mel_pos"].long().to(device)
                src_pos = db["src_pos"].long().to(device)
                max_mel_len = db["mel_max_len"]

                mel_target = mel_target.contiguous().transpose(1, 2)
                src_length = torch.max(src_pos, -1)[0]
                mel_length = torch.max(mel_pos, -1)[0]

                gate_target = mel_pos.eq(0).float()
                gate_target = gate_target[:, 1:]
                gate_target = F.pad(gate_target, (0, 1, 0, 0), value=1.)

                # Forward
                inputs = character, src_length, mel_target, max_mel_len, mel_length
                mel_output, mel_output_postnet, gate_output = model(inputs)

                # Cal Loss
                mel_loss, mel_postnet_loss, gate_loss \
                    = tts_loss(mel_output, mel_output_postnet, gate_output,
                               mel_target, gate_target)
                total_loss = mel_loss + mel_postnet_loss + gate_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                g_l = gate_loss.item()

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join("logger", "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join("logger", "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join("logger", "gate_loss.txt"),
                          "a") as f_g_loss:
                    f_g_loss.write(str(g_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:"\
                        .format(epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Gate Loss: {:.4f};".format(
                        m_l, m_p_l, g_l)
                    str3 = "Current Learning Rate is {:.6f}."\
                        .format(scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s."\
                        .format((Now-Start), (total_step-current_step) * np.mean(Time, dtype=np.float32))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #6
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # read params
    mean_mel, std_mel = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "mel_stat.npy")),
                                     dtype=torch.float).to(device)
    mean_f0, std_f0 = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "f0_stat.npy")),
                                   dtype=torch.float).to(device)
    mean_energy, std_energy = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "energy_stat.npy")),
                                           dtype=torch.float).to(device)

    mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1)
    mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1)
    mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape(
        1, -1)

    # Load vocoder
    if hp.vocoder == 'vocgan':
        vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path)
        vocoder.to(device)
    else:
        vocoder = None

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                train_logger.add_scalar('Loss/total_loss', t_l, current_step)
                train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                        current_step)
                train_logger.add_scalar('Loss/duration_loss', d_l,
                                        current_step)
                train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                train_logger.add_scalar('Loss/energy_loss', e_l, current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step, vocoder)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #7
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    model = nn.DataParallel(FastSpeech()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech Parameters:', num_param)

    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    try:
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Define Some Information
    Time = np.array([])
    Start = time.clock()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        # Get Training Loader
        training_loader = DataLoader(dataset,
                                     batch_size=hp.batch_size**2,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     drop_last=True,
                                     num_workers=0)
        total_step = hp.epochs * len(training_loader) * hp.batch_size

        for i, batchs in enumerate(training_loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.clock()

                current_step = i * hp.batch_size + j + args.restore_step + \
                    epoch * len(training_loader)*hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                max_mel_len = data_of_batch["mel_max_len"]

                # Forward
                mel_output, mel_postnet_output, duration_predictor_output = model(
                    character,
                    src_pos,
                    mel_pos=mel_pos,
                    mel_max_length=max_mel_len,
                    length_target=D)

                # print(mel_target.size())
                # print(mel_output.size())

                # Cal Loss
                mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(
                    mel_output, mel_postnet_output, duration_predictor_output,
                    mel_target, D)
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = duration_loss.item()

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join("logger", "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join("logger", "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join("logger", "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.clock()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        m_l, m_p_l, d_l)
                    str3 = "Current Learning Rate is {:.6f}.".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                end_time = time.clock()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #8
0
                    batch_size=hp.batch_size**2,
                    shuffle=True,
                    collate_fn=dataset.collate_fn,
                    drop_last=False,
                    num_workers=8)
# Define model
model = FastSpeech2(py_vocab_size, hz_vocab_size).to(device)
num_param = utils.get_param_num(model)

# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(),
                             lr=hp.start_lr,
                             betas=hp.betas,
                             eps=hp.eps,
                             weight_decay=0)
scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                 hp.n_warm_up_step, args.restore_step)
Loss = FastSpeech2Loss().to(device)
print("Optimizer and Loss Function Defined.")
# Load checkpoint if exists
checkpoint_path = os.path.join(hp.checkpoint_path)
try:
    checkpoint = torch.load(
        os.path.join(checkpoint_path,
                     'checkpoint_{}.pth.tar'.format(args.restore_step)))

    #temp = nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])
    #model.load_state_dict(temp.module.state_dict())
    #del temp
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("\n---Model Restored at Step {}---\n".format(args.restore_step))
예제 #9
0
파일: train.py 프로젝트: qqwwaass11/STYLER
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(STYLER()).to(device)
    print("Model Has Been Defined")

    # Parameters
    num_param = utils.get_param_num(model)
    text_encoder = utils.get_param_num(
        model.module.style_modeling.style_encoder.text_encoder)
    audio_encoder = utils.get_param_num(
        model.module.style_modeling.style_encoder.audio_encoder)
    predictors = utils.get_param_num(model.module.style_modeling.duration_predictor)\
         + utils.get_param_num(model.module.style_modeling.pitch_predictor)\
              + utils.get_param_num(model.module.style_modeling.energy_predictor)
    decoder = utils.get_param_num(model.module.decoder)
    print('Number of Model Parameters          :', num_param)
    print('Number of Text Encoder Parameters   :', text_encoder)
    print('Number of Audio Encoder Parameters  :', audio_encoder)
    print('Number of Predictor Parameters      :', predictors)
    print('Number of Decoder Parameters        :', decoder)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = STYLERLoss().to(device)
    DATLoss = DomainAdversarialTrainingLoss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path())
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    vocoder = utils.get_vocoder()

    # Init logger
    log_path = hp.log_path()
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path()
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i*hp.batch_size + j + args.restore_step + \
                    epoch*len(loader)*hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                mel_aug = torch.from_numpy(
                    data_of_batch["mel_aug"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                f0_norm = torch.from_numpy(
                    data_of_batch["f0_norm"]).float().to(device)
                f0_norm_aug = torch.from_numpy(
                    data_of_batch["f0_norm_aug"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                energy_input = torch.from_numpy(
                    data_of_batch["energy_input"]).float().to(device)
                energy_input_aug = torch.from_numpy(
                    data_of_batch["energy_input_aug"]).float().to(device)
                speaker_embed = torch.from_numpy(
                    data_of_batch["speaker_embed"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_outputs, mel_postnet_outputs, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _, aug_posteriors = model(
                    text,
                    mel_target,
                    mel_aug,
                    f0_norm,
                    energy_input,
                    src_len,
                    mel_len,
                    D,
                    f0,
                    energy,
                    max_src_len,
                    max_mel_len,
                    speaker_embed=speaker_embed)

                # Cal Loss Clean
                mel_output, mel_postnet_output = mel_outputs[
                    0], mel_postnet_outputs[0]
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss, classifier_loss_a = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask, src_len, mel_len,\
                        aug_posteriors, torch.zeros(mel_target.size(0)).long().to(device))

                # Cal Loss Noisy
                mel_output_noisy, mel_postnet_output_noisy = mel_outputs[
                    1], mel_postnet_outputs[1]
                mel_noisy_loss, mel_postnet_noisy_loss = Loss.cal_mel_loss(
                    mel_output_noisy, mel_postnet_output_noisy, mel_aug,
                    ~mel_mask)

                # Forward DAT
                enc_cat = model.module.style_modeling.style_encoder.encoder_input_cat(
                    mel_aug, f0_norm_aug, energy_input_aug, mel_aug)
                duration_encoding, pitch_encoding, energy_encoding, _ = model.module.style_modeling.style_encoder.audio_encoder(
                    enc_cat, mel_len, src_len, mask=None)
                aug_posterior_d = model.module.style_modeling.augmentation_classifier_d(
                    duration_encoding)
                aug_posterior_p = model.module.style_modeling.augmentation_classifier_p(
                    pitch_encoding)
                aug_posterior_e = model.module.style_modeling.augmentation_classifier_e(
                    energy_encoding)

                # Cal Loss DAT
                classifier_loss_a_dat = DATLoss(
                    (aug_posterior_d, aug_posterior_p, aug_posterior_e),
                    torch.ones(mel_target.size(0)).long().to(device))

                # Total loss
                total_loss = mel_loss + mel_postnet_loss + mel_noisy_loss + mel_postnet_noisy_loss + d_loss + f_loss + e_loss\
                    + hp.dat_weight*(classifier_loss_a + classifier_loss_a_dat)

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                m_n_l = mel_noisy_loss.item()
                m_p_n_l = mel_postnet_noisy_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                cl_a = classifier_loss_a.item()
                cl_a_dat = classifier_loss_a_dat.item()

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step == 1 or current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_noisy_loss', m_n_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_postnet_noisy_loss',
                                            m_p_n_l, current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)
                    train_logger.add_scalar('Loss/dat_clean_loss', cl_a,
                                            current_step)
                    train_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat,
                                            current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step == 1 or current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_aug_torch = mel_aug[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_aug = mel_aug[0, :length].detach().cpu().transpose(
                        0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel_noisy_torch = mel_output_noisy[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_noisy = mel_output_noisy[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet_noisy_torch = mel_postnet_output_noisy[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_noisy = mel_postnet_output_noisy[
                        0, :length].detach().cpu().transpose(0, 1)
                    # Audio.tools.inv_mel_spec(mel, os.path.join(
                    #     synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "c")))
                    # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(
                    #     synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "c")))
                    # Audio.tools.inv_mel_spec(mel_noisy, os.path.join(
                    #     synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "n")))
                    # Audio.tools.inv_mel_spec(mel_postnet_noisy, os.path.join(
                    #     synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "n")))

                    wav_mel = utils.vocoder_infer(
                        mel_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_{}.wav'.format(current_step, "c",
                                                       hp.vocoder)))
                    wav_mel_postnet = utils.vocoder_infer(
                        mel_postnet_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_postnet_{}.wav'.format(
                                current_step, "c", hp.vocoder)))
                    wav_ground_truth = utils.vocoder_infer(
                        mel_target_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_ground-truth_{}.wav'.format(
                                current_step, "c", hp.vocoder)))
                    wav_mel_noisy = utils.vocoder_infer(
                        mel_noisy_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_{}.wav'.format(current_step, "n",
                                                       hp.vocoder)))
                    wav_mel_postnet_noisy = utils.vocoder_infer(
                        mel_postnet_noisy_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_postnet_{}.wav'.format(
                                current_step, "n", hp.vocoder)))
                    wav_aug = utils.vocoder_infer(
                        mel_aug_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_ground-truth_{}.wav'.format(
                                current_step, "n", hp.vocoder)))

                    # Model duration prediction
                    log_duration_output = log_duration_output[
                        0, :src_len[0].item()].detach().cpu()  # [seg_len]
                    log_duration_output = torch.clamp(torch.round(
                        torch.exp(log_duration_output) - hp.log_offset),
                                                      min=0).int()
                    model_duration = utils.get_alignment_2D(
                        log_duration_output).T  # [seg_len, mel_len]
                    model_duration = utils.plot_alignment([model_duration])

                    # Model mel prediction
                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()
                    mel_predicted = utils.plot_data(
                        [(mel_postnet.numpy(), f0_output, energy_output),
                         (mel_target.numpy(), f0, energy)], [
                             'Synthetized Spectrogram Clean',
                             'Ground-Truth Spectrogram'
                         ],
                        filename=os.path.join(
                            synth_path,
                            'step_{}_{}.png'.format(current_step, "c")))
                    mel_noisy_predicted = utils.plot_data(
                        [(mel_postnet_noisy.numpy(), f0_output, energy_output),
                         (mel_aug.numpy(), f0, energy)],
                        ['Synthetized Spectrogram Noisy', 'Aug Spectrogram'],
                        filename=os.path.join(
                            synth_path,
                            'step_{}_{}.png'.format(current_step, "n")))

                    # Normalize audio for tensorboard logger. See https://github.com/lanpa/tensorboardX/issues/511#issuecomment-537600045
                    wav_ground_truth = wav_ground_truth / max(wav_ground_truth)
                    wav_mel = wav_mel / max(wav_mel)
                    wav_mel_postnet = wav_mel_postnet / max(wav_mel_postnet)
                    wav_aug = wav_aug / max(wav_aug)
                    wav_mel_noisy = wav_mel_noisy / max(wav_mel_noisy)
                    wav_mel_postnet_noisy = wav_mel_postnet_noisy / max(
                        wav_mel_postnet_noisy)

                    train_logger.add_image("model_duration",
                                           model_duration,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_image("mel_predicted/Clean",
                                           mel_predicted,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_image("mel_predicted/Noisy",
                                           mel_noisy_predicted,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_audio("Clean/wav_ground_truth",
                                           wav_ground_truth,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Clean/wav_mel",
                                           wav_mel,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Clean/wav_mel_postnet",
                                           wav_mel_postnet,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_aug",
                                           wav_aug,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_mel_noisy",
                                           wav_mel_noisy,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_mel_postnet_noisy",
                                           wav_mel_postnet_noisy,
                                           current_step,
                                           sample_rate=hp.sampling_rate)

                if current_step == 1 or current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, cl_a, cl_a_dat, m_l, m_p_l, m_n_l, m_p_n_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l + m_n_l + m_p_n_l\
                            + hp.dat_weight*(cl_a + cl_a_dat)

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_noisy_loss', m_n_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_noisy_loss',
                                              m_p_n_l, current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)
                        val_logger.add_scalar('Loss/dat_clean_loss', cl_a,
                                              current_step)
                        val_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #10
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    wave_glow = utils.get_WaveGlow()

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_len = max(data_of_batch["mel_len"]).astype(np.int16)

                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D, f0, energy)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    duration_output, D, f0_output, f0, energy_output, energy,
                    mel_output, mel_postnet_output, mel_target, src_len,
                    mel_len)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    logger.add_scalars('Loss/total_loss', {'training': t_l},
                                       current_step)
                    logger.add_scalars('Loss/mel_loss', {'training': m_l},
                                       current_step)
                    logger.add_scalars('Loss/mel_postnet_loss',
                                       {'training': m_p_l}, current_step)
                    logger.add_scalars('Loss/duration_loss', {'training': d_l},
                                       current_step)
                    logger.add_scalars('Loss/F0_loss', {'training': f_l},
                                       current_step)
                    logger.add_scalars('Loss/energy_loss', {'training': e_l},
                                       current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    Audio.tools.inv_mel_spec(
                        mel,
                        os.path.join(
                            synth_path,
                            "step_{}_griffin_lim.wav".format(current_step)))
                    Audio.tools.inv_mel_spec(
                        mel_postnet,
                        os.path.join(
                            synth_path,
                            "step_{}_postnet_griffin_lim.wav".format(
                                current_step)))
                    waveglow.inference.inference(
                        mel_torch, wave_glow,
                        os.path.join(
                            synth_path,
                            "step_{}_waveglow.wav".format(current_step)))
                    waveglow.inference.inference(
                        mel_postnet_torch, wave_glow,
                        os.path.join(
                            synth_path, "step_{}_postnet_waveglow.wav".format(
                                current_step)))
                    waveglow.inference.inference(
                        mel_target_torch, wave_glow,
                        os.path.join(
                            synth_path,
                            "step_{}_ground-truth_waveglow.wav".format(
                                current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        logger.add_scalars('Loss/total_loss',
                                           {'validation': t_l}, current_step)
                        logger.add_scalars('Loss/mel_loss',
                                           {'validation': m_l}, current_step)
                        logger.add_scalars('Loss/mel_postnet_loss',
                                           {'validation': m_p_l}, current_step)
                        logger.add_scalars('Loss/duration_loss',
                                           {'validation': d_l}, current_step)
                        logger.add_scalars('Loss/F0_loss', {'validation': f_l},
                                           current_step)
                        logger.add_scalars('Loss/energy_loss',
                                           {'validation': e_l}, current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
예제 #11
0
def main(args):
    # Get device
    #     device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'
    # Define model
    model = FastSpeech().to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech Parameters:', num_param)

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    writer = SummaryWriter(log_dir='log/' + current_time)

    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)

    # Load checkpoint if exists
    try:
        checkpoint_in = open(
            os.path.join(hp.checkpoint_path, 'checkpoint.txt'), 'r')
        args.restore_step = int(checkpoint_in.readline().strip())
        checkpoint_in.close()
        checkpoint = torch.load(
            os.path.join(hp.checkpoint_path,
                         'checkpoint_%08d.pth' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:

        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)
    # Get dataset
    dataset = FastSpeechDataset()

    # Optimizer and loss

    scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = FastSpeechLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    t_l = 0.0
    for epoch in range(hp.epochs):
        # Get Training Loader
        training_loader = DataLoader(dataset,
                                     batch_size=hp.batch_size**2,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     drop_last=True,
                                     num_workers=0)
        total_step = hp.epochs * len(training_loader) * hp.batch_size

        for i, batchs in enumerate(training_loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + \
                    epoch * len(training_loader)*hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                condition1 = torch.from_numpy(
                    data_of_batch["condition1"]).long().to(device)
                condition2 = torch.from_numpy(
                    data_of_batch["condition2"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).long().to(device)
                norm_f0 = torch.from_numpy(
                    data_of_batch["norm_f0"]).long().to(device)
                mel_in = torch.from_numpy(
                    data_of_batch["mel_in"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                lens = data_of_batch["lens"]
                max_mel_len = data_of_batch["mel_max_len"]
                #                 print(condition1,condition2)
                # Forward
                mel_output = model(src_seq1=condition1,
                                   src_seq2=condition2,
                                   mel_in=mel_in,
                                   src_pos=src_pos,
                                   mel_pos=mel_pos,
                                   mel_max_length=max_mel_len,
                                   length_target=D)

                #                 print(mel_target.size())
                #                 print(mel_output)
                #                 print(mel_postnet_output)

                # Cal Loss
                #                 mel_loss, mel_postnet_loss= fastspeech_loss(mel_output,                                                                            mel_postnet_output,mel_target,)
                #                 print(mel_output.shape,mel_target.shape)
                Loss = torch.nn.CrossEntropyLoss()
                predict = mel_output.transpose(1, 2)
                target1 = mel_target.long().squeeze()
                target2 = norm_f0.long().squeeze()
                target = ((target1 + target2) / 2).long().squeeze()

                #                 print(predict.shape,target.shape)
                #                 print(target.float().mean())
                losses = []
                #                 print(lens,target)
                for index in range(predict.shape[0]):
                    #                     print(predict[i,:,:lens[i]].shape,target[i,:lens[i]].shape)
                    losses.append(
                        Loss(predict[index, :, :lens[index]].transpose(0, 1),
                             target[index, :lens[index]]).unsqueeze(0))


#                     losses.append(0.5*Loss(predict[index,:,:lens[index]].transpose(0,1),target2[index,:lens[index]]).unsqueeze(0))
                total_loss = torch.cat(losses).mean()
                t_l += total_loss.item()

                #                 assert np.isnan(t_l)==False

                with open(os.path.join("logger", "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                # Backward
                if not np.isnan(t_l):
                    total_loss.backward()
                else:
                    print(condition1, condition2, D)

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}] Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Loss:{:.4f} ".format(t_l / hp.log_step)

                    str3 = "LR:{:.6f}".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "T: {:.1f}s ETR:{:.1f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print('\r' + str1 + ' ' + str2 + ' ' + str3 + ' ' + str4,
                          end='')
                    writer.add_scalar('loss', t_l / hp.log_step, current_step)
                    writer.add_scalar('lreaning rate',
                                      scheduled_optim.get_learning_rate(),
                                      current_step)

                    if hp.gpu_log_step != -1 and current_step % hp.gpu_log_step == 0:
                        os.system('nvidia-smi')

                    with open(os.path.join("logger", "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                    t_l = 0.0

                if current_step % hp.fig_step == 0 or current_step == 20:
                    f = plt.figure()
                    plt.matshow(mel_output[0].cpu().detach().numpy())
                    plt.savefig('out_predicted.png')
                    plt.matshow(
                        F.softmax(predict, dim=1).transpose(
                            1, 2)[0].cpu().detach().numpy())
                    plt.savefig('out_predicted_softmax.png')
                    writer.add_figure('predict', f, current_step)
                    plt.cla()

                    f = plt.figure(figsize=(8, 6))
                    #                   plt.matshow(mel_target[0].cpu().detach().numpy())
                    #                   x=np.arange(mel_target.shape[1])
                    #                   y=sample_from_discretized_mix_logistic(mel_output.transpose(1,2)).cpu().detach().numpy()[0]
                    #                   plt.plot(x,y)
                    sample = []
                    p = F.softmax(predict, dim=1).transpose(
                        1, 2)[0].detach().cpu().numpy()
                    for index in range(p.shape[0]):
                        sample.append(np.random.choice(200, 1, p=p[index]))
                    sample = np.array(sample)
                    plt.plot(np.arange(sample.shape[0]),
                             sample,
                             color='grey',
                             linewidth='1')
                    for index in range(D.shape[1]):
                        x = np.arange(D[0][index].cpu().numpy()
                                      ) + D[0][:index].cpu().numpy().sum()
                        y = np.arange(D[0][index].detach().cpu().numpy())
                        if condition2[0][index].cpu().numpy() != 0:
                            y.fill(
                                (condition2[0][index].cpu().numpy() - 40.0) *
                                5)
                            plt.plot(x, y, color='blue')
                    plt.plot(np.arange(target.shape[1]),
                             target[0].squeeze().detach().cpu().numpy(),
                             color='red',
                             linewidth='1')
                    plt.savefig('out_target.png', dpi=300)
                    writer.add_figure('target', f, current_step)
                    plt.cla()

                    plt.close("all")

                if current_step % (hp.save_step) == 0:
                    print("save model at step %d ..." % current_step, end='')
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%08d.pth' % current_step))
                    checkpoint_out = open(
                        os.path.join(hp.checkpoint_path, 'checkpoint.txt'),
                        'w')
                    checkpoint_out.write(str(current_step))
                    checkpoint_out.close()

                    #                     os.system('python savefig.py')

                    print('save completed')

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)