Esempio n. 1
0
def synthesize(model, text, sentence, prefix=''):
    src_pos = np.array([i + 1 for i in range(text.shape[1])])
    src_pos = np.stack([src_pos])
    src_pos = torch.from_numpy(src_pos).to(device).long()

    model.to(device)
    mel, mel_postnet, duration_output, f0_output, energy_output = model(
        text, src_pos)
    model.to('cpu')

    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    mel = mel[0].cpu().transpose(0, 1).detach()
    mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
    f0_output = f0_output[0].detach().cpu().numpy()
    energy_output = energy_output[0].detach().cpu().numpy()

    if not os.path.exists(hp.test_path):
        os.makedirs(hp.test_path)

    Audio.tools.inv_mel_spec(
        mel_postnet,
        os.path.join(hp.test_path,
                     '{}_griffin_lim_{}.wav'.format(prefix, sentence)))

    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
        utils.melgan_infer(
            mel_postnet_torch, melgan,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)
        utils.waveglow_infer(
            mel_postnet_torch, waveglow,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))

    utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)],
                    ['Synthesized Spectrogram'],
                    filename=os.path.join(hp.test_path,
                                          '{}_{}.png'.format(prefix,
                                                             sentence)))
Esempio n. 2
0
        f_log.write("\n")

    return d_l, f_l, e_l, mel_l, mel_p_l

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--step', type=int, default=30000)
    args = parser.parse_args()
    
    # Get model
    model = get_FastSpeech2(args.step).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)
    
    # Load vocoder
    if hp.vocoder == 'melgan':
        vocoder = utils.get_melgan()
    elif hp.vocoder == 'waveglow':
        vocoder = utils.get_waveglow()
    vocoder.to(device)
        
    # Init directories
    if not os.path.exists(hp.log_path):
        os.makedirs(hp.log_path)
    if not os.path.exists(hp.eval_path):
        os.makedirs(hp.eval_path)
    
    evaluate(model, args.step, vocoder)
