示例#1
0
 def validation_epoch_end(self, outputs):
     avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
     avg_mle = torch.stack([x['l_mle'] for x in outputs]).mean()
     avg_length_loss = torch.stack([x['l_length'] for x in outputs]).mean()
     avg_logdet = torch.stack([x['logdet'] for x in outputs]).mean()
     tensorboard_logs = {
         'val_loss': avg_loss,
         'val_mle': avg_mle,
         'val_length_loss': avg_length_loss,
         'val_logdet': avg_logdet,
     }
     if self.logger is not None and self.logger.experiment is not None:
         tb_logger = self.logger.experiment
         if isinstance(self.logger, LoggerCollection):
             for logger in self.logger:
                 if isinstance(logger, TensorBoardLogger):
                     tb_logger = logger.experiment
                     break
         separated_phonemes = "|".join(
             [self.parser.symbols[c] for c in outputs[0]['x'][0]])
         tb_logger.add_text("separated phonemes", separated_phonemes,
                            self.global_step)
         tb_logger.add_image(
             "real_spectrogram",
             plot_spectrogram_to_numpy(
                 outputs[0]['y'][0].data.cpu().numpy()),
             self.global_step,
             dataformats="HWC",
         )
         tb_logger.add_image(
             "generated_spectrogram",
             plot_spectrogram_to_numpy(
                 outputs[0]['y_gen'][0].data.cpu().numpy()),
             self.global_step,
             dataformats="HWC",
         )
         tb_logger.add_image(
             "alignment_for_real_sp",
             plot_alignment_to_numpy(
                 outputs[0]['attn'][0].data.cpu().numpy()),
             self.global_step,
             dataformats="HWC",
         )
         tb_logger.add_image(
             "alignment_for_generated_sp",
             plot_alignment_to_numpy(
                 outputs[0]['attn_gen'][0].data.cpu().numpy()),
             self.global_step,
             dataformats="HWC",
         )
         log_audio_to_tb(tb_logger, outputs[0]['y'][0], "true_audio_gf",
                         self.global_step)
         log_audio_to_tb(tb_logger, outputs[0]['y_gen'][0],
                         "generated_audio_gf", self.global_step)
     self.log('val_loss', avg_loss)
     return {'val_loss': avg_loss, 'log': tensorboard_logs}
示例#2
0
    def validation_step(self, batch, batch_idx):
        audio, audio_len, text, text_len, attn_prior = batch
        spec, spec_len = self.preprocessor(audio, audio_len)
        attn_soft, attn_logprob = self(spec=spec,
                                       spec_len=spec_len,
                                       text=text,
                                       text_len=text_len,
                                       attn_prior=attn_prior)

        loss, forward_sum_loss, bin_loss, attn_hard = self._metrics(
            attn_soft, attn_logprob, spec_len, text_len)

        # plot once per epoch
        if batch_idx == 0 and isinstance(self.logger,
                                         WandbLogger) and HAVE_WANDB:
            if attn_hard is None:
                attn_hard = binarize_attention(attn_soft, text_len, spec_len)

            attn_matrices = []
            for i in range(min(5, audio.shape[0])):
                attn_matrices.append(
                    wandb.Image(
                        plot_alignment_to_numpy(
                            np.fliplr(
                                np.rot90(
                                    attn_soft[i,
                                              0, :spec_len[i], :text_len[i]].
                                    data.cpu().numpy()))),
                        caption=f"attn soft",
                    ), )

                attn_matrices.append(
                    wandb.Image(
                        plot_alignment_to_numpy(
                            np.fliplr(
                                np.rot90(
                                    attn_hard[i,
                                              0, :spec_len[i], :text_len[i]].
                                    data.cpu().numpy()))),
                        caption=f"attn hard",
                    ))

            self.logger.experiment.log({"attn_matrices": attn_matrices})

        val_log = {
            'val_loss': loss,
            'val_forward_sum_loss': forward_sum_loss,
            'val_bin_loss': bin_loss
        }
        self.log_dict(val_log,
                      prog_bar=False,
                      on_epoch=True,
                      logger=True,
                      sync_dist=True)
示例#3
0
文件: fastpitch.py 项目: NVIDIA/NeMo
    def training_step(self, batch, batch_idx):
        attn_prior, durs, speaker = None, None, None
        if self.learn_alignment:
            if self.ds_class_name == "TTSDataset":
                if SpeakerID in self._train_dl.dataset.sup_data_types_set:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
                else:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
            else:
                raise ValueError(
                    f"Unknown vocab class: {self.vocab.__class__.__name__}")
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

        mels, spec_len = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                             in_lens=text_lens,
                                             out_lens=spec_len)
            bin_loss_weight = min(
                self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss += pitch_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_train_images and isinstance(self.logger,
                                                TensorBoardLogger):
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss
示例#4
0
    def training_step(self, batch, batch_idx):
        attn_prior, durs, speakers = None, None, None
        if self.learn_alignment:
            audio, audio_lens, text, text_lens, attn_prior, pitch = batch
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speakers = batch
        mels, spec_len = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        mels_pred, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speakers,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                             in_lens=text_lens,
                                             out_lens=spec_len)
            bin_loss_weight = min(
                self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss += pitch_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_train_images:
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().numpy().T
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss