Esempio n. 1
0
 def output_types(self):
     return {
         "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
         "num_frames": NeuralType(('B'), TokenDurationType()),
         "durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()),
         "log_durs_predicted": NeuralType(('B', 'T_text'), TokenLogDurationType()),
         "pitch_predicted": NeuralType(('B', 'T_text'), RegressionValuesType()),
         "attn_soft": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()),
         "attn_logprob": NeuralType(('B', 'S', 'T_spec', 'T_text'), LogprobsType()),
         "attn_hard": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()),
         "attn_hard_dur": NeuralType(('B', 'T_text'), TokenDurationType()),
         "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()),
     }
Esempio n. 2
0
 def input_types(self):
     return {
         "text": NeuralType(('B', 'T_text'), TokenIndex()),
         "durs": NeuralType(('B', 'T_text'), TokenDurationType()),
         "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()),
         "speaker": NeuralType(('B'), Index(), optional=True),
         "pace": NeuralType(optional=True),
         "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
         "attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
         "mel_lens": NeuralType(('B'), LengthsType(), optional=True),
         "input_lens": NeuralType(('B'), LengthsType(), optional=True),
     }
Esempio n. 3
0
class MixerTTSModel(SpectrogramGenerator, Exportable):
    """MixerTTS pipeline."""
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        super().__init__(cfg=cfg, trainer=trainer)
        cfg = self._cfg
        if "text_normalizer" in cfg.train_ds.dataset:
            self.normalizer = instantiate(cfg.train_ds.dataset.text_normalizer)
            self.text_normalizer_call = self.normalizer.normalize
            self.text_normalizer_call_args = {}
            if cfg.train_ds.dataset.get("text_normalizer_call_args",
                                        None) is not None:
                self.text_normalizer_call_args = cfg.train_ds.dataset.text_normalizer_call_args

        self.tokenizer = instantiate(cfg.train_ds.dataset.text_tokenizer)
        num_tokens = len(self.tokenizer.tokens)
        self.tokenizer_pad = self.tokenizer.pad
        self.tokenizer_unk = self.tokenizer.oov

        self.pitch_loss_scale = cfg.pitch_loss_scale
        self.durs_loss_scale = cfg.durs_loss_scale
        self.mel_loss_scale = cfg.mel_loss_scale

        self.aligner = instantiate(cfg.alignment_module)
        self.forward_sum_loss = ForwardSumLoss()
        self.bin_loss = BinLoss()
        self.add_bin_loss = False
        self.bin_loss_scale = 0.0
        self.bin_loss_start_ratio = cfg.bin_loss_start_ratio
        self.bin_loss_warmup_epochs = cfg.bin_loss_warmup_epochs

        self.cond_on_lm_embeddings = cfg.get("cond_on_lm_embeddings", False)

        if self.cond_on_lm_embeddings:
            self.lm_padding_value = (self._train_dl.dataset.lm_padding_value
                                     if self._train_dl is not None else
                                     self._get_lm_padding_value(
                                         cfg.train_ds.dataset.lm_model))
            self.lm_embeddings = self._get_lm_embeddings(
                cfg.train_ds.dataset.lm_model)
            self.lm_embeddings.weight.requires_grad = False

            self.self_attention_module = instantiate(
                cfg.self_attention_module,
                n_lm_tokens_channels=self.lm_embeddings.weight.shape[1])

        self.encoder = instantiate(cfg.encoder,
                                   num_tokens=num_tokens,
                                   padding_idx=self.tokenizer_pad)
        self.symbol_emb = self.encoder.to_embed

        self.duration_predictor = instantiate(cfg.duration_predictor)

        self.pitch_mean, self.pitch_std = float(cfg.pitch_mean), float(
            cfg.pitch_std)
        self.pitch_predictor = instantiate(cfg.pitch_predictor)
        self.pitch_emb = instantiate(cfg.pitch_emb)

        self.preprocessor = instantiate(cfg.preprocessor)

        self.decoder = instantiate(cfg.decoder)
        self.proj = nn.Linear(self.decoder.d_model, cfg.n_mel_channels)

    def _get_lm_model_tokenizer(self, lm_model="albert"):
        if getattr(self, "_lm_model_tokenizer", None) is not None:
            return self._lm_model_tokenizer

        if self._train_dl is not None and self._train_dl.dataset is not None:
            self._lm_model_tokenizer = self._train_dl.dataset.lm_model_tokenizer

        if lm_model == "albert":
            self._lm_model_tokenizer = AlbertTokenizer.from_pretrained(
                'albert-base-v2')
        else:
            raise NotImplementedError(
                f"{lm_model} lm model is not supported. Only albert is supported at this moment."
            )

        return self._lm_model_tokenizer

    def _get_lm_embeddings(self, lm_model="albert"):
        if lm_model == "albert":
            return transformers.AlbertModel.from_pretrained(
                'albert-base-v2').embeddings.word_embeddings
        else:
            raise NotImplementedError(
                f"{lm_model} lm model is not supported. Only albert is supported at this moment."
            )

    def _get_lm_padding_value(self, lm_model="albert"):
        if lm_model == "albert":
            return transformers.AlbertTokenizer.from_pretrained(
                'albert-base-v2')._convert_token_to_id('<pad>')
        else:
            raise NotImplementedError(
                f"{lm_model} lm model is not supported. Only albert is supported at this moment."
            )

    def _metrics(
        self,
        true_durs,
        true_text_len,
        pred_durs,
        true_pitch,
        pred_pitch,
        true_spect=None,
        pred_spect=None,
        true_spect_len=None,
        attn_logprob=None,
        attn_soft=None,
        attn_hard=None,
        attn_hard_dur=None,
    ):
        text_mask = get_mask_from_lengths(true_text_len)
        mel_mask = get_mask_from_lengths(true_spect_len)
        loss = 0.0

        # dur loss and metrics
        durs_loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(),
                               reduction='none')
        durs_loss = durs_loss * text_mask.float()
        durs_loss = durs_loss.sum() / text_mask.sum()

        durs_pred = pred_durs.exp() - 1
        durs_pred = torch.clamp_min(durs_pred, min=0)
        durs_pred = durs_pred.round().long()

        acc = ((true_durs == durs_pred) *
               text_mask).sum().float() / text_mask.sum() * 100
        acc_dist_1 = (((true_durs - durs_pred).abs() <= 1) *
                      text_mask).sum().float() / text_mask.sum() * 100
        acc_dist_3 = (((true_durs - durs_pred).abs() <= 3) *
                      text_mask).sum().float() / text_mask.sum() * 100

        pred_spect = pred_spect.transpose(1, 2)

        # mel loss
        mel_loss = F.mse_loss(pred_spect, true_spect,
                              reduction='none').mean(dim=-2)
        mel_loss = mel_loss * mel_mask.float()
        mel_loss = mel_loss.sum() / mel_mask.sum()

        loss = loss + self.durs_loss_scale * durs_loss + self.mel_loss_scale * mel_loss

        # aligner loss
        bin_loss, ctc_loss = None, None
        ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                         in_lens=true_text_len,
                                         out_lens=true_spect_len)
        loss = loss + ctc_loss
        if self.add_bin_loss:
            bin_loss = self.bin_loss(hard_attention=attn_hard,
                                     soft_attention=attn_soft)
            loss = loss + self.bin_loss_scale * bin_loss
        true_avg_pitch = average_pitch(true_pitch.unsqueeze(1),
                                       attn_hard_dur).squeeze(1)

        # pitch loss
        pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch,
                                reduction='none')  # noqa
        pitch_loss = (pitch_loss * text_mask).sum() / text_mask.sum()

        loss = loss + self.pitch_loss_scale * pitch_loss

        return loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss

    @torch.jit.unused
    def run_aligner(self, text, text_len, text_mask, spect, spect_len,
                    attn_prior):
        text_emb = self.symbol_emb(text)
        attn_soft, attn_logprob = self.aligner(
            spect,
            text_emb.permute(0, 2, 1),
            mask=text_mask == 0,
            attn_prior=attn_prior,
        )
        attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len)
        attn_hard_dur = attn_hard.sum(2)[:, 0, :]
        assert torch.all(torch.eq(attn_hard_dur.sum(dim=1), spect_len))
        return attn_soft, attn_logprob, attn_hard, attn_hard_dur

    @typecheck(
        input_types={
            "text":
            NeuralType(('B', 'T_text'), TokenIndex()),
            "text_len":
            NeuralType(('B', ), LengthsType()),
            "pitch":
            NeuralType(('B', 'T_audio'), RegressionValuesType(),
                       optional=True),
            "spect":
            NeuralType(('B', 'D', 'T_spec'),
                       MelSpectrogramType(),
                       optional=True),
            "spect_len":
            NeuralType(('B', ), LengthsType(), optional=True),
            "attn_prior":
            NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
            "lm_tokens":
            NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True),
        },
        output_types={
            "pred_spect":
            NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
            "durs_predicted":
            NeuralType(('B', 'T_text'), TokenDurationType()),
            "log_durs_predicted":
            NeuralType(('B', 'T_text'), TokenLogDurationType()),
            "pitch_predicted":
            NeuralType(('B', 'T_text'), RegressionValuesType()),
            "attn_soft":
            NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()),
            "attn_logprob":
            NeuralType(('B', 'S', 'T_spec', 'T_text'), LogprobsType()),
            "attn_hard":
            NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()),
            "attn_hard_dur":
            NeuralType(('B', 'T_text'), TokenDurationType()),
        },
    )
    def forward(self,
                text,
                text_len,
                pitch=None,
                spect=None,
                spect_len=None,
                attn_prior=None,
                lm_tokens=None):
        if self.training:
            assert pitch is not None

        text_mask = get_mask_from_lengths(text_len).unsqueeze(2)

        enc_out, enc_mask = self.encoder(text, text_mask)

        # aligner
        attn_soft, attn_logprob, attn_hard, attn_hard_dur = None, None, None, None
        if spect is not None:
            attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner(
                text, text_len, text_mask, spect, spect_len, attn_prior)

        if self.cond_on_lm_embeddings:
            lm_emb = self.lm_embeddings(lm_tokens)
            lm_features = self.self_attention_module(
                enc_out,
                lm_emb,
                lm_emb,
                q_mask=enc_mask.squeeze(2),
                kv_mask=lm_tokens != self.lm_padding_value)

        # duration predictor
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0)

        # pitch predictor
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)

        # avg pitch, add pitch_emb
        if not self.training:
            if pitch is not None:
                pitch = average_pitch(pitch.unsqueeze(1),
                                      attn_hard_dur).squeeze(1)
                pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
            else:
                pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
        else:
            pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))

        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.cond_on_lm_embeddings:
            enc_out = enc_out + lm_features

        # regulate length
        len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out)

        dec_out, dec_lens = self.decoder(
            len_regulated_enc_out,
            get_mask_from_lengths(dec_lens).unsqueeze(2))
        pred_spect = self.proj(dec_out)

        return (
            pred_spect,
            durs_predicted,
            log_durs_predicted,
            pitch_predicted,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        )

    def infer(
        self,
        text,
        text_len=None,
        text_mask=None,
        spect=None,
        spect_len=None,
        attn_prior=None,
        use_gt_durs=False,
        lm_tokens=None,
        pitch=None,
    ):
        if text_mask is None:
            text_mask = get_mask_from_lengths(text_len).unsqueeze(2)

        enc_out, enc_mask = self.encoder(text, text_mask)

        # aligner
        attn_hard_dur = None
        if use_gt_durs:
            attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner(
                text, text_len, text_mask, spect, spect_len, attn_prior)

        if self.cond_on_lm_embeddings:
            lm_emb = self.lm_embeddings(lm_tokens)
            lm_features = self.self_attention_module(
                enc_out,
                lm_emb,
                lm_emb,
                q_mask=enc_mask.squeeze(2),
                kv_mask=lm_tokens != self.lm_padding_value)

        # duration predictor
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0)

        # avg pitch, pitch predictor
        if use_gt_durs and pitch is not None:
            pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        else:
            pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))

        # add pitch emb
        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.cond_on_lm_embeddings:
            enc_out = enc_out + lm_features

        if use_gt_durs:
            if attn_hard_dur is not None:
                len_regulated_enc_out, dec_lens = regulate_len(
                    attn_hard_dur, enc_out)
            else:
                raise NotImplementedError
        else:
            len_regulated_enc_out, dec_lens = regulate_len(
                durs_predicted, enc_out)

        dec_out, _ = self.decoder(len_regulated_enc_out,
                                  get_mask_from_lengths(dec_lens).unsqueeze(2))
        pred_spect = self.proj(dec_out)

        return pred_spect

    def on_train_epoch_start(self):
        bin_loss_start_epoch = np.ceil(self.bin_loss_start_ratio *
                                       self._trainer.max_epochs)

        # Add bin loss when current_epoch >= bin_start_epoch
        if not self.add_bin_loss and self.current_epoch >= bin_loss_start_epoch:
            logging.info(
                f"Using hard attentions after epoch: {self.current_epoch}")
            self.add_bin_loss = True

        if self.add_bin_loss:
            self.bin_loss_scale = min(
                (self.current_epoch - bin_loss_start_epoch) /
                self.bin_loss_warmup_epochs, 1.0)

    def training_step(self, batch, batch_idx):
        attn_prior, lm_tokens = None, None
        if self.cond_on_lm_embeddings:
            audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch
        else:
            audio, audio_len, text, text_len, attn_prior, pitch, _ = batch

        spect, spect_len = self.preprocessor(input_signal=audio,
                                             length=audio_len)

        # pitch normalization
        zero_pitch_idx = pitch == 0
        pitch = (pitch - self.pitch_mean) / self.pitch_std
        pitch[zero_pitch_idx] = 0.0

        (
            pred_spect,
            _,
            pred_log_durs,
            pred_pitch,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        ) = self(
            text=text,
            text_len=text_len,
            pitch=pitch,
            spect=spect,
            spect_len=spect_len,
            attn_prior=attn_prior,
            lm_tokens=lm_tokens,
        )

        (
            loss,
            durs_loss,
            acc,
            acc_dist_1,
            acc_dist_3,
            pitch_loss,
            mel_loss,
            ctc_loss,
            bin_loss,
        ) = self._metrics(
            pred_durs=pred_log_durs,
            pred_pitch=pred_pitch,
            true_durs=attn_hard_dur,
            true_text_len=text_len,
            true_pitch=pitch,
            true_spect=spect,
            pred_spect=pred_spect,
            true_spect_len=spect_len,
            attn_logprob=attn_logprob,
            attn_soft=attn_soft,
            attn_hard=attn_hard,
            attn_hard_dur=attn_hard_dur,
        )

        train_log = {
            'train_loss':
            loss,
            'train_durs_loss':
            durs_loss,
            'train_pitch_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if pitch_loss is None else pitch_loss,
            'train_mel_loss':
            mel_loss,
            'train_durs_acc':
            acc,
            'train_durs_acc_dist_3':
            acc_dist_3,
            'train_ctc_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if ctc_loss is None else ctc_loss,
            'train_bin_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if bin_loss is None else bin_loss,
        }

        return {'loss': loss, 'progress_bar': train_log, 'log': train_log}

    def validation_step(self, batch, batch_idx):
        attn_prior, lm_tokens = None, None
        if self.cond_on_lm_embeddings:
            audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch
        else:
            audio, audio_len, text, text_len, attn_prior, pitch, _ = batch

        spect, spect_len = self.preprocessor(input_signal=audio,
                                             length=audio_len)

        # pitch normalization
        zero_pitch_idx = pitch == 0
        pitch = (pitch - self.pitch_mean) / self.pitch_std
        pitch[zero_pitch_idx] = 0.0

        (
            pred_spect,
            _,
            pred_log_durs,
            pred_pitch,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        ) = self(
            text=text,
            text_len=text_len,
            pitch=pitch,
            spect=spect,
            spect_len=spect_len,
            attn_prior=attn_prior,
            lm_tokens=lm_tokens,
        )

        (
            loss,
            durs_loss,
            acc,
            acc_dist_1,
            acc_dist_3,
            pitch_loss,
            mel_loss,
            ctc_loss,
            bin_loss,
        ) = self._metrics(
            pred_durs=pred_log_durs,
            pred_pitch=pred_pitch,
            true_durs=attn_hard_dur,
            true_text_len=text_len,
            true_pitch=pitch,
            true_spect=spect,
            pred_spect=pred_spect,
            true_spect_len=spect_len,
            attn_logprob=attn_logprob,
            attn_soft=attn_soft,
            attn_hard=attn_hard,
            attn_hard_dur=attn_hard_dur,
        )

        # without ground truth internal features except for durations
        pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur = self(
            text=text,
            text_len=text_len,
            pitch=None,
            spect=spect,
            spect_len=spect_len,
            attn_prior=attn_prior,
            lm_tokens=lm_tokens,
        )

        *_, with_pred_features_mel_loss, _, _ = self._metrics(
            pred_durs=pred_log_durs,
            pred_pitch=pred_pitch,
            true_durs=attn_hard_dur,
            true_text_len=text_len,
            true_pitch=pitch,
            true_spect=spect,
            pred_spect=pred_spect,
            true_spect_len=spect_len,
            attn_logprob=attn_logprob,
            attn_soft=attn_soft,
            attn_hard=attn_hard,
            attn_hard_dur=attn_hard_dur,
        )

        val_log = {
            'val_loss':
            loss,
            'val_durs_loss':
            durs_loss,
            'val_pitch_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if pitch_loss is None else pitch_loss,
            'val_mel_loss':
            mel_loss,
            'val_with_pred_features_mel_loss':
            with_pred_features_mel_loss,
            'val_durs_acc':
            acc,
            'val_durs_acc_dist_3':
            acc_dist_3,
            'val_ctc_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if ctc_loss is None else ctc_loss,
            'val_bin_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if bin_loss is None else bin_loss,
        }
        self.log_dict(val_log,
                      prog_bar=False,
                      on_epoch=True,
                      logger=True,
                      sync_dist=True)

        if batch_idx == 0 and self.current_epoch % 5 == 0 and isinstance(
                self.logger, WandbLogger):
            specs = []
            pitches = []
            for i in range(min(3, spect.shape[0])):
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(
                            spect[i, :, :spect_len[i]].data.cpu().numpy()),
                        caption=f"gt mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(
                            pred_spect.transpose(
                                1, 2)[i, :, :spect_len[i]].data.cpu().numpy()),
                        caption=f"pred mel {i}",
                    ),
                ]

                pitches += [
                    wandb.Image(
                        plot_pitch_to_numpy(
                            average_pitch(pitch.unsqueeze(1),
                                          attn_hard_dur).squeeze(1)
                            [i, :text_len[i]].data.cpu().numpy(),
                            ylim_range=[-2.5, 2.5],
                        ),
                        caption=f"gt pitch {i}",
                    ),
                ]

                pitches += [
                    wandb.Image(
                        plot_pitch_to_numpy(
                            pred_pitch[i, :text_len[i]].data.cpu().numpy(),
                            ylim_range=[-2.5, 2.5]),
                        caption=f"pred pitch {i}",
                    ),
                ]

            self.logger.experiment.log({"specs": specs, "pitches": pitches})

    @typecheck(
        input_types={
            "tokens":
            NeuralType(('B', 'T_text'), TokenIndex(), optional=True),
            "tokens_len":
            NeuralType(('B'), LengthsType(), optional=True),
            "lm_tokens":
            NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True),
            "raw_texts": [NeuralType(optional=True)],
            "lm_model":
            NeuralType(optional=True),
        },
        output_types={
            "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
        },
    )
    def generate_spectrogram(
        self,
        tokens: Optional[torch.Tensor] = None,
        tokens_len: Optional[torch.Tensor] = None,
        lm_tokens: Optional[torch.Tensor] = None,
        raw_texts: Optional[List[str]] = None,
        lm_model: str = "albert",
    ):
        if tokens is not None:
            if tokens_len is None:
                # it is assumed that padding is consecutive and only at the end
                tokens_len = (tokens != self.tokenizer.pad).sum(dim=-1)
        else:
            if raw_texts is None:
                logging.error("raw_texts must be specified if tokens is None")

            t_seqs = [self.tokenizer(t) for t in raw_texts]
            tokens = torch.nn.utils.rnn.pad_sequence(
                sequences=[
                    torch.tensor(t, dtype=torch.long, device=self.device)
                    for t in t_seqs
                ],
                batch_first=True,
                padding_value=self.tokenizer.pad,
            )
            tokens_len = torch.tensor([len(t) for t in t_seqs],
                                      dtype=torch.long,
                                      device=tokens.device)

        if self.cond_on_lm_embeddings and lm_tokens is None:
            if raw_texts is None:
                logging.error(
                    "raw_texts must be specified if lm_tokens is None")

            lm_model_tokenizer = self._get_lm_model_tokenizer(lm_model)
            lm_padding_value = lm_model_tokenizer._convert_token_to_id('<pad>')
            lm_space_value = lm_model_tokenizer._convert_token_to_id('▁')

            assert isinstance(self.tokenizer,
                              EnglishCharsTokenizer) or isinstance(
                                  self.tokenizer, EnglishPhonemesTokenizer)

            preprocess_texts_as_tts_input = [
                self.tokenizer.text_preprocessing_func(t) for t in raw_texts
            ]
            lm_tokens_as_ids_list = [
                lm_model_tokenizer.encode(t, add_special_tokens=False)
                for t in preprocess_texts_as_tts_input
            ]

            if self.tokenizer.pad_with_space:
                lm_tokens_as_ids_list = [[lm_space_value] + t +
                                         [lm_space_value]
                                         for t in lm_tokens_as_ids_list]

            lm_tokens = torch.full(
                (len(lm_tokens_as_ids_list),
                 max([len(t) for t in lm_tokens_as_ids_list])),
                fill_value=lm_padding_value,
                device=tokens.device,
            )
            for i, lm_tokens_i in enumerate(lm_tokens_as_ids_list):
                lm_tokens[i, :len(lm_tokens_i)] = torch.tensor(
                    lm_tokens_i, device=tokens.device)

        pred_spect = self.infer(tokens, tokens_len,
                                lm_tokens=lm_tokens).transpose(1, 2)
        return pred_spect

    def parse(self, text: str, normalize=True) -> torch.Tensor:
        if normalize and getattr(self, "text_normalizer_call",
                                 None) is not None:
            text = self.text_normalizer_call(text,
                                             **self.text_normalizer_call_args)
        return torch.tensor(
            self.tokenizer.encode(text)).long().unsqueeze(0).to(self.device)

    @staticmethod
    def _loader(cfg):
        try:
            _ = cfg.dataset.manifest_filepath
        except omegaconf.errors.MissingMandatoryValue:
            logging.warning(
                "manifest_filepath was skipped. No dataset for this model.")
            return None

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(  # noqa
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            **cfg.dataloader_params,
        )

    def setup_training_data(self, cfg):
        self._train_dl = self._loader(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self._loader(cfg)

    def setup_test_data(self, cfg):
        """Omitted."""
        pass

    @classmethod
    def list_available_models(cls):
        """Empty."""
        pass

    @property
    def input_types(self):
        return {
            "text":
            NeuralType(('B', 'T_text'), TokenIndex()),
            "lm_tokens":
            NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True),
        }

    @property
    def output_types(self):
        return {
            "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
        }

    def forward_for_export(self, text, lm_tokens=None):
        text_mask = (text != self.tokenizer_pad).unsqueeze(2)
        spect = self.infer(text=text, text_mask=text_mask,
                           lm_tokens=lm_tokens).transpose(1, 2)
        return spect.to(torch.float)