Esempio n. 3
0
def main(args):
    torch.manual_seed(0)
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    if hp.use_spk_embed:
        n_pkers = len(dataset.spk_table.keys())
        model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device)
    else:
        model = nn.DataParallel(FastSpeech2()).to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        #melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Init evaluation directory
    eval_path = hp.eval_path
    if not os.path.exists(eval_path):
        os.makedirs(eval_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1
                print("step : {}".format(current_step), end='\r', flush=True)

                ### Get Data ###
                if hp.use_spk_embed:
                    spk_ids = torch.tensor(data_of_batch["spk_ids"]).to(
                        torch.int64).to(device)
                else:
                    spk_ids = None
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                ### Forward ###
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len, spk_ids)

                ### Cal Loss ###
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss

                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                ## Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                ### Update weights ###
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                ### Print ###
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                ### Save model ###
                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                ### Synth ###
                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    print("step: {} , length {}, {}".format(
                        current_step, length, mel_len))
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)

                    spk_id = dataset.inv_spk_table[int(spk_ids[0])]
                    if hp.vocoder == 'melgan':
                        vocoder.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    elif hp.vocoder == 'waveglow':
                        vocoder.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                ### Evaluation ###
                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)

                    model.train()

                ### Time ###
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Esempio n. 4
0
def main(args):
    torch.manual_seed(0)

    # Get device
    # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoaderX(dataset,
                         batch_size=hp.batch_size * 4,
                         shuffle=True,
                         collate_fn=dataset.collate_fn,
                         drop_last=True,
                         num_workers=8)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    #     model = FastSpeech2().to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    train_logger = SummaryWriter(log_dir='log/train/' + current_time)
    val_logger = SummaryWriter(log_dir='log/validation/' + current_time)
    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()
    current_step0 = 0
    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * len(batchs) + j + args.restore_step + \
                    epoch * len(loader)*len(batchs) + 1
                # Get Data
                condition = torch.from_numpy(
                    data_of_batch["condition"]).long().to(device)
                mel_refer = torch.from_numpy(
                    data_of_batch["mel_refer"]).float().to(device)
                if hp.vocoder == 'WORLD':
                    ap_target = torch.from_numpy(
                        data_of_batch["ap_target"]).float().to(device)
                    sp_target = torch.from_numpy(
                        data_of_batch["sp_target"]).float().to(device)
                else:
                    mel_target = torch.from_numpy(
                        data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                #print(D,log_D)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                if hp.vocoder == 'WORLD':
                    #                     print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape)
                    ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model(
                        condition, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        ap_output=ap_output,
                        sp_output=sp_output,
                        sp_postnet_output=sp_postnet_output,
                        ap_target=ap_target,
                        sp_target=sp_target,
                        src_mask=src_mask,
                        ap_mask=ap_mask,
                        sp_mask=sp_mask)
                    total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss
                else:
                    mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                        condition, mel_refer, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        log_D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        mel_output=mel_output,
                        mel_postnet_output=mel_postnet_output,
                        mel_target=mel_target,
                        src_mask=~src_mask,
                        mel_mask=~mel_mask)
                    total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                if hp.vocoder == 'WORLD':
                    ap_l = ap_loss.item()
                    sp_l = sp_loss.item()
                    sp_p_l = sp_postnet_loss.item()
                else:
                    m_l = mel_loss.item()
                    m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                #                 with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
                #                     f_total_loss.write(str(t_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
                #                     f_mel_loss.write(str(m_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                #                     f_mel_postnet_loss.write(str(m_p_l)+"\n")
                #                 with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
                #                     f_d_loss.write(str(d_l)+"\n")
                #                 with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
                #                     f_f_loss.write(str(f_l)+"\n")
                #                 with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
                #                     f_e_loss.write(str(e_l)+"\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}],Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    if hp.vocoder == 'WORLD':
                        str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l)
                    else:
                        str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "T:{:.1f}s,ETA:{:.1f}s.".format(
                        (Now - Start) / (current_step - current_step0),
                        (total_step - current_step) * np.mean(Time))

                    print("" + str1 + str2 + str3 + '')

                    #                     with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                    #                         f_log.write(str1 + "\n")
                    #                         f_log.write(str2 + "\n")
                    #                         f_log.write(str3 + "\n")
                    #                         f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    if hp.vocoder == 'WORLD':
                        train_logger.add_scalar('Loss/ap_loss', ap_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_loss', sp_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l,
                                                current_step)
                    else:
                        train_logger.add_scalar('Loss/mel_loss', m_l,
                                                current_step)
                        train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                                current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                if current_step % hp.save_step == 0 or current_step == 20:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0 or current_step == 5:
                    length = mel_len[0].item()

                    if hp.vocoder == 'WORLD':
                        ap_target_torch = ap_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        ap_torch = ap_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_target_torch = sp_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        sp_torch = sp_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_postnet_torch = sp_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                    else:
                        mel_target_torch = mel_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_target = mel_target[
                            0, :length].detach().cpu().transpose(0, 1)
                        mel_torch = mel_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        mel = mel_output[0, :length].detach().cpu().transpose(
                            0, 1)
                        mel_postnet_torch = mel_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_postnet = mel_postnet_output[
                            0, :length].detach().cpu().transpose(0, 1)
#                     Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step)))
#                     Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'WORLD':
                        #     ap=np.swapaxes(ap,0,1)
                        #     sp=np.swapaxes(sp,0,1)
                        wav = utils.world_infer(
                            np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1),
                            np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0,
                                        1), f0_output)
                        sf.write(
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)
                        wav = utils.world_infer(
                            np.swapaxes(ap_target_torch[0].cpu().numpy(), 0,
                                        1),
                            np.swapaxes(sp_target_torch[0].cpu().numpy(), 0,
                                        1), f0)
                        sf.write(
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)

                    utils.plot_data([
                        (sp_postnet_torch[0].cpu().numpy(), f0_output,
                         energy_output),
                        (sp_target_torch[0].cpu().numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                    plt.matshow(sp_postnet_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'sp_postnet_{}.png'.format(current_step)))
                    plt.matshow(ap_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'ap_{}.png'.format(current_step)))
                    plt.matshow(
                        variance_adaptor_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step)))
                    #                     plt.matshow(decoder_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step)))

                    plt.cla()
                    fout = open(
                        os.path.join(synth_path,
                                     'D_{}.txt'.format(current_step)), 'w')
                    fout.write(
                        str(log_duration_output[0].detach().cpu().numpy()) +
                        '\n')
                    fout.write(str(D[0].detach().cpu().numpy()) + '\n')
                    fout.write(
                        str(condition[0, :, 2].detach().cpu().numpy()) + '\n')
                    fout.close()


#                 if current_step % hp.eval_step == 0 or current_step==20:
#                     model.eval()
#                     with torch.no_grad():

#                         if hp.vocoder=='WORLD':
#                             d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)
#                         else:
#                             d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + m_l + m_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_loss', m_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)

