Ejemplo n.º 1
0
    def validation_epoch_end(self, outputs):
        if self.logger is not None and self.logger.experiment is not None:
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            _, spec_target, spec_predict = outputs[0].values()
            tb_logger.add_image(
                "val_mel_target",
                plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = spec_predict[0].data.cpu().numpy()
            tb_logger.add_image(
                "val_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict.T),
                self.global_step,
                dataformats="HWC",
            )
        avg_loss = torch.stack([
            x['val_loss'] for x in outputs
        ]).mean()  # This reduces across batches, not workers!
        self.log('val_loss', avg_loss, sync_dist=True)

        self.log_train_images = True
Ejemplo n.º 2
0
    def training_epoch_end(self, outputs):
        if self.log_train_images and self.logger is not None and self.logger.experiment is not None:
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            spec_target, spec_predict = outputs[0]["outputs"]
            tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = spec_predict[0].data.cpu().numpy()
            tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict.T),
                self.global_step,
                dataformats="HWC",
            )
            self.log_train_images = False

            return super().training_epoch_end(outputs)
Ejemplo n.º 3
0
    def validation_epoch_end(self, outputs):
        collect = lambda key: torch.stack([x[key] for x in outputs]).mean()
        val_loss = collect("val_loss")
        mel_loss = collect("mel_loss")
        dur_loss = collect("dur_loss")
        pitch_loss = collect("pitch_loss")
        self.log("v_loss", val_loss)
        self.log("v_mel_loss", mel_loss)
        self.log("v_dur_loss", dur_loss)
        self.log("v_pitch_loss", pitch_loss)

        _, _, _, _, spec_target, spec_predict = outputs[0].values()

        if isinstance(self.logger, TensorBoardLogger):
            self.tb_logger.add_image(
                "val_mel_target",
                plot_spectrogram_to_numpy(
                    spec_target[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = spec_predict[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "val_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            self.log_train_images = True
Ejemplo n.º 4
0
 def validation_epoch_end(self, outputs):
     avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
     avg_mle = torch.stack([x['l_mle'] for x in outputs]).mean()
     avg_length_loss = torch.stack([x['l_length'] for x in outputs]).mean()
     avg_logdet = torch.stack([x['logdet'] for x in outputs]).mean()
     tensorboard_logs = {
         'val_loss': avg_loss,
         'val_mle': avg_mle,
         'val_length_loss': avg_length_loss,
         'val_logdet': avg_logdet,
     }
     if self.logger is not None and self.logger.experiment is not None:
         tb_logger = self.logger.experiment
         if isinstance(self.logger, LoggerCollection):
             for logger in self.logger:
                 if isinstance(logger, TensorBoardLogger):
                     tb_logger = logger.experiment
                     break
         separated_phonemes = "|".join(
             [self.parser.symbols[c] for c in outputs[0]['x'][0]])
         tb_logger.add_text("separated phonemes", separated_phonemes,
                            self.global_step)
         tb_logger.add_image(
             "real_spectrogram",
             plot_spectrogram_to_numpy(
                 outputs[0]['y'][0].data.cpu().numpy()),
             self.global_step,
             dataformats="HWC",
         )
         tb_logger.add_image(
             "generated_spectrogram",
             plot_spectrogram_to_numpy(
                 outputs[0]['y_gen'][0].data.cpu().numpy()),
             self.global_step,
             dataformats="HWC",
         )
         tb_logger.add_image(
             "alignment_for_real_sp",
             plot_alignment_to_numpy(
                 outputs[0]['attn'][0].data.cpu().numpy()),
             self.global_step,
             dataformats="HWC",
         )
         tb_logger.add_image(
             "alignment_for_generated_sp",
             plot_alignment_to_numpy(
                 outputs[0]['attn_gen'][0].data.cpu().numpy()),
             self.global_step,
             dataformats="HWC",
         )
         log_audio_to_tb(tb_logger, outputs[0]['y'][0], "true_audio_gf",
                         self.global_step)
         log_audio_to_tb(tb_logger, outputs[0]['y_gen'][0],
                         "generated_audio_gf", self.global_step)
     self.log('val_loss', avg_loss)
     return {'val_loss': avg_loss, 'log': tensorboard_logs}
Ejemplo n.º 5
0
    def validation_step(self, batch, batch_idx):
        audio, audio_len = batch
        audio_mel, audio_mel_len = self.audio_to_melspec_precessor(
            audio, audio_len)
        audio_pred = self(spec=audio_mel)

        audio_pred_mel, _ = self.audio_to_melspec_precessor(
            audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)

        # plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger, WandbLogger):
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i,
                                   0, :audio_len[i]].data.cpu().numpy().astype(
                                       'float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"real audio {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"generated audio {i}",
                    ),
                ]

            self.logger.experiment.log({
                "audio": clips,
                "specs": specs
            },
                                       commit=False)
Ejemplo n.º 6
0
    def validation_epoch_end(self, outputs):
        # Los images and audio manually
        if self.logger is not None and self.logger.experiment is not None:
            if not self.logged_real_samples:
                self.logger.experiment.add_image(
                    "val_mel_target",
                    plot_spectrogram_to_numpy(
                        outputs[0]["spec"][0].data.cpu().numpy()),
                    self.global_step,
                    dataformats="HWC",
                )
                self.logger.experiment.add_audio(
                    "val_wav_target",
                    outputs[0]["audio"][0].data.cpu().numpy(),
                    self.global_step,
                    sample_rate=self.sample_rate,
                )
                self.logged_real_samples = True
            self.logger.experiment.add_image(
                "val_mel_predicted",
                plot_spectrogram_to_numpy(
                    outputs[0]["spec_pred"][0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            self.logger.experiment.add_audio(
                "val_wav_predicted",
                outputs[0]["audio_pred"][0].data.cpu().numpy(),
                self.global_step,
                sample_rate=self.sample_rate,
            )

        def get_stack(list_of_dict, key):
            """
            Helper function to take a list of losses and reduce across all validation batches
            """
            return_list = [[]] * len(list_of_dict[0][key])
            for dict_ in list_of_dict:
                list_of_losses = dict_[key]
                for i, loss in enumerate(list_of_losses):
                    return_list[i].append(loss)
            for i, loss in enumerate(return_list):
                return_list[i] = torch.mean(torch.stack(loss))
            return return_list

        loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log("val_loss", loss, sync_dist=True)

        if self.start_training_disc:
            gan_loss = get_stack(outputs, "gan_loss")
            self.log("val_loss_gan_loss",
                     sum(gan_loss) / len(gan_loss),
                     sync_dist=True)
            for i, _ in enumerate(gan_loss):
                self.log(
                    f"val_loss_gan_loss_{i}",
                    gan_loss[i] / len(gan_loss),
                    sync_dist=True,
                )

        sc_loss = get_stack(outputs, "sc_loss")
        mag_loss = get_stack(outputs, "mag_loss")
        self.log("val_loss_feat_loss",
                 torch.stack([x['loss_feat'] for x in outputs]).mean(),
                 sync_dist=True)
        self.log("val_loss_feat_loss_fb_sc",
                 sum(sc_loss) / len(sc_loss),
                 sync_dist=True)
        self.log("val_loss_feat_loss_fb_mag",
                 sum(mag_loss) / len(sc_loss),
                 sync_dist=True)
        for i, _ in enumerate(sc_loss):
            self.log(f"val_loss_feat_loss_fb_sc_{i}",
                     sc_loss[i] / len(sc_loss),
                     sync_dist=True)
            self.log(f"val_loss_feat_loss_fb_mag_{i}",
                     mag_loss[i] / len(sc_loss),
                     sync_dist=True)
Ejemplo n.º 7
0
    def training_step(self, batch, batch_idx, optimizer_idx):
        f, fl, t, tl, durations, pitch, energies = batch
        spec, spec_len = self.audio_to_melspec_precessor(f, fl)

        # train discriminator
        if optimizer_idx == 0:
            with torch.no_grad():
                audio_pred, splices, _, _, _, _ = self(
                    spec=spec,
                    spec_len=spec_len,
                    text=t,
                    text_length=tl,
                    durations=durations,
                    pitch=pitch if not self.use_pitch_pred else None,
                    energies=energies if not self.use_energy_pred else None,
                )
                real_audio = []
                for i, splice in enumerate(splices):
                    real_audio.append(
                        f[i, splice *
                          self.hop_size:(splice + self.splice_length) *
                          self.hop_size])
                real_audio = torch.stack(real_audio).unsqueeze(1)

            real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(
                real_audio, audio_pred)
            real_score_ms, gen_score_ms, _, _ = self.multiscaledisc(
                real_audio, audio_pred)

            loss_mp, loss_mp_real, _ = self.disc_loss(real_score_mp,
                                                      gen_score_mp)
            loss_ms, loss_ms_real, _ = self.disc_loss(real_score_ms,
                                                      gen_score_ms)
            loss_mp /= len(loss_mp_real)
            loss_ms /= len(loss_ms_real)
            loss_disc = loss_mp + loss_ms

            self.log("loss_discriminator", loss_disc, prog_bar=True)
            self.log("loss_discriminator_ms", loss_ms)
            self.log("loss_discriminator_mp", loss_mp)
            return loss_disc

        # train generator
        elif optimizer_idx == 1:
            audio_pred, splices, log_dur_preds, pitch_preds, energy_preds, encoded_text_mask = self(
                spec=spec,
                spec_len=spec_len,
                text=t,
                text_length=tl,
                durations=durations,
                pitch=pitch if not self.use_pitch_pred else None,
                energies=energies if not self.use_energy_pred else None,
            )
            real_audio = []
            for i, splice in enumerate(splices):
                real_audio.append(
                    f[i, splice * self.hop_size:(splice + self.splice_length) *
                      self.hop_size])
            real_audio = torch.stack(real_audio).unsqueeze(1)

            # Do HiFiGAN generator loss
            audio_length = torch.tensor([
                self.splice_length * self.hop_size
                for _ in range(real_audio.shape[0])
            ]).to(real_audio.device)
            real_spliced_spec, _ = self.melspec_fn(real_audio.squeeze(),
                                                   seq_len=audio_length)
            pred_spliced_spec, _ = self.melspec_fn(audio_pred.squeeze(),
                                                   seq_len=audio_length)
            loss_mel = torch.nn.functional.l1_loss(real_spliced_spec,
                                                   pred_spliced_spec)
            loss_mel *= self.mel_loss_coeff
            _, gen_score_mp, real_feat_mp, gen_feat_mp = self.multiperioddisc(
                real_audio, audio_pred)
            _, gen_score_ms, real_feat_ms, gen_feat_ms = self.multiscaledisc(
                real_audio, audio_pred)
            loss_gen_mp, list_loss_gen_mp = self.gen_loss(gen_score_mp)
            loss_gen_ms, list_loss_gen_ms = self.gen_loss(gen_score_ms)
            loss_gen_mp /= len(list_loss_gen_mp)
            loss_gen_ms /= len(list_loss_gen_ms)
            total_loss = loss_gen_mp + loss_gen_ms + loss_mel
            loss_feat_mp = self.feat_matching_loss(real_feat_mp, gen_feat_mp)
            loss_feat_ms = self.feat_matching_loss(real_feat_ms, gen_feat_ms)
            total_loss += loss_feat_mp + loss_feat_ms
            self.log(name="loss_gen_disc_feat",
                     value=loss_feat_mp + loss_feat_ms)
            self.log(name="loss_gen_disc_feat_ms", value=loss_feat_ms)
            self.log(name="loss_gen_disc_feat_mp", value=loss_feat_mp)

            self.log(name="loss_gen_mel", value=loss_mel)
            self.log(name="loss_gen_disc", value=loss_gen_mp + loss_gen_ms)
            self.log(name="loss_gen_disc_mp", value=loss_gen_mp)
            self.log(name="loss_gen_disc_ms", value=loss_gen_ms)

            dur_loss = self.durationloss(log_duration_pred=log_dur_preds,
                                         duration_target=durations.float(),
                                         mask=encoded_text_mask)
            self.log(name="loss_gen_duration", value=dur_loss)
            total_loss += dur_loss
            if self.pitch:
                pitch_loss = self.mseloss(
                    pitch_preds, pitch.float()) * self.pitch_loss_coeff
                total_loss += pitch_loss
                self.log(name="loss_gen_pitch", value=pitch_loss)
            if self.energy:
                energy_loss = self.mseloss(energy_preds,
                                           energies) * self.energy_loss_coeff
                total_loss += energy_loss
                self.log(name="loss_gen_energy", value=energy_loss)

            # Log images to tensorboard
            if self.log_train_images:
                self.log_train_images = False
                if self.logger is not None and self.logger.experiment is not None:
                    self.tb_logger.add_image(
                        "train_mel_target",
                        plot_spectrogram_to_numpy(
                            real_spliced_spec[0].data.cpu().numpy()),
                        self.global_step,
                        dataformats="HWC",
                    )
                    spec_predict = pred_spliced_spec[0].data.cpu().numpy()
                    self.tb_logger.add_image(
                        "train_mel_predicted",
                        plot_spectrogram_to_numpy(spec_predict),
                        self.global_step,
                        dataformats="HWC",
                    )
            self.log(name="loss_gen", prog_bar=True, value=total_loss)
            return total_loss
Ejemplo n.º 8
0
    def validation_step(self, batch, batch_idx):
        attn_prior, lm_tokens = None, None
        if self.cond_on_lm_embeddings:
            audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch
        else:
            audio, audio_len, text, text_len, attn_prior, pitch, _ = batch

        spect, spect_len = self.preprocessor(input_signal=audio,
                                             length=audio_len)

        # pitch normalization
        zero_pitch_idx = pitch == 0
        pitch = (pitch - self.pitch_mean) / self.pitch_std
        pitch[zero_pitch_idx] = 0.0

        (
            pred_spect,
            _,
            pred_log_durs,
            pred_pitch,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        ) = self(
            text=text,
            text_len=text_len,
            pitch=pitch,
            spect=spect,
            spect_len=spect_len,
            attn_prior=attn_prior,
            lm_tokens=lm_tokens,
        )

        (
            loss,
            durs_loss,
            acc,
            acc_dist_1,
            acc_dist_3,
            pitch_loss,
            mel_loss,
            ctc_loss,
            bin_loss,
        ) = self._metrics(
            pred_durs=pred_log_durs,
            pred_pitch=pred_pitch,
            true_durs=attn_hard_dur,
            true_text_len=text_len,
            true_pitch=pitch,
            true_spect=spect,
            pred_spect=pred_spect,
            true_spect_len=spect_len,
            attn_logprob=attn_logprob,
            attn_soft=attn_soft,
            attn_hard=attn_hard,
            attn_hard_dur=attn_hard_dur,
        )

        # without ground truth internal features except for durations
        pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur = self(
            text=text,
            text_len=text_len,
            pitch=None,
            spect=spect,
            spect_len=spect_len,
            attn_prior=attn_prior,
            lm_tokens=lm_tokens,
        )

        *_, with_pred_features_mel_loss, _, _ = self._metrics(
            pred_durs=pred_log_durs,
            pred_pitch=pred_pitch,
            true_durs=attn_hard_dur,
            true_text_len=text_len,
            true_pitch=pitch,
            true_spect=spect,
            pred_spect=pred_spect,
            true_spect_len=spect_len,
            attn_logprob=attn_logprob,
            attn_soft=attn_soft,
            attn_hard=attn_hard,
            attn_hard_dur=attn_hard_dur,
        )

        val_log = {
            'val_loss':
            loss,
            'val_durs_loss':
            durs_loss,
            'val_pitch_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if pitch_loss is None else pitch_loss,
            'val_mel_loss':
            mel_loss,
            'val_with_pred_features_mel_loss':
            with_pred_features_mel_loss,
            'val_durs_acc':
            acc,
            'val_durs_acc_dist_3':
            acc_dist_3,
            'val_ctc_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if ctc_loss is None else ctc_loss,
            'val_bin_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if bin_loss is None else bin_loss,
        }
        self.log_dict(val_log,
                      prog_bar=False,
                      on_epoch=True,
                      logger=True,
                      sync_dist=True)

        if batch_idx == 0 and self.current_epoch % 5 == 0 and isinstance(
                self.logger, WandbLogger):
            specs = []
            pitches = []
            for i in range(min(3, spect.shape[0])):
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(
                            spect[i, :, :spect_len[i]].data.cpu().numpy()),
                        caption=f"gt mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(
                            pred_spect.transpose(
                                1, 2)[i, :, :spect_len[i]].data.cpu().numpy()),
                        caption=f"pred mel {i}",
                    ),
                ]

                pitches += [
                    wandb.Image(
                        plot_pitch_to_numpy(
                            average_pitch(pitch.unsqueeze(1),
                                          attn_hard_dur).squeeze(1)
                            [i, :text_len[i]].data.cpu().numpy(),
                            ylim_range=[-2.5, 2.5],
                        ),
                        caption=f"gt pitch {i}",
                    ),
                ]

                pitches += [
                    wandb.Image(
                        plot_pitch_to_numpy(
                            pred_pitch[i, :text_len[i]].data.cpu().numpy(),
                            ylim_range=[-2.5, 2.5]),
                        caption=f"pred pitch {i}",
                    ),
                ]

            self.logger.experiment.log({"specs": specs, "pitches": pitches})
Ejemplo n.º 9
0
    def validation_step(self, batch, batch_idx):
        if self.finetune:
            audio, audio_len, audio_mel = batch
            audio_mel_len = [audio_mel.shape[1]] * audio_mel.shape[0]
        else:
            audio, audio_len = batch
            audio_mel, audio_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred = self(spec=audio_mel)

        # perform bias denoising
        pred_denoised = self._bias_denoise(audio_pred, audio_mel).squeeze(1)
        pred_denoised_mel, _ = self.audio_to_melspec_precessor(
            pred_denoised, audio_len)

        if self.finetune:
            gt_mel, gt_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred_mel, _ = self.audio_to_melspec_precessor(
            audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)

        # plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger,
                                         WandbLogger) and HAVE_WANDB:
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i,
                                   0, :audio_len[i]].data.cpu().numpy().astype(
                                       'float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        pred_denoised[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"denoised audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"input mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"output mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(pred_denoised_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"denoised mel {i}",
                    ),
                ]
                if self.finetune:
                    specs += [
                        wandb.Image(
                            plot_spectrogram_to_numpy(gt_mel[
                                i, :, :audio_mel_len[i]].data.cpu().numpy()),
                            caption=f"gt mel {i}",
                        ),
                    ]

            self.logger.experiment.log({
                "audio": clips,
                "specs": specs
            },
                                       commit=False)
