Пример #1
0
def synthesize(model, text, sentence, prefix=''):
    src_pos = np.array([i + 1 for i in range(text.shape[1])])
    src_pos = np.stack([src_pos])
    src_pos = torch.from_numpy(src_pos).to(device).long()

    model.to(device)
    mel, mel_postnet, duration_output, f0_output, energy_output = model(
        text, src_pos)
    model.to('cpu')

    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    mel = mel[0].cpu().transpose(0, 1).detach()
    mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
    f0_output = f0_output[0].detach().cpu().numpy()
    energy_output = energy_output[0].detach().cpu().numpy()

    if not os.path.exists(hp.test_path):
        os.makedirs(hp.test_path)

    Audio.tools.inv_mel_spec(
        mel_postnet,
        os.path.join(hp.test_path,
                     '{}_griffin_lim_{}.wav'.format(prefix, sentence)))

    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
        utils.melgan_infer(
            mel_postnet_torch, melgan,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)
        utils.waveglow_infer(
            mel_postnet_torch, waveglow,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))

    utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)],
                    ['Synthesized Spectrogram'],
                    filename=os.path.join(hp.test_path,
                                          '{}_{}.png'.format(prefix,
                                                             sentence)))
Пример #2
0
        f_log.write("\n")

    return d_l, f_l, e_l, mel_l, mel_p_l

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--step', type=int, default=30000)
    args = parser.parse_args()
    
    # Get model
    model = get_FastSpeech2(args.step).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)
    
    # Load vocoder
    if hp.vocoder == 'melgan':
        vocoder = utils.get_melgan()
    elif hp.vocoder == 'waveglow':
        vocoder = utils.get_waveglow()
    vocoder.to(device)
        
    # Init directories
    if not os.path.exists(hp.log_path):
        os.makedirs(hp.log_path)
    if not os.path.exists(hp.eval_path):
        os.makedirs(hp.eval_path)
    
    evaluate(model, args.step, vocoder)