#                     model.train()
#                 if current_step%10==0:
#                     print(energy_output[0],energy[0])
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Esempio n. 5
0
    parser.add_argument('--energy_control', type=float, default=1.0)
    args = parser.parse_args()

    sentences = [
        "Advanced text to speech models such as Fast Speech can synthesize speech significantly faster than previous auto regressive models with comparable quality. The training of Fast Speech model relies on an auto regressive teacher model for duration prediction and knowledge distillation, which can ease the one to many mapping problem in T T S. However, Fast Speech has several disadvantages, 1, the teacher student distillation pipeline is complicated, 2, the duration extracted from the teacher model is not accurate enough, and the target mel spectrograms distilled from teacher model suffer from information loss due to data simplification, both of which limit the voice quality.",
        "Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition",
        "in being comparatively modern.",
        "For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process",
        "produced the block books, which were the immediate predecessors of the true printed book,",
        "the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.",
        "And it is worth mention in passing that, as an example of fine typography,",
        "the earliest book printed with movable types, the Gutenberg, or \"forty-two line Bible\" of about 1455,",
        "has never been surpassed.",
        "Printing, then, for our purpose, may be considered as the art of making books by means of movable types.",
        "Now, as all books not primarily intended as picture-books consist principally of types composed to form letterpress,"
    ]

    model = get_FastSpeech2(args.step).to(device)
    melgan = waveglow = None
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()

    with torch.no_grad():
        for sentence in sentences:
            text = preprocess(sentence)
            synthesize(model, waveglow, melgan, text, sentence,
                       'step_{}'.format(args.step), args.duration_control,
                       args.pitch_control, args.energy_control)
