def log_validation(self, reduced_loss, model, y, y_pred, iteration): self.add_scalar("validation.loss", reduced_loss, iteration) _, mel_outputs, gate_outputs, alignments = y_pred mel_targets, gate_targets = y # plot distribution of parameters for tag, value in model.named_parameters(): tag = tag.replace('.', '/') self.add_histogram(tag, value.data.cpu().numpy(), iteration) # plot alignment, mel target and predicted, gate target and predicted idx = random.randint(0, alignments.size(0) - 1) self.add_image("alignment", plot_alignment_to_numpy( alignments[idx].data.cpu().numpy().T), iteration, dataformats='HWC') self.add_image("mel_target", plot_spectrogram_to_numpy( mel_targets[idx].data.cpu().numpy()), iteration, dataformats='HWC') self.add_image("mel_predicted", plot_spectrogram_to_numpy( mel_outputs[idx].data.cpu().numpy()), iteration, dataformats='HWC') self.add_image("gate", plot_gate_outputs_to_numpy( gate_targets[idx].data.cpu().numpy(), torch.sigmoid( gate_outputs[idx]).data.cpu().numpy()), iteration, dataformats='HWC')
def log(self, y, y_pred, idx, iteration): _, mel_outputs, gate_outputs, alignments = y_pred mel_targets, gate_targets = y self.add_image("alignment", plot_alignment_to_numpy( alignments[idx].data.cpu().numpy().T), iteration, dataformats='HWC') self.add_image("mel_target", plot_spectrogram_to_numpy( mel_targets[idx].data.cpu().numpy()), iteration, dataformats='HWC') self.add_image("mel_predicted", plot_spectrogram_to_numpy( mel_outputs[idx].data.cpu().numpy()), iteration, dataformats='HWC') self.add_image("gate", plot_gate_outputs_to_numpy( gate_targets[idx].data.cpu().numpy(), torch.sigmoid( gate_outputs[idx]).data.cpu().numpy()), iteration, dataformats='HWC')
def sample_training(self, output, iteration): mel_outputs = to_arr(output[0][0]) mel_outputs_postnet = to_arr(output[1][0]) alignments = to_arr(output[2][0]).T # plot alignment, mel and postnet output self.add_image("alignment", plot_alignment_to_numpy(alignments), iteration) self.add_image("mel_outputs", plot_spectrogram_to_numpy(mel_outputs), iteration) self.add_image("mel_outputs_postnet", plot_spectrogram_to_numpy(mel_outputs_postnet), iteration) # save audio try: # sometimes error wav = inv_melspectrogram(mel_outputs) wav /= max(0.01, np.max(np.abs(wav))) wav_postnet = inv_melspectrogram(mel_outputs_postnet) wav_postnet /= max(0.01, np.max(np.abs(wav_postnet))) self.add_audio('pred', wav, iteration, hps.sample_rate) self.add_audio('pred_postnet', wav_postnet, iteration, hps.sample_rate) except: pass
def sample_training(self, output, target, iteration): mel_outputs = to_arr(output[0][0]) mel_target = to_arr(target[0][0]) mel_outputs_postnet = to_arr(output[1][0]) alignments = to_arr(output[2][0]).T # plot alignment, mel and postnet output self.add_image("alignment_test", plot_alignment_to_numpy(alignments), iteration) self.add_image("mel_outputs_test", plot_spectrogram_to_numpy(mel_outputs), iteration) self.add_image("mel_outputs_postnet_test", plot_spectrogram_to_numpy(mel_outputs_postnet), iteration) self.add_image("mel_target_test", plot_spectrogram_to_numpy(mel_target), iteration) # save audio # try: # sometimes error wav = inv_mel_spectrogram(mel_outputs, hps) # wav *= 32767 / max(0.01, np.max(np.abs(wav))) # wav /= max(0.01, np.max(np.abs(wav))) wav_postnet = inv_mel_spectrogram(mel_outputs_postnet, hps) # wav_postnet *= 32767 / max(0.01, np.max(np.abs(wav_postnet))) # wav_postnet /= max(0.01, np.max(np.abs(wav_postnet))) wav_target = inv_mel_spectrogram(mel_target, hps) # wav_target *= 32767 / max(0.01, np.max(np.abs(wav_target))) # wav_target /= max(0.01, np.max(np.abs(wav_target))) self.add_audio('pred_test', wav, iteration, hps.sample_rate) self.add_audio('pred_postnet_test', wav_postnet, iteration, hps.sample_rate) self.add_audio('target_test', wav_target, iteration, hps.sample_rate)
def log_spec(self, specs, olens, iteration, name=None, num=None): for spec in specs: for k in spec.keys(): s = spec[k].transpose(1, 2)[0][:, :olens[0]].cpu().data if name is not None: self.add_figure(name + '_' + num + '_' + k, plot_spectrogram_to_numpy(s), iteration) else: self.add_figure(k, plot_spectrogram_to_numpy(s), iteration)
def log_training_vid(self, output, target, reduced_loss, grad_norm, learning_rate, iteration): mel_loss, mel_loss_post, l1_loss, gate_loss = reduced_loss self.add_scalar("training.mel_loss", mel_loss, iteration) self.add_scalar("training.mel_loss_post", mel_loss_post, iteration) self.add_scalar("training.l1_loss", l1_loss, iteration) self.add_scalar("training.gate_loss", gate_loss, iteration) self.add_scalar("grad.norm", grad_norm, iteration) self.add_scalar("learning.rate", learning_rate, iteration) mel_outputs = to_arr(output[0][0]) mel_target = to_arr(target[0][0]) mel_outputs_postnet = to_arr(output[1][0]) alignments = to_arr(output[3][0]).T # plot alignment, mel and postnet output self.add_image("alignment", plot_alignment_to_numpy(alignments), iteration) self.add_image("mel_outputs", plot_spectrogram_to_numpy(mel_outputs), iteration) self.add_image("mel_outputs_postnet", plot_spectrogram_to_numpy(mel_outputs_postnet), iteration) self.add_image("mel_target", plot_spectrogram_to_numpy(mel_target), iteration) # save audio # try: # sometimes error wav = inv_mel_spectrogram(mel_outputs, hps) wav *= 32767 / max(0.01, np.max(np.abs(wav))) # wav /= max(0.01, np.max(np.abs(wav))) wav_postnet = inv_mel_spectrogram(mel_outputs_postnet, hps) wav_postnet *= 32767 / max(0.01, np.max(np.abs(wav_postnet))) # wav_postnet /= max(0.01, np.max(np.abs(wav_postnet))) wav_target = inv_mel_spectrogram(mel_target, hps) wav_target *= 32767 / max(0.01, np.max(np.abs(wav_target))) # wav_target /= max(0.01, np.max(np.abs(wav_target))) self.add_audio('pred', wav, iteration, hps.sample_rate) self.add_audio('pred_postnet', wav_postnet, iteration, hps.sample_rate) self.add_audio('target', wav_target, iteration, hps.sample_rate)
def train(args, hp, hp_str, logger, vocoder): os.makedirs(os.path.join(hp.train.chkpt_dir, args.name), exist_ok=True) os.makedirs(os.path.join(args.outdir, args.name), exist_ok=True) os.makedirs(os.path.join(args.outdir, args.name, "assets"), exist_ok=True) device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu") dataloader = loader.get_tts_dataset(hp.data.data_dir, hp.train.batch_size, hp) validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True) idim = len(valid_symbols) odim = hp.audio.num_mels model = fastspeech.FeedForwardTransformer(idim, odim, hp) # set torch device model = model.to(device) print("Model is loaded ...") githash = get_commit_hash() if args.checkpoint_path is not None: if os.path.exists(args.checkpoint_path): logger.info("Resuming from checkpoint: %s" % args.checkpoint_path) checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint["model"]) optimizer = get_std_opt( model, hp.model.adim, hp.model.transformer_warmup_steps, hp.model.transformer_lr, ) optimizer.load_state_dict(checkpoint["optim"]) global_step = checkpoint["step"] if hp_str != checkpoint["hp_str"]: logger.warning( "New hparams is different from checkpoint. Will use new.") if githash != checkpoint["githash"]: logger.warning( "Code might be different: git hash is different.") logger.warning("%s -> %s" % (checkpoint["githash"], githash)) else: print("Checkpoint does not exixts") global_step = 0 return None else: print("New Training") global_step = 0 optimizer = get_std_opt( model, hp.model.adim, hp.model.transformer_warmup_steps, hp.model.transformer_lr, ) print("Batch Size :", hp.train.batch_size) num_params(model) os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True) writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name)) model.train() forward_count = 0 # print(model) for epoch in range(hp.train.epochs): start = time.time() running_loss = 0 j = 0 pbar = tqdm.tqdm(dataloader, desc="Loading train data") for data in pbar: global_step += 1 x, input_length, y, _, out_length, _, dur, e, p = data # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel] # # stop_token : [batch, T_in], out_length : [batch] loss, report_dict = model( x.cuda(), input_length.cuda(), y.cuda(), out_length.cuda(), dur.cuda(), e.cuda(), p.cuda(), ) loss = loss.mean() / hp.train.accum_grad running_loss += loss.item() loss.backward() # update parameters forward_count += 1 j = j + 1 if forward_count != hp.train.accum_grad: continue forward_count = 0 step = global_step # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.train.grad_clip) logging.debug("grad norm={}".format(grad_norm)) if math.isnan(grad_norm): logging.warning("grad norm is nan. Do not update model.") else: optimizer.step() optimizer.zero_grad() if step % hp.train.summary_interval == 0: pbar.set_description( "Average Loss %.04f Loss %.04f | step %d" % (running_loss / j, loss.item(), step)) for r in report_dict: for k, v in r.items(): if k is not None and v is not None: if "cupy" in str(type(v)): v = v.get() if "cupy" in str(type(k)): k = k.get() writer.add_scalar("main/{}".format(k), v, step) if step % hp.train.validation_step == 0: for valid in validloader: x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid model.eval() with torch.no_grad(): loss_, report_dict_ = model( x_.cuda(), input_length_.cuda(), y_.cuda(), out_length_.cuda(), dur_.cuda(), e_.cuda(), p_.cuda(), ) mels_ = model.inference(x_[-1].cuda()) # [T, num_mel] model.train() for r in report_dict_: for k, v in r.items(): if k is not None and v is not None: if "cupy" in str(type(v)): v = v.get() if "cupy" in str(type(k)): k = k.get() writer.add_scalar("validation/{}".format(k), v, step) mels_ = mels_.T # Out: [num_mels, T] writer.add_image( "melspectrogram_target_{}".format(ids_[-1]), plot_spectrogram_to_numpy( y_[-1].T.data.cpu().numpy()[:, :out_length_[-1]]), step, dataformats="HWC", ) writer.add_image( "melspectrogram_prediction_{}".format(ids_[-1]), plot_spectrogram_to_numpy(mels_.data.cpu().numpy()), step, dataformats="HWC", ) # print(mels.unsqueeze(0).shape) audio = generate_audio( mels_.unsqueeze(0), vocoder ) # selecting the last data point to match mel generated above audio = audio.cpu().float().numpy() audio = audio / (audio.max() - audio.min() ) # get values between -1 and 1 writer.add_audio( tag="generated_audio_{}".format(ids_[-1]), snd_tensor=torch.Tensor(audio), global_step=step, sample_rate=hp.audio.sample_rate, ) _, target = read_wav_np( hp.data.wav_dir + f"{ids_[-1]}.wav", sample_rate=hp.audio.sample_rate, ) writer.add_audio( tag=" target_audio_{}".format(ids_[-1]), snd_tensor=torch.Tensor(target), global_step=step, sample_rate=hp.audio.sample_rate, ) ## if step % hp.train.save_interval == 0: avg_p, avg_e, avg_d = evaluate(hp, validloader, model) writer.add_scalar("evaluation/Pitch Loss", avg_p, step) writer.add_scalar("evaluation/Energy Loss", avg_e, step) writer.add_scalar("evaluation/Dur Loss", avg_d, step) save_path = os.path.join( hp.train.chkpt_dir, args.name, "{}_fastspeech_{}_{}k_steps.pyt".format( args.name, githash, step // 1000), ) torch.save( { "model": model.state_dict(), "optim": optimizer.state_dict(), "step": step, "hp_str": hp_str, "githash": githash, }, save_path, ) logger.info("Saved checkpoint to: %s" % save_path) print("Time taken for epoch {} is {} sec\n".format( epoch + 1, int(time.time() - start)))