Пример #3
0
def main(args):
    torch.manual_seed(0)
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    if hp.use_spk_embed:
        n_pkers = len(dataset.spk_table.keys())
        model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device)
    else:
        model = nn.DataParallel(FastSpeech2()).to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        #melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Init evaluation directory
    eval_path = hp.eval_path
    if not os.path.exists(eval_path):
        os.makedirs(eval_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1
                print("step : {}".format(current_step), end='\r', flush=True)

                ### Get Data ###
                if hp.use_spk_embed:
                    spk_ids = torch.tensor(data_of_batch["spk_ids"]).to(
                        torch.int64).to(device)
                else:
                    spk_ids = None
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                ### Forward ###
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len, spk_ids)

                ### Cal Loss ###
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss

                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                ## Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                ### Update weights ###
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                ### Print ###
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                ### Save model ###
                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                ### Synth ###
                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    print("step: {} , length {}, {}".format(
                        current_step, length, mel_len))
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)

                    spk_id = dataset.inv_spk_table[int(spk_ids[0])]
                    if hp.vocoder == 'melgan':
                        vocoder.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    elif hp.vocoder == 'waveglow':
                        vocoder.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_spk_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_postnet_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))
                        vocoder.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_spk_{}_ground-truth_{}.wav'.format(
                                    current_step, spk_id, hp.vocoder)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                ### Evaluation ###
                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)

                    model.train()

                ### Time ###
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Пример #4
0
def main(args):
    torch.manual_seed(0)

    # Get device
    # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    device = 'cuda'

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoaderX(dataset,
                         batch_size=hp.batch_size * 4,
                         shuffle=True,
                         collate_fn=dataset.collate_fn,
                         drop_last=True,
                         num_workers=8)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    #     model = FastSpeech2().to(device)

    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))

    current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime())
    train_logger = SummaryWriter(log_dir='log/train/' + current_time)
    val_logger = SummaryWriter(log_dir='log/validation/' + current_time)
    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()
    current_step0 = 0
    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * len(batchs) + j + args.restore_step + \
                    epoch * len(loader)*len(batchs) + 1
                # Get Data
                condition = torch.from_numpy(
                    data_of_batch["condition"]).long().to(device)
                mel_refer = torch.from_numpy(
                    data_of_batch["mel_refer"]).float().to(device)
                if hp.vocoder == 'WORLD':
                    ap_target = torch.from_numpy(
                        data_of_batch["ap_target"]).float().to(device)
                    sp_target = torch.from_numpy(
                        data_of_batch["sp_target"]).float().to(device)
                else:
                    mel_target = torch.from_numpy(
                        data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                #print(D,log_D)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                if hp.vocoder == 'WORLD':
                    #                     print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape)
                    ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model(
                        condition, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        ap_output=ap_output,
                        sp_output=sp_output,
                        sp_postnet_output=sp_postnet_output,
                        ap_target=ap_target,
                        sp_target=sp_target,
                        src_mask=src_mask,
                        ap_mask=ap_mask,
                        sp_mask=sp_mask)
                    total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss
                else:
                    mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                        condition, mel_refer, src_len, mel_len, D, f0, energy,
                        max_src_len, max_mel_len)

                    mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                        log_duration_output,
                        log_D,
                        f0_output,
                        f0,
                        energy_output,
                        energy,
                        mel_output=mel_output,
                        mel_postnet_output=mel_postnet_output,
                        mel_target=mel_target,
                        src_mask=~src_mask,
                        mel_mask=~mel_mask)
                    total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                if hp.vocoder == 'WORLD':
                    ap_l = ap_loss.item()
                    sp_l = sp_loss.item()
                    sp_p_l = sp_postnet_loss.item()
                else:
                    m_l = mel_loss.item()
                    m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                #                 with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
                #                     f_total_loss.write(str(t_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
                #                     f_mel_loss.write(str(m_l)+"\n")
                #                 with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                #                     f_mel_postnet_loss.write(str(m_p_l)+"\n")
                #                 with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
                #                     f_d_loss.write(str(d_l)+"\n")
                #                 with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
                #                     f_f_loss.write(str(f_l)+"\n")
                #                 with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
                #                     f_e_loss.write(str(e_l)+"\n")

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch[{}/{}],Step[{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    if hp.vocoder == 'WORLD':
                        str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l)
                    else:
                        str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format(
                            t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "T:{:.1f}s,ETA:{:.1f}s.".format(
                        (Now - Start) / (current_step - current_step0),
                        (total_step - current_step) * np.mean(Time))

                    print("" + str1 + str2 + str3 + '')

                    #                     with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                    #                         f_log.write(str1 + "\n")
                    #                         f_log.write(str2 + "\n")
                    #                         f_log.write(str3 + "\n")
                    #                         f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    if hp.vocoder == 'WORLD':
                        train_logger.add_scalar('Loss/ap_loss', ap_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_loss', sp_l,
                                                current_step)
                        train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l,
                                                current_step)
                    else:
                        train_logger.add_scalar('Loss/mel_loss', m_l,
                                                current_step)
                        train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                                current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)

                if current_step % hp.save_step == 0 or current_step == 20:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0 or current_step == 5:
                    length = mel_len[0].item()

                    if hp.vocoder == 'WORLD':
                        ap_target_torch = ap_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        ap_torch = ap_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_target_torch = sp_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        sp_torch = sp_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        sp_postnet_torch = sp_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                    else:
                        mel_target_torch = mel_target[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_target = mel_target[
                            0, :length].detach().cpu().transpose(0, 1)
                        mel_torch = mel_output[0, :length].detach().unsqueeze(
                            0).transpose(1, 2)
                        mel = mel_output[0, :length].detach().cpu().transpose(
                            0, 1)
                        mel_postnet_torch = mel_postnet_output[
                            0, :length].detach().unsqueeze(0).transpose(1, 2)
                        mel_postnet = mel_postnet_output[
                            0, :length].detach().cpu().transpose(0, 1)
#                     Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step)))
#                     Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'WORLD':
                        #     ap=np.swapaxes(ap,0,1)
                        #     sp=np.swapaxes(sp,0,1)
                        wav = utils.world_infer(
                            np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1),
                            np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0,
                                        1), f0_output)
                        sf.write(
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)
                        wav = utils.world_infer(
                            np.swapaxes(ap_target_torch[0].cpu().numpy(), 0,
                                        1),
                            np.swapaxes(sp_target_torch[0].cpu().numpy(), 0,
                                        1), f0)
                        sf.write(
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)), wav, 32000)

                    utils.plot_data([
                        (sp_postnet_torch[0].cpu().numpy(), f0_output,
                         energy_output),
                        (sp_target_torch[0].cpu().numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                    plt.matshow(sp_postnet_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'sp_postnet_{}.png'.format(current_step)))
                    plt.matshow(ap_torch[0].cpu().numpy())
                    plt.savefig(
                        os.path.join(synth_path,
                                     'ap_{}.png'.format(current_step)))
                    plt.matshow(
                        variance_adaptor_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step)))
                    #                     plt.matshow(decoder_output[0].detach().cpu().numpy())
                    #                     plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step)))

                    plt.cla()
                    fout = open(
                        os.path.join(synth_path,
                                     'D_{}.txt'.format(current_step)), 'w')
                    fout.write(
                        str(log_duration_output[0].detach().cpu().numpy()) +
                        '\n')
                    fout.write(str(D[0].detach().cpu().numpy()) + '\n')
                    fout.write(
                        str(condition[0, :, 2].detach().cpu().numpy()) + '\n')
                    fout.close()


#                 if current_step % hp.eval_step == 0 or current_step==20:
#                     model.eval()
#                     with torch.no_grad():

#                         if hp.vocoder=='WORLD':
#                             d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step)
#                             val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)
#                         else:
#                             d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step)
#                             t_l = d_l + f_l + e_l + m_l + m_p_l

#                             val_logger.add_scalar('valLoss/total_loss', t_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_loss', m_l, current_step)
#                             val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step)
#                             val_logger.add_scalar('valLoss/duration_loss', d_l, current_step)
#                             val_logger.add_scalar('valLoss/F0_loss', f_l, current_step)
#                             val_logger.add_scalar('valLoss/energy_loss', e_l, current_step)

#                     model.train()
#                 if current_step%10==0:
#                     print(energy_output[0],energy[0])
                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Пример #5
0
    parser.add_argument('--energy_control', type=float, default=1.0)
    args = parser.parse_args()

    sentences = [
        "Advanced text to speech models such as Fast Speech can synthesize speech significantly faster than previous auto regressive models with comparable quality. The training of Fast Speech model relies on an auto regressive teacher model for duration prediction and knowledge distillation, which can ease the one to many mapping problem in T T S. However, Fast Speech has several disadvantages, 1, the teacher student distillation pipeline is complicated, 2, the duration extracted from the teacher model is not accurate enough, and the target mel spectrograms distilled from teacher model suffer from information loss due to data simplification, both of which limit the voice quality.",
        "Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition",
        "in being comparatively modern.",
        "For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process",
        "produced the block books, which were the immediate predecessors of the true printed book,",
        "the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.",
        "And it is worth mention in passing that, as an example of fine typography,",
        "the earliest book printed with movable types, the Gutenberg, or \"forty-two line Bible\" of about 1455,",
        "has never been surpassed.",
        "Printing, then, for our purpose, may be considered as the art of making books by means of movable types.",
        "Now, as all books not primarily intended as picture-books consist principally of types composed to form letterpress,"
    ]

    model = get_FastSpeech2(args.step).to(device)
    melgan = waveglow = None
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()

    with torch.no_grad():
        for sentence in sentences:
            text = preprocess(sentence)
            synthesize(model, waveglow, melgan, text, sentence,
                       'step_{}'.format(args.step), args.duration_control,
                       args.pitch_control, args.energy_control)
Пример #6
0
def main(args, device):
    hp.checkpoint_path = os.path.join(hp.root_path, args.name_task, "ckpt",
                                      hp.dataset)
    hp.synth_path = os.path.join(hp.root_path, args.name_task, "synth",
                                 hp.dataset)
    hp.eval_path = os.path.join(hp.root_path, args.name_task, "eval",
                                hp.dataset)
    hp.log_path = os.path.join(hp.root_path, args.name_task, "log", hp.dataset)
    hp.test_path = os.path.join(hp.root_path, args.name_task, 'results')
    # Define model
    print("Use FastSpeech")
    model = nn.DataParallel(FastSpeech()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of TTS Parameters:', num_param)
    # Get buffer
    print("Load data to buffer")
    buffer = get_data_to_buffer('train.txt')

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer, hp1.decoder_dim,
                                     hp1.n_warm_up_step, args.restore_step)
    fastspeech_loss = DNNLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(args.restore_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        checkpoint_path = os.path.join(hp.checkpoint_path)
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Init logger
    if not os.path.exists(hp.log_path):
        os.makedirs(hp.log_path)

    waveglow = None
    if hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()

    # Get dataset
    dataset = BufferDataset(buffer)

    # Get Training Loader
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_expand_size *
                                 hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn_tensor,
                                 drop_last=True,
                                 num_workers=0)
    total_step = hp.epochs * len(training_loader) * hp.batch_expand_size

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        for i, batchs in enumerate(training_loader):
            # real batch start here
            for j, db in enumerate(batchs):
                print(len(batchs), len(db))
                start_time = time.perf_counter()

                current_step = i * hp.batch_expand_size + j + args.restore_step + \
                    epoch * len(training_loader) * hp.batch_expand_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = db["text"].long().to(device)
                mel_target = db["mel_target"].float().to(device)
                duration = db["duration"].int().to(device)
                mel_pos = db["mel_pos"].long().to(device)
                src_pos = db["src_pos"].long().to(device)
                max_mel_len = db["mel_max_len"]

                # Forward
                mel_output, mel_postnet_output, duration_predictor_output = model(
                    character,
                    src_pos,
                    mel_pos=mel_pos,
                    mel_max_length=max_mel_len,
                    length_target=duration)

                # Cal Loss
                mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(
                    mel_output, mel_postnet_output, duration_predictor_output,
                    mel_target, duration)
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = duration_loss.item()

                with open(os.path.join(hp.log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")

                with open(os.path.join(hp.log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")

                with open(os.path.join(hp.log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")

                with open(os.path.join(hp.log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp1.grad_clip_thresh)

                # Update weights
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l)
                    str3 = "Current Learning Rate is {:.6f}.".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join(hp.log_path, "logger.txt"),
                              "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(hp.checkpoint_path,
                                     'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                # if current_step % 10 == 0:
                #     model.eval()
                #     with torch.no_grad():
                #         t_l, d_l, mel_l, mel_p_l = evaluate(
                #             model, current_step, vocoder=waveglow)

                #         str0 = 'Validating'
                #         str1 = "\tEpoch [{}/{}], Step [{}/{}]:".format(
                #             epoch + 1, hp.epochs, current_step, total_step)
                #         str2 = "\tTotal Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                #             t_l, m_l, m_p_l, d_l)
                #         print(str0)
                #         print(str1)
                #         print(str2)
                #     model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    # model = nn.DataParallel(FastSpeech2()).to(device)
    model = FastSpeech2().to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan()
        melgan.to(device)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()
        waveglow.to(device)

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    # pdb.set_trace()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_output, mel_postnet_output, duration_output, src_mask, pred_mel_mask, enc_attns, dec_attns, W = model(
                    text, src_len, mel_len, max_src_len, max_mel_len)
                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss = Loss(
                    duration_output, mel_len, mel_output, mel_postnet_output,
                    mel_target, src_mask, pred_mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)

                    plot_attn(train_logger, enc_attns, dec_attns, current_step,
                              hp)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    Audio.tools.inv_mel_spec(
                        mel,
                        os.path.join(
                            synth_path,
                            "step_{}_griffin_lim.wav".format(current_step)))
                    Audio.tools.inv_mel_spec(
                        mel_postnet,
                        os.path.join(
                            synth_path,
                            "step_{}_postnet_griffin_lim.wav".format(
                                current_step)))

                    if hp.vocoder == 'melgan':
                        utils.melgan_infer(
                            mel_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_postnet_torch, melgan,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.melgan_infer(
                            mel_target_torch, melgan,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))
                    elif hp.vocoder == 'waveglow':
                        utils.waveglow_infer(
                            mel_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_postnet_torch, waveglow,
                            os.path.join(
                                hp.synth_path, 'step_{}_postnet_{}.wav'.format(
                                    current_step, hp.vocoder)))
                        utils.waveglow_infer(
                            mel_target_torch, waveglow,
                            os.path.join(
                                hp.synth_path,
                                'step_{}_ground-truth_{}.wav'.format(
                                    current_step, hp.vocoder)))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, m_l, m_p_l = evaluate(model, current_step)
                        t_l = d_l + m_l + m_p_l

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
Пример #8
0
def main(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("test.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=False,
                        collate_fn=dataset.collate_fn,
                        drop_last=False,
                        num_workers=0)

    model = get_FastSpeech2(args.step, full_path=args.model_fs).to(device)

    # Load vocoder
    if hp.vocoder == 'melgan':
        melgan = utils.get_melgan(full_path=args.model_melgan)
    elif hp.vocoder == 'waveglow':
        waveglow = utils.get_waveglow()

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'test'))
    test_logger = SummaryWriter(os.path.join(log_path, 'test'))

    # Init synthesis directory
    test_path = hp.test_path
    if not os.path.exists(test_path):
        os.makedirs(test_path)

    current_step = args.step
    findex = open(os.path.join(test_path, "index.tsv"), "w")

    # Testing
    print("Generate test audio")
    prefix = ""
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            print("Start batch", j)
            fids = data_of_batch["id"]

            # Get Data
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).long().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
                text, src_len)

            for i in range(len(mel_len)):
                fid = fids[i]
                print("Generate audio for:", fid)
                length = mel_len[i].item()

                _mel_target_torch = mel_target[i, :length].detach().unsqueeze(
                    0).transpose(1, 2)

                _mel_target = mel_target[i, :length].detach().cpu().transpose(
                    0, 1)

                _mel_torch = mel_output[i, :length].detach().unsqueeze(
                    0).transpose(1, 2)

                _mel = mel_output[i, :length].detach().cpu().transpose(0, 1)

                _mel_postnet_torch = mel_postnet_output[
                    i, :length].detach().unsqueeze(0).transpose(1, 2)

                _mel_postnet = mel_postnet_output[
                    i, :length].detach().cpu().transpose(0, 1)

                fname = "{}{}_step_{}_gt_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel_target,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "Griffin Lim", "vocoder", fname, "",
                                  fid)

                fname = "{}{}_step_{}_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "FastSpeech2 + GL", "tts", fname, "",
                                  fid)

                fname = "{}{}_step_{}_postnet_griffin_lim.wav".format(
                    prefix, fid, current_step)
                Audio.tools.inv_mel_spec(_mel_postnet,
                                         os.path.join(hp.test_path, fname))
                _write_index_line(findex, "FastSpeech2 + PN + GL", "tts",
                                  fname, "", fid)

                if hp.vocoder == 'melgan':
                    fname = '{}{}_step_{}_ground-truth_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_target_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "Melgan", "vocoder", fname, "",
                                      fid)

                    fname = '{}{}_step_{}_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "FastSpeech2 + Melgan", "tts",
                                      fname, "", fid)

                    fname = '{}{}_step_{}_postnet_{}.wav'.format(
                        prefix, fid, current_step, hp.vocoder)
                    utils.melgan_infer(_mel_postnet_torch, melgan,
                                       os.path.join(hp.test_path, fname))
                    _write_index_line(findex, "FastSpeech2 + PN + Melgan",
                                      "tts", fname, "", fid)

                elif hp.vocoder == 'waveglow':
                    utils.waveglow_infer(
                        _mel_torch, waveglow,
                        os.path.join(
                            hp.test_path,
                            'step_{}_{}.wav'.format(current_step, hp.vocoder)))
                    utils.waveglow_infer(
                        _mel_postnet_torch, waveglow,
                        os.path.join(
                            hp.test_path, 'step_{}_postnet_{}.wav'.format(
                                current_step, hp.vocoder)))
                    utils.waveglow_infer(
                        _mel_target_torch, waveglow,
                        os.path.join(
                            hp.test_path, 'step_{}_ground-truth_{}.wav'.format(
                                current_step, hp.vocoder)))