Esempio n. 6
0
def main(args, device):
    hp.checkpoint_path = os.path.join(hp.root_path, args.name_task, "ckpt",
                                      hp.dataset)
    hp.synth_path = os.path.join(hp.root_path, args.name_task, "synth",
                                 hp.dataset)
    hp.eval_path = os.path.join(hp.root_path, args.name_task, "eval",
                                hp.dataset)
    hp.log_path = os.path.join(hp.root_path, args.name_task, "log", hp.dataset)
    hp.test_path = os.path.join(hp.root_path, args.name_task, 'results')
    # Define model
    print("Use FastSpeech")
    model = nn.DataParallel(FastSpeech()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of TTS Parameters:', num_param)
    # Get buffer
    print("Load data to buffer")
    buffer = get_data_to_buffer('train.txt')

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp1.decoder_dim,
                                     hp1.n_warm_up_step, args.restore_step)
    fastspeech_loss = DNNLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(args.restore_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        checkpoint_path = os.path.join(hp.checkpoint_path)
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Init logger
    if not os.path.exists(hp.log_path):
        os.makedirs(hp.log_path)

    waveglow = None
    if hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()

    # Get dataset
    dataset = BufferDataset(buffer)

    # Get Training Loader
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_expand_size *
                                 hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn_tensor,
                                 drop_last=True,
                                 num_workers=0)
    total_step = hp.epochs * len(training_loader) * hp.batch_expand_size

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        for i, batchs in enumerate(training_loader):
            # real batch start here
            for j, db in enumerate(batchs):
                print(len(batchs), len(db))
                start_time = time.perf_counter()

                current_step = i * hp.batch_expand_size + j + args.restore_step + \
                    epoch * len(training_loader) * hp.batch_expand_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = db["text"].long().to(device)
                mel_target = db["mel_target"].float().to(device)
                duration = db["duration"].int().to(device)
                mel_pos = db["mel_pos"].long().to(device)
                src_pos = db["src_pos"].long().to(device)
                max_mel_len = db["mel_max_len"]

                # Forward
                mel_output, mel_postnet_output, duration_predictor_output = model(
                    character,
                    src_pos,
                    mel_pos=mel_pos,
                    mel_max_length=max_mel_len,
                    length_target=duration)

                # Cal Loss
                mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(
                    mel_output, mel_postnet_output, duration_predictor_output,
                    mel_target, duration)
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = duration_loss.item()

                with open(os.path.join(hp.log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join(hp.log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join(hp.log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join(hp.log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp1.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l)
                    str3 = "Current Learning Rate is {:.6f}.".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join(hp.log_path, "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                # if current_step % 10 == 0:
                #     model.eval()
                #     with torch.no_grad():
                #         t_l, d_l, mel_l, mel_p_l = evaluate(
                #             model, current_step, vocoder=waveglow)

                #         str0 = 'Validating'
                #         str1 = "\tEpoch [{}/{}], Step [{}/{}]:".format(
                #             epoch + 1, hp.epochs, current_step, total_step)
                #         str2 = "\tTotal Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                #             t_l, m_l, m_p_l, d_l)
                #         print(str0)
                #         print(str1)
                #         print(str2)
                #     model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    # model = nn.DataParallel(FastSpeech2()).to(device)
    model = FastSpeech2().to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    # pdb.set_trace()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_output, mel_postnet_output, duration_output, src_mask, pred_mel_mask, enc_attns, dec_attns, W = model(
                    text, src_len, mel_len, max_src_len, max_mel_len)
                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss = Loss(
                    duration_output, mel_len, mel_output, mel_postnet_output,
                    mel_target, src_mask, pred_mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)

                    plot_attn(train_logger, enc_attns, dec_attns, current_step,
                              hp)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    Audio.tools.inv_mel_spec(
                        mel,
                        os.path.join(
                            synth_path,
                            "step_{}_griffin_lim.wav".format(current_step)))
                    Audio.tools.inv_mel_spec(
                        mel_postnet,
                        os.path.join(
                            synth_path,
                            "step_{}_postnet_griffin_lim.wav".format(
                                current_step)))

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, m_l, m_p_l = evaluate(model, current_step)
                        t_l = d_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Esempio n. 8
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("test.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=False,
                        collate_fn=dataset.collate_fn,
                        drop_last=False,
                        num_workers=0)

    model = get_FastSpeech2(args.step, full_path=args.model_fs).to(device)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan(full_path=args.model_melgan)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'test'))
    test_logger = SummaryWriter(os.path.join(log_path, 'test'))

    # Init synthesis directory
    test_path = hp.test_path
    if not os.path.exists(test_path):
        os.makedirs(test_path)

    current_step = args.step
    findex = open(os.path.join(test_path, "index.tsv"), "w")

    # Testing
    print("Generate test audio")
    prefix = ""
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            print("Start batch", j)
            fids = data_of_batch["id"]

            # Get Data
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).long().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                text, src_len)

            for i in range(len(mel_len)):
                fid = fids[i]
                print("Generate audio for:", fid)
                length = mel_len[i].item()

                _mel_target_torch = mel_target[i, :length].detach().unsqueeze(
                    0).transpose(1, 2)

                _mel_target = mel_target[i, :length].detach().cpu().transpose(
                    0, 1)

                _mel_torch = mel_output[i, :length].detach().unsqueeze(
                    0).transpose(1, 2)

                _mel = mel_output[i, :length].detach().cpu().transpose(0, 1)

                _mel_postnet_torch = mel_postnet_output[
                    i, :length].detach().unsqueeze(0).transpose(1, 2)

                _mel_postnet = mel_postnet_output[
                    i, :length].detach().cpu().transpose(0, 1)

                fname = "{}{}_step_{}_gt_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel_target,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "Griffin Lim", "vocoder", fname, "",
                                  fid)

                fname = "{}{}_step_{}_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "FastSpeech2 + GL", "tts", fname, "",
                                  fid)

                fname = "{}{}_step_{}_postnet_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel_postnet,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "FastSpeech2 + PN + GL", "tts",
                                  fname, "", fid)

                if hp.vocoder == 'melgan':
                    fname = '{}{}_step_{}_ground-truth_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_target_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "Melgan", "vocoder", fname, "",
                                      fid)

                    fname = '{}{}_step_{}_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "FastSpeech2 + Melgan", "tts",
                                      fname, "", fid)

                    fname = '{}{}_step_{}_postnet_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_postnet_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "FastSpeech2 + PN + Melgan",
                                      "tts", fname, "", fid)

                elif hp.vocoder == 'waveglow':
                    utils.waveglow_infer(
                        _mel_torch, waveglow,
                        os.path.join(
                            hp.test_path,
                            'step_{}_{}.wav'.format(current_step, hp.vocoder)))
                    utils.waveglow_infer(
                        _mel_postnet_torch, waveglow,
                        os.path.join(
                            hp.test_path, 'step_{}_postnet_{}.wav'.format(
                                current_step, hp.vocoder)))
                    utils.waveglow_infer(
                        _mel_target_torch, waveglow,
                        os.path.join(
                            hp.test_path, 'step_{}_ground-truth_{}.wav'.format(
                                current_step, hp.vocoder)))