Example #1
0
    def log_validation(self, reduced_loss, model, y, y_pred, iteration):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image("alignment",
                       plot_alignment_to_numpy(
                           alignments[idx].data.cpu().numpy().T),
                       iteration,
                       dataformats='HWC')
        self.add_image("mel_target",
                       plot_spectrogram_to_numpy(
                           mel_targets[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
        self.add_image("mel_predicted",
                       plot_spectrogram_to_numpy(
                           mel_outputs[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
        self.add_image("gate",
                       plot_gate_outputs_to_numpy(
                           gate_targets[idx].data.cpu().numpy(),
                           torch.sigmoid(
                               gate_outputs[idx]).data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
Example #2
0
    def log(self, y, y_pred, idx, iteration):
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        self.add_image("alignment",
                       plot_alignment_to_numpy(
                           alignments[idx].data.cpu().numpy().T),
                       iteration,
                       dataformats='HWC')
        self.add_image("mel_target",
                       plot_spectrogram_to_numpy(
                           mel_targets[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
        self.add_image("mel_predicted",
                       plot_spectrogram_to_numpy(
                           mel_outputs[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
        self.add_image("gate",
                       plot_gate_outputs_to_numpy(
                           gate_targets[idx].data.cpu().numpy(),
                           torch.sigmoid(
                               gate_outputs[idx]).data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
    def sample_training(self, output, iteration):
        mel_outputs = to_arr(output[0][0])
        mel_outputs_postnet = to_arr(output[1][0])
        alignments = to_arr(output[2][0]).T

        # plot alignment, mel and postnet output
        self.add_image("alignment", plot_alignment_to_numpy(alignments),
                       iteration)
        self.add_image("mel_outputs", plot_spectrogram_to_numpy(mel_outputs),
                       iteration)
        self.add_image("mel_outputs_postnet",
                       plot_spectrogram_to_numpy(mel_outputs_postnet),
                       iteration)

        # save audio
        try:  # sometimes error
            wav = inv_melspectrogram(mel_outputs)
            wav /= max(0.01, np.max(np.abs(wav)))
            wav_postnet = inv_melspectrogram(mel_outputs_postnet)
            wav_postnet /= max(0.01, np.max(np.abs(wav_postnet)))
            self.add_audio('pred', wav, iteration, hps.sample_rate)
            self.add_audio('pred_postnet', wav_postnet, iteration,
                           hps.sample_rate)
        except:
            pass
Example #4
0
    def sample_training(self, output, target, iteration):
        mel_outputs = to_arr(output[0][0])
        mel_target = to_arr(target[0][0])
        mel_outputs_postnet = to_arr(output[1][0])
        alignments = to_arr(output[2][0]).T

        # plot alignment, mel and postnet output
        self.add_image("alignment_test", plot_alignment_to_numpy(alignments),
                       iteration)
        self.add_image("mel_outputs_test",
                       plot_spectrogram_to_numpy(mel_outputs), iteration)
        self.add_image("mel_outputs_postnet_test",
                       plot_spectrogram_to_numpy(mel_outputs_postnet),
                       iteration)
        self.add_image("mel_target_test",
                       plot_spectrogram_to_numpy(mel_target), iteration)

        # save audio
        # try: # sometimes error
        wav = inv_mel_spectrogram(mel_outputs, hps)
        # 			wav *= 32767 / max(0.01, np.max(np.abs(wav)))
        # wav /= max(0.01, np.max(np.abs(wav)))
        wav_postnet = inv_mel_spectrogram(mel_outputs_postnet, hps)
        # 			wav_postnet *= 32767 / max(0.01, np.max(np.abs(wav_postnet)))
        # wav_postnet /= max(0.01, np.max(np.abs(wav_postnet)))
        wav_target = inv_mel_spectrogram(mel_target, hps)
        # 			wav_target *= 32767 / max(0.01, np.max(np.abs(wav_target)))
        # wav_target /= max(0.01, np.max(np.abs(wav_target)))
        self.add_audio('pred_test', wav, iteration, hps.sample_rate)
        self.add_audio('pred_postnet_test', wav_postnet, iteration,
                       hps.sample_rate)
        self.add_audio('target_test', wav_target, iteration, hps.sample_rate)
Example #5
0
 def log_spec(self, specs, olens, iteration, name=None, num=None):
     for spec in specs:
         for k in spec.keys():
             s = spec[k].transpose(1, 2)[0][:, :olens[0]].cpu().data
             if name is not None:
                 self.add_figure(name + '_' + num + '_' + k,
                                 plot_spectrogram_to_numpy(s), iteration)
             else:
                 self.add_figure(k, plot_spectrogram_to_numpy(s), iteration)
Example #6
0
    def log_training_vid(self, output, target, reduced_loss, grad_norm,
                         learning_rate, iteration):
        mel_loss, mel_loss_post, l1_loss, gate_loss = reduced_loss
        self.add_scalar("training.mel_loss", mel_loss, iteration)
        self.add_scalar("training.mel_loss_post", mel_loss_post, iteration)
        self.add_scalar("training.l1_loss", l1_loss, iteration)
        self.add_scalar("training.gate_loss", gate_loss, iteration)
        self.add_scalar("grad.norm", grad_norm, iteration)
        self.add_scalar("learning.rate", learning_rate, iteration)

        mel_outputs = to_arr(output[0][0])
        mel_target = to_arr(target[0][0])
        mel_outputs_postnet = to_arr(output[1][0])
        alignments = to_arr(output[3][0]).T

        # plot alignment, mel and postnet output
        self.add_image("alignment", plot_alignment_to_numpy(alignments),
                       iteration)
        self.add_image("mel_outputs", plot_spectrogram_to_numpy(mel_outputs),
                       iteration)
        self.add_image("mel_outputs_postnet",
                       plot_spectrogram_to_numpy(mel_outputs_postnet),
                       iteration)
        self.add_image("mel_target", plot_spectrogram_to_numpy(mel_target),
                       iteration)

        # save audio
        # try:  # sometimes error
        wav = inv_mel_spectrogram(mel_outputs, hps)
        wav *= 32767 / max(0.01, np.max(np.abs(wav)))
        # wav /= max(0.01, np.max(np.abs(wav)))
        wav_postnet = inv_mel_spectrogram(mel_outputs_postnet, hps)
        wav_postnet *= 32767 / max(0.01, np.max(np.abs(wav_postnet)))
        # wav_postnet /= max(0.01, np.max(np.abs(wav_postnet)))
        wav_target = inv_mel_spectrogram(mel_target, hps)
        wav_target *= 32767 / max(0.01, np.max(np.abs(wav_target)))
        # wav_target /= max(0.01, np.max(np.abs(wav_target)))
        self.add_audio('pred', wav, iteration, hps.sample_rate)
        self.add_audio('pred_postnet', wav_postnet, iteration, hps.sample_rate)
        self.add_audio('target', wav_target, iteration, hps.sample_rate)
def train(args, hp, hp_str, logger, vocoder):
    os.makedirs(os.path.join(hp.train.chkpt_dir, args.name), exist_ok=True)
    os.makedirs(os.path.join(args.outdir, args.name), exist_ok=True)
    os.makedirs(os.path.join(args.outdir, args.name, "assets"), exist_ok=True)
    device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu")

    dataloader = loader.get_tts_dataset(hp.data.data_dir, hp.train.batch_size,
                                        hp)
    validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True)

    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = fastspeech.FeedForwardTransformer(idim, odim, hp)
    # set torch device
    model = model.to(device)
    print("Model is loaded ...")
    githash = get_commit_hash()
    if args.checkpoint_path is not None:
        if os.path.exists(args.checkpoint_path):
            logger.info("Resuming from checkpoint: %s" % args.checkpoint_path)
            checkpoint = torch.load(args.checkpoint_path)
            model.load_state_dict(checkpoint["model"])
            optimizer = get_std_opt(
                model,
                hp.model.adim,
                hp.model.transformer_warmup_steps,
                hp.model.transformer_lr,
            )
            optimizer.load_state_dict(checkpoint["optim"])
            global_step = checkpoint["step"]

            if hp_str != checkpoint["hp_str"]:
                logger.warning(
                    "New hparams is different from checkpoint. Will use new.")

            if githash != checkpoint["githash"]:
                logger.warning(
                    "Code might be different: git hash is different.")
                logger.warning("%s -> %s" % (checkpoint["githash"], githash))

        else:
            print("Checkpoint does not exixts")
            global_step = 0
            return None
    else:
        print("New Training")
        global_step = 0
        optimizer = get_std_opt(
            model,
            hp.model.adim,
            hp.model.transformer_warmup_steps,
            hp.model.transformer_lr,
        )

    print("Batch Size :", hp.train.batch_size)

    num_params(model)

    os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True)
    writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name))
    model.train()
    forward_count = 0
    # print(model)
    for epoch in range(hp.train.epochs):
        start = time.time()
        running_loss = 0
        j = 0

        pbar = tqdm.tqdm(dataloader, desc="Loading train data")
        for data in pbar:
            global_step += 1
            x, input_length, y, _, out_length, _, dur, e, p = data
            # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel]
            #             # stop_token : [batch, T_in], out_length : [batch]

            loss, report_dict = model(
                x.cuda(),
                input_length.cuda(),
                y.cuda(),
                out_length.cuda(),
                dur.cuda(),
                e.cuda(),
                p.cuda(),
            )
            loss = loss.mean() / hp.train.accum_grad
            running_loss += loss.item()

            loss.backward()

            # update parameters
            forward_count += 1
            j = j + 1
            if forward_count != hp.train.accum_grad:
                continue
            forward_count = 0
            step = global_step

            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       hp.train.grad_clip)
            logging.debug("grad norm={}".format(grad_norm))
            if math.isnan(grad_norm):
                logging.warning("grad norm is nan. Do not update model.")
            else:
                optimizer.step()
            optimizer.zero_grad()

            if step % hp.train.summary_interval == 0:
                pbar.set_description(
                    "Average Loss %.04f Loss %.04f | step %d" %
                    (running_loss / j, loss.item(), step))

                for r in report_dict:
                    for k, v in r.items():
                        if k is not None and v is not None:
                            if "cupy" in str(type(v)):
                                v = v.get()
                            if "cupy" in str(type(k)):
                                k = k.get()
                            writer.add_scalar("main/{}".format(k), v, step)

            if step % hp.train.validation_step == 0:

                for valid in validloader:
                    x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid
                    model.eval()
                    with torch.no_grad():
                        loss_, report_dict_ = model(
                            x_.cuda(),
                            input_length_.cuda(),
                            y_.cuda(),
                            out_length_.cuda(),
                            dur_.cuda(),
                            e_.cuda(),
                            p_.cuda(),
                        )

                        mels_ = model.inference(x_[-1].cuda())  # [T, num_mel]

                    model.train()
                    for r in report_dict_:
                        for k, v in r.items():
                            if k is not None and v is not None:
                                if "cupy" in str(type(v)):
                                    v = v.get()
                                if "cupy" in str(type(k)):
                                    k = k.get()
                                writer.add_scalar("validation/{}".format(k), v,
                                                  step)

                    mels_ = mels_.T  # Out: [num_mels, T]
                    writer.add_image(
                        "melspectrogram_target_{}".format(ids_[-1]),
                        plot_spectrogram_to_numpy(
                            y_[-1].T.data.cpu().numpy()[:, :out_length_[-1]]),
                        step,
                        dataformats="HWC",
                    )
                    writer.add_image(
                        "melspectrogram_prediction_{}".format(ids_[-1]),
                        plot_spectrogram_to_numpy(mels_.data.cpu().numpy()),
                        step,
                        dataformats="HWC",
                    )

                    # print(mels.unsqueeze(0).shape)

                    audio = generate_audio(
                        mels_.unsqueeze(0), vocoder
                    )  # selecting the last data point to match mel generated above
                    audio = audio.cpu().float().numpy()
                    audio = audio / (audio.max() - audio.min()
                                     )  # get values between -1 and 1

                    writer.add_audio(
                        tag="generated_audio_{}".format(ids_[-1]),
                        snd_tensor=torch.Tensor(audio),
                        global_step=step,
                        sample_rate=hp.audio.sample_rate,
                    )

                    _, target = read_wav_np(
                        hp.data.wav_dir + f"{ids_[-1]}.wav",
                        sample_rate=hp.audio.sample_rate,
                    )

                    writer.add_audio(
                        tag=" target_audio_{}".format(ids_[-1]),
                        snd_tensor=torch.Tensor(target),
                        global_step=step,
                        sample_rate=hp.audio.sample_rate,
                    )

                ##
            if step % hp.train.save_interval == 0:
                avg_p, avg_e, avg_d = evaluate(hp, validloader, model)
                writer.add_scalar("evaluation/Pitch Loss", avg_p, step)
                writer.add_scalar("evaluation/Energy Loss", avg_e, step)
                writer.add_scalar("evaluation/Dur Loss", avg_d, step)
                save_path = os.path.join(
                    hp.train.chkpt_dir,
                    args.name,
                    "{}_fastspeech_{}_{}k_steps.pyt".format(
                        args.name, githash, step // 1000),
                )

                torch.save(
                    {
                        "model": model.state_dict(),
                        "optim": optimizer.state_dict(),
                        "step": step,
                        "hp_str": hp_str,
                        "githash": githash,
                    },
                    save_path,
                )
                logger.info("Saved checkpoint to: %s" % save_path)
        print("Time taken for epoch {} is {} sec\n".format(
            epoch + 1, int(time.time() - start)))