Esempio n. 4
0
class FastPitchModel(SpectrogramGenerator, Exportable):
    """FastPitch model (https://arxiv.org/abs/2006.06873) that is used to generate mel spectrogram from text."""
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Setup normalizer
        self.normalizer = None
        self.text_normalizer_call = None
        self.text_normalizer_call_kwargs = {}
        self._setup_normalizer(cfg)

        self.learn_alignment = cfg.get("learn_alignment", False)

        # Setup vocabulary (=tokenizer) and input_fft_kwargs (supported only with self.learn_alignment=True)
        input_fft_kwargs = {}
        if self.learn_alignment:
            self.vocab = None
            self.ds_class_name = cfg.train_ds.dataset._target_.split(".")[-1]

            if self.ds_class_name == "TTSDataset":
                self._setup_tokenizer(cfg)
                assert self.vocab is not None
                input_fft_kwargs["n_embed"] = len(self.vocab.tokens)
                input_fft_kwargs["padding_idx"] = self.vocab.pad
            elif self.ds_class_name == "AudioToCharWithPriorAndPitchDataset":
                logging.warning(
                    "AudioToCharWithPriorAndPitchDataset class has been deprecated. No support for"
                    " training or finetuning. Only inference is supported.")
                tokenizer_conf = self._get_default_text_tokenizer_conf()
                self._setup_tokenizer(tokenizer_conf)
                assert self.vocab is not None
                input_fft_kwargs["n_embed"] = len(self.vocab.tokens)
                input_fft_kwargs["padding_idx"] = self.vocab.pad
            else:
                raise ValueError(
                    f"Unknown dataset class: {self.ds_class_name}")

        self._parser = None
        self._tb_logger = None
        super().__init__(cfg=cfg, trainer=trainer)

        self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100)
        self.log_train_images = False

        loss_scale = 0.1 if self.learn_alignment else 1.0
        dur_loss_scale = loss_scale
        pitch_loss_scale = loss_scale
        if "dur_loss_scale" in cfg:
            dur_loss_scale = cfg.dur_loss_scale
        if "pitch_loss_scale" in cfg:
            pitch_loss_scale = cfg.pitch_loss_scale

        self.mel_loss = MelLoss()
        self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale)
        self.duration_loss = DurationLoss(loss_scale=dur_loss_scale)

        self.aligner = None
        if self.learn_alignment:
            self.aligner = instantiate(self._cfg.alignment_module)
            self.forward_sum_loss = ForwardSumLoss()
            self.bin_loss = BinLoss()

        self.preprocessor = instantiate(self._cfg.preprocessor)
        input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs)
        output_fft = instantiate(self._cfg.output_fft)
        duration_predictor = instantiate(self._cfg.duration_predictor)
        pitch_predictor = instantiate(self._cfg.pitch_predictor)

        self.fastpitch = FastPitchModule(
            input_fft,
            output_fft,
            duration_predictor,
            pitch_predictor,
            self.aligner,
            cfg.n_speakers,
            cfg.symbols_embedding_dim,
            cfg.pitch_embedding_kernel_size,
            cfg.n_mel_channels,
        )
        self._input_types = self._output_types = None

    def _get_default_text_tokenizer_conf(self):
        text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
        return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer))

    def _setup_normalizer(self, cfg):
        if "text_normalizer" in cfg:
            normalizer_kwargs = {}

            if "whitelist" in cfg.text_normalizer:
                normalizer_kwargs["whitelist"] = self.register_artifact(
                    'text_normalizer.whitelist', cfg.text_normalizer.whitelist)

            self.normalizer = instantiate(cfg.text_normalizer,
                                          **normalizer_kwargs)
            self.text_normalizer_call = self.normalizer.normalize
            if "text_normalizer_call_kwargs" in cfg:
                self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs

    def _setup_tokenizer(self, cfg):
        text_tokenizer_kwargs = {}
        if "g2p" in cfg.text_tokenizer:
            g2p_kwargs = {}

            if "phoneme_dict" in cfg.text_tokenizer.g2p:
                g2p_kwargs["phoneme_dict"] = self.register_artifact(
                    'text_tokenizer.g2p.phoneme_dict',
                    cfg.text_tokenizer.g2p.phoneme_dict,
                )

            if "heteronyms" in cfg.text_tokenizer.g2p:
                g2p_kwargs["heteronyms"] = self.register_artifact(
                    'text_tokenizer.g2p.heteronyms',
                    cfg.text_tokenizer.g2p.heteronyms,
                )

            text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p,
                                                       **g2p_kwargs)

        self.vocab = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs)

    @property
    def tb_logger(self):
        if self._tb_logger is None:
            if self.logger is None and self.logger.experiment is None:
                return None
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            self._tb_logger = tb_logger
        return self._tb_logger

    @property
    def parser(self):
        if self._parser is not None:
            return self._parser

        if self.learn_alignment:
            ds_class_name = self._cfg.train_ds.dataset._target_.split(".")[-1]

            if ds_class_name == "TTSDataset":
                self._parser = self.vocab.encode
            elif ds_class_name == "AudioToCharWithPriorAndPitchDataset":
                if self.vocab is None:
                    tokenizer_conf = self._get_default_text_tokenizer_conf()
                    self._setup_tokenizer(tokenizer_conf)
                self._parser = self.vocab.encode
            else:
                raise ValueError(f"Unknown dataset class: {ds_class_name}")
        else:
            self._parser = parsers.make_parser(
                labels=self._cfg.labels,
                name='en',
                unk_id=-1,
                blank_id=-1,
                do_normalize=True,
                abbreviation_version="fastpitch",
                make_table=False,
            )
        return self._parser

    def parse(self, str_input: str, normalize=True) -> torch.tensor:
        if self.training:
            logging.warning("parse() is meant to be called in eval mode.")

        if normalize and self.text_normalizer_call is not None:
            str_input = self.text_normalizer_call(
                str_input, **self.text_normalizer_call_kwargs)

        if self.learn_alignment:
            eval_phon_mode = contextlib.nullcontext()
            if hasattr(self.vocab, "set_phone_prob"):
                eval_phon_mode = self.vocab.set_phone_prob(prob=1.0)

            # Disable mixed g2p representation if necessary
            with eval_phon_mode:
                tokens = self.parser(str_input)
        else:
            tokens = self.parser(str_input)

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    @typecheck(
        input_types={
            "text":
            NeuralType(('B', 'T_text'), TokenIndex()),
            "durs":
            NeuralType(('B', 'T_text'), TokenDurationType()),
            "pitch":
            NeuralType(('B', 'T_audio'), RegressionValuesType()),
            "speaker":
            NeuralType(('B'), Index(), optional=True),
            "pace":
            NeuralType(optional=True),
            "spec":
            NeuralType(('B', 'D', 'T_spec'),
                       MelSpectrogramType(),
                       optional=True),
            "attn_prior":
            NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
            "mel_lens":
            NeuralType(('B'), LengthsType(), optional=True),
            "input_lens":
            NeuralType(('B'), LengthsType(), optional=True),
        })
    def forward(
        self,
        *,
        text,
        durs=None,
        pitch=None,
        speaker=None,
        pace=1.0,
        spec=None,
        attn_prior=None,
        mel_lens=None,
        input_lens=None,
    ):
        return self.fastpitch(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=pace,
            spec=spec,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=input_lens,
        )

    @typecheck(output_types={
        "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())
    })
    def generate_spectrogram(self,
                             tokens: 'torch.tensor',
                             speaker: Optional[int] = None,
                             pace: float = 1.0) -> torch.tensor:
        if self.training:
            logging.warning(
                "generate_spectrogram() is meant to be called in eval mode.")
        if isinstance(speaker, int):
            speaker = torch.tensor([speaker]).to(self.device)
        spect, *_ = self(text=tokens,
                         durs=None,
                         pitch=None,
                         speaker=speaker,
                         pace=pace)
        return spect

    def training_step(self, batch, batch_idx):
        attn_prior, durs, speaker = None, None, None
        if self.learn_alignment:
            if self.ds_class_name == "TTSDataset":
                if SpeakerID in self._train_dl.dataset.sup_data_types_set:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
                else:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
            else:
                raise ValueError(
                    f"Unknown vocab class: {self.vocab.__class__.__name__}")
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

        mels, spec_len = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                             in_lens=text_lens,
                                             out_lens=spec_len)
            bin_loss_weight = min(
                self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss += pitch_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_train_images and isinstance(self.logger,
                                                TensorBoardLogger):
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss

    def validation_step(self, batch, batch_idx):
        attn_prior, durs, speaker = None, None, None
        if self.learn_alignment:
            if self.ds_class_name == "TTSDataset":
                if SpeakerID in self._train_dl.dataset.sup_data_types_set:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
                else:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
            else:
                raise ValueError(
                    f"Unknown vocab class: {self.vocab.__class__.__name__}")
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

        mels, mel_lens = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        # Calculate val loss on ground truth durations to better align L2 loss in time
        mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss = mel_loss + dur_loss + pitch_loss

        return {
            "val_loss": loss,
            "mel_loss": mel_loss,
            "dur_loss": dur_loss,
            "pitch_loss": pitch_loss,
            "mel_target": mels if batch_idx == 0 else None,
            "mel_pred": mels_pred if batch_idx == 0 else None,
        }

    def validation_epoch_end(self, outputs):
        collect = lambda key: torch.stack([x[key] for x in outputs]).mean()
        val_loss = collect("val_loss")
        mel_loss = collect("mel_loss")
        dur_loss = collect("dur_loss")
        pitch_loss = collect("pitch_loss")
        self.log("v_loss", val_loss)
        self.log("v_mel_loss", mel_loss)
        self.log("v_dur_loss", dur_loss)
        self.log("v_pitch_loss", pitch_loss)

        _, _, _, _, spec_target, spec_predict = outputs[0].values()

        if isinstance(self.logger, TensorBoardLogger):
            self.tb_logger.add_image(
                "val_mel_target",
                plot_spectrogram_to_numpy(
                    spec_target[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = spec_predict[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "val_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            self.log_train_images = True

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        if cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset":
            phon_mode = contextlib.nullcontext()
            if hasattr(self.vocab, "set_phone_prob"):
                phon_mode = self.vocab.set_phone_prob(
                    prob=None if name ==
                    "val" else self.vocab.phoneme_probability)

            with phon_mode:
                dataset = instantiate(
                    cfg.dataset,
                    text_normalizer=self.normalizer,
                    text_normalizer_call_kwargs=self.
                    text_normalizer_call_kwargs,
                    text_tokenizer=self.vocab,
                )
        else:
            dataset = instantiate(cfg.dataset)

        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="val")

    def setup_test_data(self, cfg):
        """Omitted."""
        pass

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_fastpitch",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.8.1/files/tts_en_fastpitch_align.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models

    # Methods for model exportability
    def _prepare_for_export(self, **kwargs):
        super()._prepare_for_export(**kwargs)

        # Define input_types and output_types as required by export()
        self._input_types = {
            "text": NeuralType(('B', 'T_text'), TokenIndex()),
            "pitch": NeuralType(('B', 'T_text'), RegressionValuesType()),
            "pace": NeuralType(('B', 'T_text'), optional=True),
            "volume": NeuralType(('B', 'T_text')),
            "speaker": NeuralType(('B'), Index()),
        }
        self._output_types = {
            "spect":
            NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
            "num_frames":
            NeuralType(('B'), TokenDurationType()),
            "durs_predicted":
            NeuralType(('B', 'T_text'), TokenDurationType()),
            "log_durs_predicted":
            NeuralType(('B', 'T_text'), TokenLogDurationType()),
            "pitch_predicted":
            NeuralType(('B', 'T_text'), RegressionValuesType()),
            "volume_aligned":
            NeuralType(('B', 'T_spec'), RegressionValuesType()),
        }

    def _export_teardown(self):
        self._input_types = self._output_types = None

    @property
    def disabled_deployment_input_names(self):
        """Implement this method to return a set of input names disabled for export"""
        disabled_inputs = set()
        if self.fastpitch.speaker_emb is None:
            disabled_inputs.add("speaker")
        return disabled_inputs

    @property
    def input_types(self):
        return self._input_types

    @property
    def output_types(self):
        return self._output_types

    def input_example(self, max_batch=1, max_dim=44):
        """
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        """
        par = next(self.fastpitch.parameters())
        sz = (max_batch, max_dim)
        inp = torch.randint(0,
                            self.fastpitch.encoder.word_emb.num_embeddings,
                            sz,
                            device=par.device,
                            dtype=torch.int64)
        pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
        pace = torch.clamp(
            (torch.randn(sz, device=par.device, dtype=torch.float32) + 1) *
            0.1,
            min=0.01)
        volume = torch.clamp(
            (torch.randn(sz, device=par.device, dtype=torch.float32) + 1) *
            0.1,
            min=0.01)

        inputs = {'text': inp, 'pitch': pitch, 'pace': pace, 'volume': volume}

        if self.fastpitch.speaker_emb is not None:
            inputs['speaker'] = torch.randint(
                0,
                self.fastpitch.speaker_emb.num_embeddings, (max_batch, ),
                device=par.device,
                dtype=torch.int64)

        return (inputs, )

    def forward_for_export(self, text, pitch, pace, volume, speaker=None):
        return self.fastpitch.infer(text=text,
                                    pitch=pitch,
                                    pace=pace,
                                    volume=volume,
                                    speaker=speaker)
Esempio n. 5
0
class FastPitchModel(SpectrogramGenerator):
    """FastPitch Model that is used to generate mel spectrograms from text"""
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)

        self.learn_alignment = False
        if "learn_alignment" in cfg:
            self.learn_alignment = cfg.learn_alignment
        self._parser = None
        self._tb_logger = None
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(FastPitchConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.bin_loss_warmup_epochs = 100
        self.aligner = None
        self.log_train_images = False
        self.mel_loss = MelLoss()
        loss_scale = 0.1 if self.learn_alignment else 1.0
        self.pitch_loss = PitchLoss(loss_scale=loss_scale)
        self.duration_loss = DurationLoss(loss_scale=loss_scale)
        input_fft_kwargs = {}
        if self.learn_alignment:
            self.aligner = instantiate(self._cfg.alignment_module)
            self.forward_sum_loss = ForwardSumLoss()
            self.bin_loss = BinLoss()
            self.vocab = AudioToCharWithDursF0Dataset.make_vocab(
                **self._cfg.train_ds.dataset.vocab)
            input_fft_kwargs["n_embed"] = len(self.vocab.labels)
            input_fft_kwargs["padding_idx"] = self.vocab.pad

        self.preprocessor = instantiate(self._cfg.preprocessor)

        input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs)
        output_fft = instantiate(self._cfg.output_fft)
        duration_predictor = instantiate(self._cfg.duration_predictor)
        pitch_predictor = instantiate(self._cfg.pitch_predictor)

        self.fastpitch = FastPitchModule(
            input_fft,
            output_fft,
            duration_predictor,
            pitch_predictor,
            self.aligner,
            cfg.n_speakers,
            cfg.symbols_embedding_dim,
            cfg.pitch_embedding_kernel_size,
            cfg.n_mel_channels,
        )

    @property
    def tb_logger(self):
        if self._tb_logger is None:
            if self.logger is None and self.logger.experiment is None:
                return None
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            self._tb_logger = tb_logger
        return self._tb_logger

    @property
    def parser(self):
        if self._parser is not None:
            return self._parser

        if self.learn_alignment:
            vocab = AudioToCharWithDursF0Dataset.make_vocab(
                **self._cfg.train_ds.dataset.vocab)
            self._parser = vocab.encode
        else:
            self._parser = parsers.make_parser(
                labels=self._cfg.labels,
                name='en',
                unk_id=-1,
                blank_id=-1,
                do_normalize=True,
                abbreviation_version="fastpitch",
                make_table=False,
            )
        return self._parser

    def parse(self, str_input: str) -> torch.tensor:
        if str_input[-1] not in [".", "!", "?"]:
            str_input = str_input + "."

        tokens = self.parser(str_input)

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    @typecheck(
        input_types={
            "text":
            NeuralType(('B', 'T'), TokenIndex()),
            "durs":
            NeuralType(('B', 'T'), TokenDurationType()),
            "pitch":
            NeuralType(('B', 'T'), RegressionValuesType()),
            "speaker":
            NeuralType(('B'), Index()),
            "pace":
            NeuralType(optional=True),
            "spec":
            NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True),
            "attn_prior":
            NeuralType(('B', 'T', 'T'), ProbsType(), optional=True),
            "mel_lens":
            NeuralType(('B'), LengthsType(), optional=True),
            "input_lens":
            NeuralType(('B'), LengthsType(), optional=True),
        })
    def forward(
        self,
        *,
        text,
        durs=None,
        pitch=None,
        speaker=0,
        pace=1.0,
        spec=None,
        attn_prior=None,
        mel_lens=None,
        input_lens=None,
    ):
        return self.fastpitch(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=pace,
            spec=spec,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=input_lens,
        )

    @typecheck(output_types={
        "spect": NeuralType(('B', 'C', 'T'), MelSpectrogramType())
    })
    def generate_spectrogram(self,
                             tokens: 'torch.tensor',
                             speaker: int = 0,
                             pace: float = 1.0) -> torch.tensor:
        self.eval()
        spect, *_ = self(text=tokens,
                         durs=None,
                         pitch=None,
                         speaker=speaker,
                         pace=pace)
        return spect.transpose(1, 2)

    def training_step(self, batch, batch_idx):
        attn_prior, durs, speakers = None, None, None
        if self.learn_alignment:
            audio, audio_lens, text, text_lens, attn_prior, pitch = batch
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speakers = batch
        mels, spec_len = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        mels_pred, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speakers,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                             in_lens=text_lens,
                                             out_lens=spec_len)
            bin_loss_weight = min(
                self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss += pitch_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_train_images:
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().numpy().T
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss

    def validation_step(self, batch, batch_idx):
        attn_prior, durs, speakers = None, None, None
        if self.learn_alignment:
            audio, audio_lens, text, text_lens, attn_prior, pitch = batch
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speakers = batch
        mels, mel_lens = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        # Calculate val loss on ground truth durations to better align L2 loss in time
        mels_pred, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speakers,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss = mel_loss + dur_loss + pitch_loss

        return {
            "val_loss": loss,
            "mel_loss": mel_loss,
            "dur_loss": dur_loss,
            "pitch_loss": pitch_loss,
            "mel_target": mels if batch_idx == 0 else None,
            "mel_pred": mels_pred if batch_idx == 0 else None,
        }

    def validation_epoch_end(self, outputs):
        collect = lambda key: torch.stack([x[key] for x in outputs]).mean()
        val_loss = collect("val_loss")
        mel_loss = collect("mel_loss")
        dur_loss = collect("dur_loss")
        pitch_loss = collect("pitch_loss")
        self.log("v_loss", val_loss)
        self.log("v_mel_loss", mel_loss)
        self.log("v_dur_loss", dur_loss)
        self.log("v_pitch_loss", pitch_loss)

        _, _, _, _, spec_target, spec_predict = outputs[0].values()
        self.tb_logger.add_image(
            "val_mel_target",
            plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
            self.global_step,
            dataformats="HWC",
        )
        spec_predict = spec_predict[0].data.cpu().numpy()
        self.tb_logger.add_image(
            "val_mel_predicted",
            plot_spectrogram_to_numpy(spec_predict.T),
            self.global_step,
            dataformats="HWC",
        )
        self.log_train_images = True

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        kwargs_dict = {}
        if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset":
            kwargs_dict["parser"] = self.parser
        dataset = instantiate(cfg.dataset, **kwargs_dict)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="val")

    def setup_test_data(self, cfg):
        """Omitted."""
        pass

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_fastpitch",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.0.0/files/tts_en_fastpitch.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models
Esempio n. 6
0
 def input_types(self):
     return {
         "hard_attention": NeuralType(('B', 'S', 'T', 'D'), ProbsType()),
         "soft_attention": NeuralType(('B', 'S', 'T', 'D'), ProbsType()),
     }