Ejemplo n.º 10
0
    def training_step(self, batch, batch_idx):
        attn_prior, durs, speaker = None, None, None
        if self.learn_alignment:
            if self.ds_class_name == "TTSDataset":
                if SpeakerID in self._train_dl.dataset.sup_data_types_set:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
                else:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
            else:
                raise ValueError(
                    f"Unknown vocab class: {self.vocab.__class__.__name__}")
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

        mels, spec_len = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                             in_lens=text_lens,
                                             out_lens=spec_len)
            bin_loss_weight = min(
                self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss += pitch_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_train_images and isinstance(self.logger,
                                                TensorBoardLogger):
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss
Ejemplo n.º 11
0
    def training_step(self, batch, batch_idx):
        attn_prior, durs, speakers = None, None, None
        if self.learn_alignment:
            audio, audio_lens, text, text_lens, attn_prior, pitch = batch
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speakers = batch
        mels, spec_len = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        mels_pred, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speakers,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                             in_lens=text_lens,
                                             out_lens=spec_len)
            bin_loss_weight = min(
                self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss += pitch_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_train_images:
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().numpy().T
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss
Ejemplo n.º 12
0
    def training_step(self, batch, batch_idx, optimizer_idx):
        audio, _, text, text_lens, durs, pitch, _ = batch

        # train discriminator
        if optimizer_idx == 0:
            with torch.no_grad():
                audio_pred, splices, _, _ = self(text=text,
                                                 durs=durs,
                                                 pitch=pitch)
                real_audio = []
                for i, splice in enumerate(splices):
                    real_audio.append(
                        audio[i, splice *
                              self.hop_size:(splice + self.splice_length) *
                              self.hop_size])
                real_audio = torch.stack(real_audio).unsqueeze(1)

            real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(
                real_audio, audio_pred)
            real_score_ms, gen_score_ms, _, _ = self.multiscaledisc(
                real_audio, audio_pred)

            loss_mp, loss_mp_real, _ = self.disc_loss(
                disc_real_outputs=real_score_mp,
                disc_generated_outputs=gen_score_mp)
            loss_ms, loss_ms_real, _ = self.disc_loss(
                disc_real_outputs=real_score_ms,
                disc_generated_outputs=gen_score_ms)
            loss_mp /= len(loss_mp_real)
            loss_ms /= len(loss_ms_real)
            loss_disc = loss_mp + loss_ms

            self.log("loss_discriminator", loss_disc, prog_bar=True)
            self.log("loss_discriminator_ms", loss_ms)
            self.log("loss_discriminator_mp", loss_mp)
            return loss_disc

        # train generator
        elif optimizer_idx == 1:
            audio_pred, splices, log_dur_preds, pitch_preds = self(text=text,
                                                                   durs=durs,
                                                                   pitch=pitch)
            real_audio = []
            for i, splice in enumerate(splices):
                real_audio.append(
                    audio[i, splice *
                          self.hop_size:(splice + self.splice_length) *
                          self.hop_size])
            real_audio = torch.stack(real_audio).unsqueeze(1)

            dur_loss = self.durationloss(log_durs_predicted=log_dur_preds,
                                         durs_tgt=durs,
                                         len=text_lens)
            pitch_loss = self.pitchloss(
                pitch_predicted=pitch_preds,
                pitch_tgt=pitch,
            )

            # Do HiFiGAN generator loss
            audio_length = torch.tensor([
                self.splice_length * self.hop_size
                for _ in range(real_audio.shape[0])
            ]).to(real_audio.device)
            real_spliced_spec, _ = self.melspec_fn(real_audio.squeeze(),
                                                   audio_length)
            pred_spliced_spec, _ = self.melspec_fn(audio_pred.squeeze(),
                                                   audio_length)
            loss_mel = torch.nn.functional.l1_loss(real_spliced_spec,
                                                   pred_spliced_spec)
            loss_mel *= self.mel_loss_coeff
            _, gen_score_mp, _, _ = self.multiperioddisc(
                real_audio, audio_pred)
            _, gen_score_ms, _, _ = self.multiscaledisc(y=real_audio,
                                                        y_hat=audio_pred)
            loss_gen_mp, list_loss_gen_mp = self.gen_loss(
                disc_outputs=gen_score_mp)
            loss_gen_ms, list_loss_gen_ms = self.gen_loss(
                disc_outputs=gen_score_ms)
            loss_gen_mp /= len(list_loss_gen_mp)
            loss_gen_ms /= len(list_loss_gen_ms)
            total_loss = loss_gen_mp + loss_gen_ms + loss_mel
            total_loss += dur_loss
            total_loss += pitch_loss

            self.log(name="loss_gen_mel", value=loss_mel)
            self.log(name="loss_gen_disc", value=loss_gen_mp + loss_gen_ms)
            self.log(name="loss_gen_disc_mp", value=loss_gen_mp)
            self.log(name="loss_gen_disc_ms", value=loss_gen_ms)
            self.log(name="loss_gen_duration", value=dur_loss)
            self.log(name="loss_gen_pitch", value=pitch_loss)

            # Log images to tensorboard
            if self.log_train_images:
                self.log_train_images = False

                if self.logger is not None and self.logger.experiment is not None:
                    self.tb_logger.add_image(
                        "train_mel_target",
                        plot_spectrogram_to_numpy(
                            real_spliced_spec[0].data.cpu().numpy()),
                        self.global_step,
                        dataformats="HWC",
                    )
                    spec_predict = pred_spliced_spec[0].data.cpu().numpy()
                    self.tb_logger.add_image(
                        "train_mel_predicted",
                        plot_spectrogram_to_numpy(spec_predict),
                        self.global_step,
                        dataformats="HWC",
                    )
            self.log(name="loss_gen", prog_bar=True, value=total_loss)
            return total_loss