def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() avg_mle = torch.stack([x['l_mle'] for x in outputs]).mean() avg_length_loss = torch.stack([x['l_length'] for x in outputs]).mean() avg_logdet = torch.stack([x['logdet'] for x in outputs]).mean() tensorboard_logs = { 'val_loss': avg_loss, 'val_mle': avg_mle, 'val_length_loss': avg_length_loss, 'val_logdet': avg_logdet, } if self.logger is not None and self.logger.experiment is not None: tb_logger = self.logger.experiment if isinstance(self.logger, LoggerCollection): for logger in self.logger: if isinstance(logger, TensorBoardLogger): tb_logger = logger.experiment break separated_phonemes = "|".join( [self.parser.symbols[c] for c in outputs[0]['x'][0]]) tb_logger.add_text("separated phonemes", separated_phonemes, self.global_step) tb_logger.add_image( "real_spectrogram", plot_spectrogram_to_numpy( outputs[0]['y'][0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) tb_logger.add_image( "generated_spectrogram", plot_spectrogram_to_numpy( outputs[0]['y_gen'][0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) tb_logger.add_image( "alignment_for_real_sp", plot_alignment_to_numpy( outputs[0]['attn'][0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) tb_logger.add_image( "alignment_for_generated_sp", plot_alignment_to_numpy( outputs[0]['attn_gen'][0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) log_audio_to_tb(tb_logger, outputs[0]['y'][0], "true_audio_gf", self.global_step) log_audio_to_tb(tb_logger, outputs[0]['y_gen'][0], "generated_audio_gf", self.global_step) self.log('val_loss', avg_loss) return {'val_loss': avg_loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx): audio, audio_len, text, text_len, attn_prior = batch spec, spec_len = self.preprocessor(audio, audio_len) attn_soft, attn_logprob = self(spec=spec, spec_len=spec_len, text=text, text_len=text_len, attn_prior=attn_prior) loss, forward_sum_loss, bin_loss, attn_hard = self._metrics( attn_soft, attn_logprob, spec_len, text_len) # plot once per epoch if batch_idx == 0 and isinstance(self.logger, WandbLogger) and HAVE_WANDB: if attn_hard is None: attn_hard = binarize_attention(attn_soft, text_len, spec_len) attn_matrices = [] for i in range(min(5, audio.shape[0])): attn_matrices.append( wandb.Image( plot_alignment_to_numpy( np.fliplr( np.rot90( attn_soft[i, 0, :spec_len[i], :text_len[i]]. data.cpu().numpy()))), caption=f"attn soft", ), ) attn_matrices.append( wandb.Image( plot_alignment_to_numpy( np.fliplr( np.rot90( attn_hard[i, 0, :spec_len[i], :text_len[i]]. data.cpu().numpy()))), caption=f"attn hard", )) self.logger.experiment.log({"attn_matrices": attn_matrices}) val_log = { 'val_loss': loss, 'val_forward_sum_loss': forward_sum_loss, 'val_bin_loss': bin_loss } self.log_dict(val_log, prog_bar=False, on_epoch=True, logger=True, sync_dist=True)
def training_step(self, batch, batch_idx): attn_prior, durs, speaker = None, None, None if self.learn_alignment: if self.ds_class_name == "TTSDataset": if SpeakerID in self._train_dl.dataset.sup_data_types_set: audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch else: audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch else: raise ValueError( f"Unknown vocab class: {self.vocab.__class__.__name__}") else: audio, audio_lens, text, text_lens, durs, pitch, speaker = batch mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens) mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self( text=text, durs=durs, pitch=pitch, speaker=speaker, pace=1.0, spec=mels if self.learn_alignment else None, attn_prior=attn_prior, mel_lens=spec_len, input_lens=text_lens, ) if durs is None: durs = attn_hard_dur mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels) dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) loss = mel_loss + dur_loss if self.learn_alignment: ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len) bin_loss_weight = min( self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0 bin_loss = self.bin_loss( hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight loss += ctc_loss + bin_loss pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) loss += pitch_loss self.log("t_loss", loss) self.log("t_mel_loss", mel_loss) self.log("t_dur_loss", dur_loss) self.log("t_pitch_loss", pitch_loss) if self.learn_alignment: self.log("t_ctc_loss", ctc_loss) self.log("t_bin_loss", bin_loss) # Log images to tensorboard if self.log_train_images and isinstance(self.logger, TensorBoardLogger): self.log_train_images = False self.tb_logger.add_image( "train_mel_target", plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()), self.global_step, dataformats="HWC", ) spec_predict = mels_pred[0].data.cpu().float().numpy() self.tb_logger.add_image( "train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", ) if self.learn_alignment: attn = attn_hard[0].data.cpu().float().numpy().squeeze() self.tb_logger.add_image( "train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC", ) soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze() self.tb_logger.add_image( "train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC", ) return loss
def training_step(self, batch, batch_idx): attn_prior, durs, speakers = None, None, None if self.learn_alignment: audio, audio_lens, text, text_lens, attn_prior, pitch = batch else: audio, audio_lens, text, text_lens, durs, pitch, speakers = batch mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens) mels_pred, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self( text=text, durs=durs, pitch=pitch, speaker=speakers, pace=1.0, spec=mels if self.learn_alignment else None, attn_prior=attn_prior, mel_lens=spec_len, input_lens=text_lens, ) if durs is None: durs = attn_hard_dur mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels) dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) loss = mel_loss + dur_loss if self.learn_alignment: ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len) bin_loss_weight = min( self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0 bin_loss = self.bin_loss( hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight loss += ctc_loss + bin_loss pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) loss += pitch_loss self.log("t_loss", loss) self.log("t_mel_loss", mel_loss) self.log("t_dur_loss", dur_loss) self.log("t_pitch_loss", pitch_loss) if self.learn_alignment: self.log("t_ctc_loss", ctc_loss) self.log("t_bin_loss", bin_loss) # Log images to tensorboard if self.log_train_images: self.log_train_images = False self.tb_logger.add_image( "train_mel_target", plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) spec_predict = mels_pred[0].data.cpu().numpy().T self.tb_logger.add_image( "train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", ) if self.learn_alignment: attn = attn_hard[0].data.cpu().numpy().squeeze() self.tb_logger.add_image( "train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC", ) soft_attn = attn_soft[0].data.cpu().numpy().squeeze() self.tb_logger.add_image( "train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC", ) return loss