예제 #1
0
def synthesize(model, text, sentence, prefix=''):
    src_pos = np.array([i+1 for i in range(text.shape[1])])
    src_pos = np.stack([src_pos])
    src_pos = torch.from_numpy(src_pos).to(device).long()
        
    model.to(device)
    mel, mel_postnet, duration_output, f0_output, energy_output = model(text, src_pos)
    model.to('cpu')
    
    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    mel = mel[0].cpu().transpose(0, 1).detach()
    mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
    f0_output = f0_output[0].detach().cpu().numpy()
    energy_output = energy_output[0].detach().cpu().numpy()

    if not os.path.exists(hp.test_path):
        os.makedirs(hp.test_path)

    Audio.tools.inv_mel_spec(mel_postnet, os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence)))
    wave_glow = utils.get_WaveGlow()
    waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(
        hp.test_path, '{}_waveglow_{}.wav'.format(prefix, sentence)))

    utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))
예제 #2
0
def main(args):
    torch.manual_seed(0)

    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(
        dataset,
        batch_size=hp.batch_size**2,
        shuffle=False,
        collate_fn=dataset.collate_fn,
        drop_last=False,
        num_workers=0,
    )

    # Get model
    model = get_FastSpeech2(args.step).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Init directories
    if not os.path.exists(hp.logger_path):
        os.makedirs(hp.logger_path)
    if not os.path.exists(hp.eval_path):
        os.makedirs(hp.eval_path)

    # Get loss function
    Loss = FastSpeech2Loss().to(device)
    print("Loss Function Defined.")

    # Load vocoder
    wave_glow = utils.get_WaveGlow()

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            mel_pos = torch.from_numpy(
                data_of_batch["mel_pos"]).long().to(device)
            src_pos = torch.from_numpy(
                data_of_batch["src_pos"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_len = max(data_of_batch["mel_len"]).astype(np.int16)

            with torch.no_grad():
                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    duration_output, D, f0_output, f0, energy_output, energy,
                    mel_output, mel_postnet_output, mel_target)

                d_l.append(d_loss.item())
                f_l.append(f_loss.item())
                e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                for k in range(len(mel_target)):
                    length = mel_len[k]

                    mel_target_torch = mel_target[k:k + 1, :length].transpose(
                        1, 2).detach()
                    mel_target_ = mel_target[k, :length].cpu().transpose(
                        0, 1).detach()
                    waveglow.inference.inference(
                        mel_target_torch, wave_glow,
                        os.path.join(
                            hp.eval_path,
                            'ground-truth_{}_waveglow.wav'.format(idx)))

                    mel_postnet_torch = mel_postnet_output[
                        k:k + 1, :length].transpose(1, 2).detach()
                    mel_postnet = mel_postnet_output[
                        k, :length].cpu().transpose(0, 1).detach()
                    waveglow.inference.inference(
                        mel_postnet_torch, wave_glow,
                        os.path.join(hp.eval_path,
                                     'eval_{}_waveglow.wav'.format(idx)))

                    utils.plot_data([
                        (mel_postnet.numpy(), None, None),
                        (mel_target_.numpy(), None, None)
                    ], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        hp.eval_path,
                                        'eval_{}.png'.format(idx)))
                    idx += 1

            current_step += 1

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(args.step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.logger_path, "eval.txt"), "a") as f_logger:
        f_logger.write(str1 + "\n")
        f_logger.write(str2 + "\n")
        f_logger.write(str3 + "\n")
        f_logger.write(str4 + "\n")
        f_logger.write(str5 + "\n")
        f_logger.write(str6 + "\n")
        f_logger.write("\n")
예제 #3
0
파일: eval.py 프로젝트: jingxu10/FastSpeech
    test4 = "I remove attention module in decoder and use average pooling to implement predicting r frames at once"
    test5 = "You can not improve your past, but you can improve your future. Once time is wasted, life is wasted."
    test6 = "Death comes to all, but great achievements raise a monument which shall endure until the sun grows old."
    data_list = list()
    data_list.append(text.text_to_sequence(test1, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test2, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test3, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test4, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test5, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test6, hp.text_cleaners))
    return data_list


if __name__ == "__main__":
    # Test
    WaveGlow = utils.get_WaveGlow()
    parser = argparse.ArgumentParser()
    parser.add_argument('--step', type=int, default=0)
    parser.add_argument("--alpha", type=float, default=1.0)
    args = parser.parse_args()

    print("use griffin-lim and waveglow")
    model = get_DNN(args.step)
    if ipex_enabled:
        model = model.to(ipex.DEVICE)
    data_list = get_data()
    for i, phn in enumerate(data_list):
        mel, mel_cuda = synthesis(model, phn, args.alpha)
        if not os.path.exists("results"):
            os.mkdir("results")
        audio.tools.inv_mel_spec(
예제 #4
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(FastSpeech2()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    wave_glow = utils.get_WaveGlow()

    # Init logger
    log_path = hp.log_path
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    # Init synthesis directory
    synth_path = hp.synth_path
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).int().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                mel_pos = torch.from_numpy(
                    data_of_batch["mel_pos"]).long().to(device)
                src_pos = torch.from_numpy(
                    data_of_batch["src_pos"]).long().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_len = max(data_of_batch["mel_len"]).astype(np.int16)

                # Forward
                mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model(
                    text, src_pos, mel_pos, max_len, D, f0, energy)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    duration_output, D, f0_output, f0, energy_output, energy,
                    mel_output, mel_postnet_output, mel_target, src_len,
                    mel_len)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                with open(os.path.join(log_path, "total_loss.txt"),
                          "a") as f_total_loss:
                    f_total_loss.write(str(t_l) + "\n")
                with open(os.path.join(log_path, "mel_loss.txt"),
                          "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l) + "\n")
                with open(os.path.join(log_path, "mel_postnet_loss.txt"),
                          "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l) + "\n")
                with open(os.path.join(log_path, "duration_loss.txt"),
                          "a") as f_d_loss:
                    f_d_loss.write(str(d_l) + "\n")
                with open(os.path.join(log_path, "f0_loss.txt"),
                          "a") as f_f_loss:
                    f_f_loss.write(str(f_l) + "\n")
                with open(os.path.join(log_path, "energy_loss.txt"),
                          "a") as f_e_loss:
                    f_e_loss.write(str(e_l) + "\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    with open(os.path.join(log_path, "log.txt"), "a") as f_log:
                        f_log.write(str1 + "\n")
                        f_log.write(str2 + "\n")
                        f_log.write(str3 + "\n")
                        f_log.write("\n")

                    logger.add_scalars('Loss/total_loss', {'training': t_l},
                                       current_step)
                    logger.add_scalars('Loss/mel_loss', {'training': m_l},
                                       current_step)
                    logger.add_scalars('Loss/mel_postnet_loss',
                                       {'training': m_p_l}, current_step)
                    logger.add_scalars('Loss/duration_loss', {'training': d_l},
                                       current_step)
                    logger.add_scalars('Loss/F0_loss', {'training': f_l},
                                       current_step)
                    logger.add_scalars('Loss/energy_loss', {'training': e_l},
                                       current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    Audio.tools.inv_mel_spec(
                        mel,
                        os.path.join(
                            synth_path,
                            "step_{}_griffin_lim.wav".format(current_step)))
                    Audio.tools.inv_mel_spec(
                        mel_postnet,
                        os.path.join(
                            synth_path,
                            "step_{}_postnet_griffin_lim.wav".format(
                                current_step)))
                    waveglow.inference.inference(
                        mel_torch, wave_glow,
                        os.path.join(
                            synth_path,
                            "step_{}_waveglow.wav".format(current_step)))
                    waveglow.inference.inference(
                        mel_postnet_torch, wave_glow,
                        os.path.join(
                            synth_path, "step_{}_postnet_waveglow.wav".format(
                                current_step)))
                    waveglow.inference.inference(
                        mel_target_torch, wave_glow,
                        os.path.join(
                            synth_path,
                            "step_{}_ground-truth_waveglow.wav".format(
                                current_step)))

                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()

                    utils.plot_data([
                        (mel_postnet.numpy(), f0_output, energy_output),
                        (mel_target.numpy(), f0, energy)
                    ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'],
                                    filename=os.path.join(
                                        synth_path,
                                        'step_{}.png'.format(current_step)))

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, m_l, m_p_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l

                        logger.add_scalars('Loss/total_loss',
                                           {'validation': t_l}, current_step)
                        logger.add_scalars('Loss/mel_loss',
                                           {'validation': m_l}, current_step)
                        logger.add_scalars('Loss/mel_postnet_loss',
                                           {'validation': m_p_l}, current_step)
                        logger.add_scalars('Loss/duration_loss',
                                           {'validation': d_l}, current_step)
                        logger.add_scalars('Loss/F0_loss', {'validation': f_l},
                                           current_step)
                        logger.add_scalars('Loss/energy_loss',
                                           {'validation': e_l}, current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)