def output_types(self): return { "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), "num_frames": NeuralType(('B'), TokenDurationType()), "durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()), "log_durs_predicted": NeuralType(('B', 'T_text'), TokenLogDurationType()), "pitch_predicted": NeuralType(('B', 'T_text'), RegressionValuesType()), "attn_soft": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()), "attn_logprob": NeuralType(('B', 'S', 'T_spec', 'T_text'), LogprobsType()), "attn_hard": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()), "attn_hard_dur": NeuralType(('B', 'T_text'), TokenDurationType()), "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()), }
def input_types(self): return { "text": NeuralType(('B', 'T_text'), TokenIndex()), "durs": NeuralType(('B', 'T_text'), TokenDurationType()), "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()), "speaker": NeuralType(('B'), Index(), optional=True), "pace": NeuralType(optional=True), "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), "attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True), "mel_lens": NeuralType(('B'), LengthsType(), optional=True), "input_lens": NeuralType(('B'), LengthsType(), optional=True), }
class MixerTTSModel(SpectrogramGenerator, Exportable): """MixerTTS pipeline.""" def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg=cfg, trainer=trainer) cfg = self._cfg if "text_normalizer" in cfg.train_ds.dataset: self.normalizer = instantiate(cfg.train_ds.dataset.text_normalizer) self.text_normalizer_call = self.normalizer.normalize self.text_normalizer_call_args = {} if cfg.train_ds.dataset.get("text_normalizer_call_args", None) is not None: self.text_normalizer_call_args = cfg.train_ds.dataset.text_normalizer_call_args self.tokenizer = instantiate(cfg.train_ds.dataset.text_tokenizer) num_tokens = len(self.tokenizer.tokens) self.tokenizer_pad = self.tokenizer.pad self.tokenizer_unk = self.tokenizer.oov self.pitch_loss_scale = cfg.pitch_loss_scale self.durs_loss_scale = cfg.durs_loss_scale self.mel_loss_scale = cfg.mel_loss_scale self.aligner = instantiate(cfg.alignment_module) self.forward_sum_loss = ForwardSumLoss() self.bin_loss = BinLoss() self.add_bin_loss = False self.bin_loss_scale = 0.0 self.bin_loss_start_ratio = cfg.bin_loss_start_ratio self.bin_loss_warmup_epochs = cfg.bin_loss_warmup_epochs self.cond_on_lm_embeddings = cfg.get("cond_on_lm_embeddings", False) if self.cond_on_lm_embeddings: self.lm_padding_value = (self._train_dl.dataset.lm_padding_value if self._train_dl is not None else self._get_lm_padding_value( cfg.train_ds.dataset.lm_model)) self.lm_embeddings = self._get_lm_embeddings( cfg.train_ds.dataset.lm_model) self.lm_embeddings.weight.requires_grad = False self.self_attention_module = instantiate( cfg.self_attention_module, n_lm_tokens_channels=self.lm_embeddings.weight.shape[1]) self.encoder = instantiate(cfg.encoder, num_tokens=num_tokens, padding_idx=self.tokenizer_pad) self.symbol_emb = self.encoder.to_embed self.duration_predictor = instantiate(cfg.duration_predictor) self.pitch_mean, self.pitch_std = float(cfg.pitch_mean), float( cfg.pitch_std) self.pitch_predictor = instantiate(cfg.pitch_predictor) self.pitch_emb = instantiate(cfg.pitch_emb) self.preprocessor = instantiate(cfg.preprocessor) self.decoder = instantiate(cfg.decoder) self.proj = nn.Linear(self.decoder.d_model, cfg.n_mel_channels) def _get_lm_model_tokenizer(self, lm_model="albert"): if getattr(self, "_lm_model_tokenizer", None) is not None: return self._lm_model_tokenizer if self._train_dl is not None and self._train_dl.dataset is not None: self._lm_model_tokenizer = self._train_dl.dataset.lm_model_tokenizer if lm_model == "albert": self._lm_model_tokenizer = AlbertTokenizer.from_pretrained( 'albert-base-v2') else: raise NotImplementedError( f"{lm_model} lm model is not supported. Only albert is supported at this moment." ) return self._lm_model_tokenizer def _get_lm_embeddings(self, lm_model="albert"): if lm_model == "albert": return transformers.AlbertModel.from_pretrained( 'albert-base-v2').embeddings.word_embeddings else: raise NotImplementedError( f"{lm_model} lm model is not supported. Only albert is supported at this moment." ) def _get_lm_padding_value(self, lm_model="albert"): if lm_model == "albert": return transformers.AlbertTokenizer.from_pretrained( 'albert-base-v2')._convert_token_to_id('<pad>') else: raise NotImplementedError( f"{lm_model} lm model is not supported. Only albert is supported at this moment." ) def _metrics( self, true_durs, true_text_len, pred_durs, true_pitch, pred_pitch, true_spect=None, pred_spect=None, true_spect_len=None, attn_logprob=None, attn_soft=None, attn_hard=None, attn_hard_dur=None, ): text_mask = get_mask_from_lengths(true_text_len) mel_mask = get_mask_from_lengths(true_spect_len) loss = 0.0 # dur loss and metrics durs_loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(), reduction='none') durs_loss = durs_loss * text_mask.float() durs_loss = durs_loss.sum() / text_mask.sum() durs_pred = pred_durs.exp() - 1 durs_pred = torch.clamp_min(durs_pred, min=0) durs_pred = durs_pred.round().long() acc = ((true_durs == durs_pred) * text_mask).sum().float() / text_mask.sum() * 100 acc_dist_1 = (((true_durs - durs_pred).abs() <= 1) * text_mask).sum().float() / text_mask.sum() * 100 acc_dist_3 = (((true_durs - durs_pred).abs() <= 3) * text_mask).sum().float() / text_mask.sum() * 100 pred_spect = pred_spect.transpose(1, 2) # mel loss mel_loss = F.mse_loss(pred_spect, true_spect, reduction='none').mean(dim=-2) mel_loss = mel_loss * mel_mask.float() mel_loss = mel_loss.sum() / mel_mask.sum() loss = loss + self.durs_loss_scale * durs_loss + self.mel_loss_scale * mel_loss # aligner loss bin_loss, ctc_loss = None, None ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=true_text_len, out_lens=true_spect_len) loss = loss + ctc_loss if self.add_bin_loss: bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft) loss = loss + self.bin_loss_scale * bin_loss true_avg_pitch = average_pitch(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1) # pitch loss pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch, reduction='none') # noqa pitch_loss = (pitch_loss * text_mask).sum() / text_mask.sum() loss = loss + self.pitch_loss_scale * pitch_loss return loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss @torch.jit.unused def run_aligner(self, text, text_len, text_mask, spect, spect_len, attn_prior): text_emb = self.symbol_emb(text) attn_soft, attn_logprob = self.aligner( spect, text_emb.permute(0, 2, 1), mask=text_mask == 0, attn_prior=attn_prior, ) attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len) attn_hard_dur = attn_hard.sum(2)[:, 0, :] assert torch.all(torch.eq(attn_hard_dur.sum(dim=1), spect_len)) return attn_soft, attn_logprob, attn_hard, attn_hard_dur @typecheck( input_types={ "text": NeuralType(('B', 'T_text'), TokenIndex()), "text_len": NeuralType(('B', ), LengthsType()), "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType(), optional=True), "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), "spect_len": NeuralType(('B', ), LengthsType(), optional=True), "attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True), "lm_tokens": NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True), }, output_types={ "pred_spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), "durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()), "log_durs_predicted": NeuralType(('B', 'T_text'), TokenLogDurationType()), "pitch_predicted": NeuralType(('B', 'T_text'), RegressionValuesType()), "attn_soft": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()), "attn_logprob": NeuralType(('B', 'S', 'T_spec', 'T_text'), LogprobsType()), "attn_hard": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()), "attn_hard_dur": NeuralType(('B', 'T_text'), TokenDurationType()), }, ) def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_prior=None, lm_tokens=None): if self.training: assert pitch is not None text_mask = get_mask_from_lengths(text_len).unsqueeze(2) enc_out, enc_mask = self.encoder(text, text_mask) # aligner attn_soft, attn_logprob, attn_hard, attn_hard_dur = None, None, None, None if spect is not None: attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner( text, text_len, text_mask, spect, spect_len, attn_prior) if self.cond_on_lm_embeddings: lm_emb = self.lm_embeddings(lm_tokens) lm_features = self.self_attention_module( enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value) # duration predictor log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0) # pitch predictor pitch_predicted = self.pitch_predictor(enc_out, enc_mask) # avg pitch, add pitch_emb if not self.training: if pitch is not None: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) else: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) enc_out = enc_out + pitch_emb.transpose(1, 2) if self.cond_on_lm_embeddings: enc_out = enc_out + lm_features # regulate length len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out) dec_out, dec_lens = self.decoder( len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2)) pred_spect = self.proj(dec_out) return ( pred_spect, durs_predicted, log_durs_predicted, pitch_predicted, attn_soft, attn_logprob, attn_hard, attn_hard_dur, ) def infer( self, text, text_len=None, text_mask=None, spect=None, spect_len=None, attn_prior=None, use_gt_durs=False, lm_tokens=None, pitch=None, ): if text_mask is None: text_mask = get_mask_from_lengths(text_len).unsqueeze(2) enc_out, enc_mask = self.encoder(text, text_mask) # aligner attn_hard_dur = None if use_gt_durs: attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner( text, text_len, text_mask, spect, spect_len, attn_prior) if self.cond_on_lm_embeddings: lm_emb = self.lm_embeddings(lm_tokens) lm_features = self.self_attention_module( enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value) # duration predictor log_durs_predicted = self.duration_predictor(enc_out, enc_mask) durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0) # avg pitch, pitch predictor if use_gt_durs and pitch is not None: pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) else: pitch_predicted = self.pitch_predictor(enc_out, enc_mask) pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) # add pitch emb enc_out = enc_out + pitch_emb.transpose(1, 2) if self.cond_on_lm_embeddings: enc_out = enc_out + lm_features if use_gt_durs: if attn_hard_dur is not None: len_regulated_enc_out, dec_lens = regulate_len( attn_hard_dur, enc_out) else: raise NotImplementedError else: len_regulated_enc_out, dec_lens = regulate_len( durs_predicted, enc_out) dec_out, _ = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2)) pred_spect = self.proj(dec_out) return pred_spect def on_train_epoch_start(self): bin_loss_start_epoch = np.ceil(self.bin_loss_start_ratio * self._trainer.max_epochs) # Add bin loss when current_epoch >= bin_start_epoch if not self.add_bin_loss and self.current_epoch >= bin_loss_start_epoch: logging.info( f"Using hard attentions after epoch: {self.current_epoch}") self.add_bin_loss = True if self.add_bin_loss: self.bin_loss_scale = min( (self.current_epoch - bin_loss_start_epoch) / self.bin_loss_warmup_epochs, 1.0) def training_step(self, batch, batch_idx): attn_prior, lm_tokens = None, None if self.cond_on_lm_embeddings: audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch else: audio, audio_len, text, text_len, attn_prior, pitch, _ = batch spect, spect_len = self.preprocessor(input_signal=audio, length=audio_len) # pitch normalization zero_pitch_idx = pitch == 0 pitch = (pitch - self.pitch_mean) / self.pitch_std pitch[zero_pitch_idx] = 0.0 ( pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur, ) = self( text=text, text_len=text_len, pitch=pitch, spect=spect, spect_len=spect_len, attn_prior=attn_prior, lm_tokens=lm_tokens, ) ( loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss, ) = self._metrics( pred_durs=pred_log_durs, pred_pitch=pred_pitch, true_durs=attn_hard_dur, true_text_len=text_len, true_pitch=pitch, true_spect=spect, pred_spect=pred_spect, true_spect_len=spect_len, attn_logprob=attn_logprob, attn_soft=attn_soft, attn_hard=attn_hard, attn_hard_dur=attn_hard_dur, ) train_log = { 'train_loss': loss, 'train_durs_loss': durs_loss, 'train_pitch_loss': torch.tensor(1.0).to(durs_loss.device) if pitch_loss is None else pitch_loss, 'train_mel_loss': mel_loss, 'train_durs_acc': acc, 'train_durs_acc_dist_3': acc_dist_3, 'train_ctc_loss': torch.tensor(1.0).to(durs_loss.device) if ctc_loss is None else ctc_loss, 'train_bin_loss': torch.tensor(1.0).to(durs_loss.device) if bin_loss is None else bin_loss, } return {'loss': loss, 'progress_bar': train_log, 'log': train_log} def validation_step(self, batch, batch_idx): attn_prior, lm_tokens = None, None if self.cond_on_lm_embeddings: audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch else: audio, audio_len, text, text_len, attn_prior, pitch, _ = batch spect, spect_len = self.preprocessor(input_signal=audio, length=audio_len) # pitch normalization zero_pitch_idx = pitch == 0 pitch = (pitch - self.pitch_mean) / self.pitch_std pitch[zero_pitch_idx] = 0.0 ( pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur, ) = self( text=text, text_len=text_len, pitch=pitch, spect=spect, spect_len=spect_len, attn_prior=attn_prior, lm_tokens=lm_tokens, ) ( loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss, ) = self._metrics( pred_durs=pred_log_durs, pred_pitch=pred_pitch, true_durs=attn_hard_dur, true_text_len=text_len, true_pitch=pitch, true_spect=spect, pred_spect=pred_spect, true_spect_len=spect_len, attn_logprob=attn_logprob, attn_soft=attn_soft, attn_hard=attn_hard, attn_hard_dur=attn_hard_dur, ) # without ground truth internal features except for durations pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur = self( text=text, text_len=text_len, pitch=None, spect=spect, spect_len=spect_len, attn_prior=attn_prior, lm_tokens=lm_tokens, ) *_, with_pred_features_mel_loss, _, _ = self._metrics( pred_durs=pred_log_durs, pred_pitch=pred_pitch, true_durs=attn_hard_dur, true_text_len=text_len, true_pitch=pitch, true_spect=spect, pred_spect=pred_spect, true_spect_len=spect_len, attn_logprob=attn_logprob, attn_soft=attn_soft, attn_hard=attn_hard, attn_hard_dur=attn_hard_dur, ) val_log = { 'val_loss': loss, 'val_durs_loss': durs_loss, 'val_pitch_loss': torch.tensor(1.0).to(durs_loss.device) if pitch_loss is None else pitch_loss, 'val_mel_loss': mel_loss, 'val_with_pred_features_mel_loss': with_pred_features_mel_loss, 'val_durs_acc': acc, 'val_durs_acc_dist_3': acc_dist_3, 'val_ctc_loss': torch.tensor(1.0).to(durs_loss.device) if ctc_loss is None else ctc_loss, 'val_bin_loss': torch.tensor(1.0).to(durs_loss.device) if bin_loss is None else bin_loss, } self.log_dict(val_log, prog_bar=False, on_epoch=True, logger=True, sync_dist=True) if batch_idx == 0 and self.current_epoch % 5 == 0 and isinstance( self.logger, WandbLogger): specs = [] pitches = [] for i in range(min(3, spect.shape[0])): specs += [ wandb.Image( plot_spectrogram_to_numpy( spect[i, :, :spect_len[i]].data.cpu().numpy()), caption=f"gt mel {i}", ), wandb.Image( plot_spectrogram_to_numpy( pred_spect.transpose( 1, 2)[i, :, :spect_len[i]].data.cpu().numpy()), caption=f"pred mel {i}", ), ] pitches += [ wandb.Image( plot_pitch_to_numpy( average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) [i, :text_len[i]].data.cpu().numpy(), ylim_range=[-2.5, 2.5], ), caption=f"gt pitch {i}", ), ] pitches += [ wandb.Image( plot_pitch_to_numpy( pred_pitch[i, :text_len[i]].data.cpu().numpy(), ylim_range=[-2.5, 2.5]), caption=f"pred pitch {i}", ), ] self.logger.experiment.log({"specs": specs, "pitches": pitches}) @typecheck( input_types={ "tokens": NeuralType(('B', 'T_text'), TokenIndex(), optional=True), "tokens_len": NeuralType(('B'), LengthsType(), optional=True), "lm_tokens": NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True), "raw_texts": [NeuralType(optional=True)], "lm_model": NeuralType(optional=True), }, output_types={ "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), }, ) def generate_spectrogram( self, tokens: Optional[torch.Tensor] = None, tokens_len: Optional[torch.Tensor] = None, lm_tokens: Optional[torch.Tensor] = None, raw_texts: Optional[List[str]] = None, lm_model: str = "albert", ): if tokens is not None: if tokens_len is None: # it is assumed that padding is consecutive and only at the end tokens_len = (tokens != self.tokenizer.pad).sum(dim=-1) else: if raw_texts is None: logging.error("raw_texts must be specified if tokens is None") t_seqs = [self.tokenizer(t) for t in raw_texts] tokens = torch.nn.utils.rnn.pad_sequence( sequences=[ torch.tensor(t, dtype=torch.long, device=self.device) for t in t_seqs ], batch_first=True, padding_value=self.tokenizer.pad, ) tokens_len = torch.tensor([len(t) for t in t_seqs], dtype=torch.long, device=tokens.device) if self.cond_on_lm_embeddings and lm_tokens is None: if raw_texts is None: logging.error( "raw_texts must be specified if lm_tokens is None") lm_model_tokenizer = self._get_lm_model_tokenizer(lm_model) lm_padding_value = lm_model_tokenizer._convert_token_to_id('<pad>') lm_space_value = lm_model_tokenizer._convert_token_to_id('▁') assert isinstance(self.tokenizer, EnglishCharsTokenizer) or isinstance( self.tokenizer, EnglishPhonemesTokenizer) preprocess_texts_as_tts_input = [ self.tokenizer.text_preprocessing_func(t) for t in raw_texts ] lm_tokens_as_ids_list = [ lm_model_tokenizer.encode(t, add_special_tokens=False) for t in preprocess_texts_as_tts_input ] if self.tokenizer.pad_with_space: lm_tokens_as_ids_list = [[lm_space_value] + t + [lm_space_value] for t in lm_tokens_as_ids_list] lm_tokens = torch.full( (len(lm_tokens_as_ids_list), max([len(t) for t in lm_tokens_as_ids_list])), fill_value=lm_padding_value, device=tokens.device, ) for i, lm_tokens_i in enumerate(lm_tokens_as_ids_list): lm_tokens[i, :len(lm_tokens_i)] = torch.tensor( lm_tokens_i, device=tokens.device) pred_spect = self.infer(tokens, tokens_len, lm_tokens=lm_tokens).transpose(1, 2) return pred_spect def parse(self, text: str, normalize=True) -> torch.Tensor: if normalize and getattr(self, "text_normalizer_call", None) is not None: text = self.text_normalizer_call(text, **self.text_normalizer_call_args) return torch.tensor( self.tokenizer.encode(text)).long().unsqueeze(0).to(self.device) @staticmethod def _loader(cfg): try: _ = cfg.dataset.manifest_filepath except omegaconf.errors.MissingMandatoryValue: logging.warning( "manifest_filepath was skipped. No dataset for this model.") return None dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader( # noqa dataset=dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params, ) def setup_training_data(self, cfg): self._train_dl = self._loader(cfg) def setup_validation_data(self, cfg): self._validation_dl = self._loader(cfg) def setup_test_data(self, cfg): """Omitted.""" pass @classmethod def list_available_models(cls): """Empty.""" pass @property def input_types(self): return { "text": NeuralType(('B', 'T_text'), TokenIndex()), "lm_tokens": NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True), } @property def output_types(self): return { "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), } def forward_for_export(self, text, lm_tokens=None): text_mask = (text != self.tokenizer_pad).unsqueeze(2) spect = self.infer(text=text, text_mask=text_mask, lm_tokens=lm_tokens).transpose(1, 2) return spect.to(torch.float)
class FastPitchModel(SpectrogramGenerator, Exportable): """FastPitch model (https://arxiv.org/abs/2006.06873) that is used to generate mel spectrogram from text.""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) cfg = model_utils.maybe_update_config_version(cfg) # Setup normalizer self.normalizer = None self.text_normalizer_call = None self.text_normalizer_call_kwargs = {} self._setup_normalizer(cfg) self.learn_alignment = cfg.get("learn_alignment", False) # Setup vocabulary (=tokenizer) and input_fft_kwargs (supported only with self.learn_alignment=True) input_fft_kwargs = {} if self.learn_alignment: self.vocab = None self.ds_class_name = cfg.train_ds.dataset._target_.split(".")[-1] if self.ds_class_name == "TTSDataset": self._setup_tokenizer(cfg) assert self.vocab is not None input_fft_kwargs["n_embed"] = len(self.vocab.tokens) input_fft_kwargs["padding_idx"] = self.vocab.pad elif self.ds_class_name == "AudioToCharWithPriorAndPitchDataset": logging.warning( "AudioToCharWithPriorAndPitchDataset class has been deprecated. No support for" " training or finetuning. Only inference is supported.") tokenizer_conf = self._get_default_text_tokenizer_conf() self._setup_tokenizer(tokenizer_conf) assert self.vocab is not None input_fft_kwargs["n_embed"] = len(self.vocab.tokens) input_fft_kwargs["padding_idx"] = self.vocab.pad else: raise ValueError( f"Unknown dataset class: {self.ds_class_name}") self._parser = None self._tb_logger = None super().__init__(cfg=cfg, trainer=trainer) self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100) self.log_train_images = False loss_scale = 0.1 if self.learn_alignment else 1.0 dur_loss_scale = loss_scale pitch_loss_scale = loss_scale if "dur_loss_scale" in cfg: dur_loss_scale = cfg.dur_loss_scale if "pitch_loss_scale" in cfg: pitch_loss_scale = cfg.pitch_loss_scale self.mel_loss = MelLoss() self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale) self.duration_loss = DurationLoss(loss_scale=dur_loss_scale) self.aligner = None if self.learn_alignment: self.aligner = instantiate(self._cfg.alignment_module) self.forward_sum_loss = ForwardSumLoss() self.bin_loss = BinLoss() self.preprocessor = instantiate(self._cfg.preprocessor) input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs) output_fft = instantiate(self._cfg.output_fft) duration_predictor = instantiate(self._cfg.duration_predictor) pitch_predictor = instantiate(self._cfg.pitch_predictor) self.fastpitch = FastPitchModule( input_fft, output_fft, duration_predictor, pitch_predictor, self.aligner, cfg.n_speakers, cfg.symbols_embedding_dim, cfg.pitch_embedding_kernel_size, cfg.n_mel_channels, ) self._input_types = self._output_types = None def _get_default_text_tokenizer_conf(self): text_tokenizer: TextTokenizerConfig = TextTokenizerConfig() return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer)) def _setup_normalizer(self, cfg): if "text_normalizer" in cfg: normalizer_kwargs = {} if "whitelist" in cfg.text_normalizer: normalizer_kwargs["whitelist"] = self.register_artifact( 'text_normalizer.whitelist', cfg.text_normalizer.whitelist) self.normalizer = instantiate(cfg.text_normalizer, **normalizer_kwargs) self.text_normalizer_call = self.normalizer.normalize if "text_normalizer_call_kwargs" in cfg: self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs def _setup_tokenizer(self, cfg): text_tokenizer_kwargs = {} if "g2p" in cfg.text_tokenizer: g2p_kwargs = {} if "phoneme_dict" in cfg.text_tokenizer.g2p: g2p_kwargs["phoneme_dict"] = self.register_artifact( 'text_tokenizer.g2p.phoneme_dict', cfg.text_tokenizer.g2p.phoneme_dict, ) if "heteronyms" in cfg.text_tokenizer.g2p: g2p_kwargs["heteronyms"] = self.register_artifact( 'text_tokenizer.g2p.heteronyms', cfg.text_tokenizer.g2p.heteronyms, ) text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p, **g2p_kwargs) self.vocab = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs) @property def tb_logger(self): if self._tb_logger is None: if self.logger is None and self.logger.experiment is None: return None tb_logger = self.logger.experiment if isinstance(self.logger, LoggerCollection): for logger in self.logger: if isinstance(logger, TensorBoardLogger): tb_logger = logger.experiment break self._tb_logger = tb_logger return self._tb_logger @property def parser(self): if self._parser is not None: return self._parser if self.learn_alignment: ds_class_name = self._cfg.train_ds.dataset._target_.split(".")[-1] if ds_class_name == "TTSDataset": self._parser = self.vocab.encode elif ds_class_name == "AudioToCharWithPriorAndPitchDataset": if self.vocab is None: tokenizer_conf = self._get_default_text_tokenizer_conf() self._setup_tokenizer(tokenizer_conf) self._parser = self.vocab.encode else: raise ValueError(f"Unknown dataset class: {ds_class_name}") else: self._parser = parsers.make_parser( labels=self._cfg.labels, name='en', unk_id=-1, blank_id=-1, do_normalize=True, abbreviation_version="fastpitch", make_table=False, ) return self._parser def parse(self, str_input: str, normalize=True) -> torch.tensor: if self.training: logging.warning("parse() is meant to be called in eval mode.") if normalize and self.text_normalizer_call is not None: str_input = self.text_normalizer_call( str_input, **self.text_normalizer_call_kwargs) if self.learn_alignment: eval_phon_mode = contextlib.nullcontext() if hasattr(self.vocab, "set_phone_prob"): eval_phon_mode = self.vocab.set_phone_prob(prob=1.0) # Disable mixed g2p representation if necessary with eval_phon_mode: tokens = self.parser(str_input) else: tokens = self.parser(str_input) x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device) return x @typecheck( input_types={ "text": NeuralType(('B', 'T_text'), TokenIndex()), "durs": NeuralType(('B', 'T_text'), TokenDurationType()), "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()), "speaker": NeuralType(('B'), Index(), optional=True), "pace": NeuralType(optional=True), "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), "attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True), "mel_lens": NeuralType(('B'), LengthsType(), optional=True), "input_lens": NeuralType(('B'), LengthsType(), optional=True), }) def forward( self, *, text, durs=None, pitch=None, speaker=None, pace=1.0, spec=None, attn_prior=None, mel_lens=None, input_lens=None, ): return self.fastpitch( text=text, durs=durs, pitch=pitch, speaker=speaker, pace=pace, spec=spec, attn_prior=attn_prior, mel_lens=mel_lens, input_lens=input_lens, ) @typecheck(output_types={ "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()) }) def generate_spectrogram(self, tokens: 'torch.tensor', speaker: Optional[int] = None, pace: float = 1.0) -> torch.tensor: if self.training: logging.warning( "generate_spectrogram() is meant to be called in eval mode.") if isinstance(speaker, int): speaker = torch.tensor([speaker]).to(self.device) spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace) return spect def training_step(self, batch, batch_idx): attn_prior, durs, speaker = None, None, None if self.learn_alignment: if self.ds_class_name == "TTSDataset": if SpeakerID in self._train_dl.dataset.sup_data_types_set: audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch else: audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch else: raise ValueError( f"Unknown vocab class: {self.vocab.__class__.__name__}") else: audio, audio_lens, text, text_lens, durs, pitch, speaker = batch mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens) mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self( text=text, durs=durs, pitch=pitch, speaker=speaker, pace=1.0, spec=mels if self.learn_alignment else None, attn_prior=attn_prior, mel_lens=spec_len, input_lens=text_lens, ) if durs is None: durs = attn_hard_dur mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels) dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) loss = mel_loss + dur_loss if self.learn_alignment: ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len) bin_loss_weight = min( self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0 bin_loss = self.bin_loss( hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight loss += ctc_loss + bin_loss pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) loss += pitch_loss self.log("t_loss", loss) self.log("t_mel_loss", mel_loss) self.log("t_dur_loss", dur_loss) self.log("t_pitch_loss", pitch_loss) if self.learn_alignment: self.log("t_ctc_loss", ctc_loss) self.log("t_bin_loss", bin_loss) # Log images to tensorboard if self.log_train_images and isinstance(self.logger, TensorBoardLogger): self.log_train_images = False self.tb_logger.add_image( "train_mel_target", plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()), self.global_step, dataformats="HWC", ) spec_predict = mels_pred[0].data.cpu().float().numpy() self.tb_logger.add_image( "train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", ) if self.learn_alignment: attn = attn_hard[0].data.cpu().float().numpy().squeeze() self.tb_logger.add_image( "train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC", ) soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze() self.tb_logger.add_image( "train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC", ) return loss def validation_step(self, batch, batch_idx): attn_prior, durs, speaker = None, None, None if self.learn_alignment: if self.ds_class_name == "TTSDataset": if SpeakerID in self._train_dl.dataset.sup_data_types_set: audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch else: audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch else: raise ValueError( f"Unknown vocab class: {self.vocab.__class__.__name__}") else: audio, audio_lens, text, text_lens, durs, pitch, speaker = batch mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens) # Calculate val loss on ground truth durations to better align L2 loss in time mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self( text=text, durs=durs, pitch=pitch, speaker=speaker, pace=1.0, spec=mels if self.learn_alignment else None, attn_prior=attn_prior, mel_lens=mel_lens, input_lens=text_lens, ) if durs is None: durs = attn_hard_dur mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels) dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) loss = mel_loss + dur_loss + pitch_loss return { "val_loss": loss, "mel_loss": mel_loss, "dur_loss": dur_loss, "pitch_loss": pitch_loss, "mel_target": mels if batch_idx == 0 else None, "mel_pred": mels_pred if batch_idx == 0 else None, } def validation_epoch_end(self, outputs): collect = lambda key: torch.stack([x[key] for x in outputs]).mean() val_loss = collect("val_loss") mel_loss = collect("mel_loss") dur_loss = collect("dur_loss") pitch_loss = collect("pitch_loss") self.log("v_loss", val_loss) self.log("v_mel_loss", mel_loss) self.log("v_dur_loss", dur_loss) self.log("v_pitch_loss", pitch_loss) _, _, _, _, spec_target, spec_predict = outputs[0].values() if isinstance(self.logger, TensorBoardLogger): self.tb_logger.add_image( "val_mel_target", plot_spectrogram_to_numpy( spec_target[0].data.cpu().float().numpy()), self.global_step, dataformats="HWC", ) spec_predict = spec_predict[0].data.cpu().float().numpy() self.tb_logger.add_image( "val_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", ) self.log_train_images = True def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") if "dataloader_params" not in cfg or not isinstance( cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True") with open_dict(cfg.dataloader_params): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to False!!!" ) elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to True!!!") if cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset": phon_mode = contextlib.nullcontext() if hasattr(self.vocab, "set_phone_prob"): phon_mode = self.vocab.set_phone_prob( prob=None if name == "val" else self.vocab.phoneme_probability) with phon_mode: dataset = instantiate( cfg.dataset, text_normalizer=self.normalizer, text_normalizer_call_kwargs=self. text_normalizer_call_kwargs, text_tokenizer=self.vocab, ) else: dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) def setup_training_data(self, cfg): self._train_dl = self.__setup_dataloader_from_config(cfg) def setup_validation_data(self, cfg): self._validation_dl = self.__setup_dataloader_from_config( cfg, shuffle_should_be=False, name="val") def setup_test_data(self, cfg): """Omitted.""" pass @classmethod def list_available_models(cls) -> 'List[PretrainedModelInfo]': """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ list_of_models = [] model = PretrainedModelInfo( pretrained_model_name="tts_en_fastpitch", location= "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.8.1/files/tts_en_fastpitch_align.nemo", description= "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.", class_=cls, ) list_of_models.append(model) return list_of_models # Methods for model exportability def _prepare_for_export(self, **kwargs): super()._prepare_for_export(**kwargs) # Define input_types and output_types as required by export() self._input_types = { "text": NeuralType(('B', 'T_text'), TokenIndex()), "pitch": NeuralType(('B', 'T_text'), RegressionValuesType()), "pace": NeuralType(('B', 'T_text'), optional=True), "volume": NeuralType(('B', 'T_text')), "speaker": NeuralType(('B'), Index()), } self._output_types = { "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), "num_frames": NeuralType(('B'), TokenDurationType()), "durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()), "log_durs_predicted": NeuralType(('B', 'T_text'), TokenLogDurationType()), "pitch_predicted": NeuralType(('B', 'T_text'), RegressionValuesType()), "volume_aligned": NeuralType(('B', 'T_spec'), RegressionValuesType()), } def _export_teardown(self): self._input_types = self._output_types = None @property def disabled_deployment_input_names(self): """Implement this method to return a set of input names disabled for export""" disabled_inputs = set() if self.fastpitch.speaker_emb is None: disabled_inputs.add("speaker") return disabled_inputs @property def input_types(self): return self._input_types @property def output_types(self): return self._output_types def input_example(self, max_batch=1, max_dim=44): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ par = next(self.fastpitch.parameters()) sz = (max_batch, max_dim) inp = torch.randint(0, self.fastpitch.encoder.word_emb.num_embeddings, sz, device=par.device, dtype=torch.int64) pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5 pace = torch.clamp( (torch.randn(sz, device=par.device, dtype=torch.float32) + 1) * 0.1, min=0.01) volume = torch.clamp( (torch.randn(sz, device=par.device, dtype=torch.float32) + 1) * 0.1, min=0.01) inputs = {'text': inp, 'pitch': pitch, 'pace': pace, 'volume': volume} if self.fastpitch.speaker_emb is not None: inputs['speaker'] = torch.randint( 0, self.fastpitch.speaker_emb.num_embeddings, (max_batch, ), device=par.device, dtype=torch.int64) return (inputs, ) def forward_for_export(self, text, pitch, pace, volume, speaker=None): return self.fastpitch.infer(text=text, pitch=pitch, pace=pace, volume=volume, speaker=speaker)
class FastPitchModel(SpectrogramGenerator): """FastPitch Model that is used to generate mel spectrograms from text""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) self.learn_alignment = False if "learn_alignment" in cfg: self.learn_alignment = cfg.learn_alignment self._parser = None self._tb_logger = None super().__init__(cfg=cfg, trainer=trainer) schema = OmegaConf.structured(FastPitchConfig) # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) # Ensure passed cfg is compliant with schema OmegaConf.merge(cfg, schema) self.bin_loss_warmup_epochs = 100 self.aligner = None self.log_train_images = False self.mel_loss = MelLoss() loss_scale = 0.1 if self.learn_alignment else 1.0 self.pitch_loss = PitchLoss(loss_scale=loss_scale) self.duration_loss = DurationLoss(loss_scale=loss_scale) input_fft_kwargs = {} if self.learn_alignment: self.aligner = instantiate(self._cfg.alignment_module) self.forward_sum_loss = ForwardSumLoss() self.bin_loss = BinLoss() self.vocab = AudioToCharWithDursF0Dataset.make_vocab( **self._cfg.train_ds.dataset.vocab) input_fft_kwargs["n_embed"] = len(self.vocab.labels) input_fft_kwargs["padding_idx"] = self.vocab.pad self.preprocessor = instantiate(self._cfg.preprocessor) input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs) output_fft = instantiate(self._cfg.output_fft) duration_predictor = instantiate(self._cfg.duration_predictor) pitch_predictor = instantiate(self._cfg.pitch_predictor) self.fastpitch = FastPitchModule( input_fft, output_fft, duration_predictor, pitch_predictor, self.aligner, cfg.n_speakers, cfg.symbols_embedding_dim, cfg.pitch_embedding_kernel_size, cfg.n_mel_channels, ) @property def tb_logger(self): if self._tb_logger is None: if self.logger is None and self.logger.experiment is None: return None tb_logger = self.logger.experiment if isinstance(self.logger, LoggerCollection): for logger in self.logger: if isinstance(logger, TensorBoardLogger): tb_logger = logger.experiment break self._tb_logger = tb_logger return self._tb_logger @property def parser(self): if self._parser is not None: return self._parser if self.learn_alignment: vocab = AudioToCharWithDursF0Dataset.make_vocab( **self._cfg.train_ds.dataset.vocab) self._parser = vocab.encode else: self._parser = parsers.make_parser( labels=self._cfg.labels, name='en', unk_id=-1, blank_id=-1, do_normalize=True, abbreviation_version="fastpitch", make_table=False, ) return self._parser def parse(self, str_input: str) -> torch.tensor: if str_input[-1] not in [".", "!", "?"]: str_input = str_input + "." tokens = self.parser(str_input) x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device) return x @typecheck( input_types={ "text": NeuralType(('B', 'T'), TokenIndex()), "durs": NeuralType(('B', 'T'), TokenDurationType()), "pitch": NeuralType(('B', 'T'), RegressionValuesType()), "speaker": NeuralType(('B'), Index()), "pace": NeuralType(optional=True), "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True), "attn_prior": NeuralType(('B', 'T', 'T'), ProbsType(), optional=True), "mel_lens": NeuralType(('B'), LengthsType(), optional=True), "input_lens": NeuralType(('B'), LengthsType(), optional=True), }) def forward( self, *, text, durs=None, pitch=None, speaker=0, pace=1.0, spec=None, attn_prior=None, mel_lens=None, input_lens=None, ): return self.fastpitch( text=text, durs=durs, pitch=pitch, speaker=speaker, pace=pace, spec=spec, attn_prior=attn_prior, mel_lens=mel_lens, input_lens=input_lens, ) @typecheck(output_types={ "spect": NeuralType(('B', 'C', 'T'), MelSpectrogramType()) }) def generate_spectrogram(self, tokens: 'torch.tensor', speaker: int = 0, pace: float = 1.0) -> torch.tensor: self.eval() spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace) return spect.transpose(1, 2) def training_step(self, batch, batch_idx): attn_prior, durs, speakers = None, None, None if self.learn_alignment: audio, audio_lens, text, text_lens, attn_prior, pitch = batch else: audio, audio_lens, text, text_lens, durs, pitch, speakers = batch mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens) mels_pred, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self( text=text, durs=durs, pitch=pitch, speaker=speakers, pace=1.0, spec=mels if self.learn_alignment else None, attn_prior=attn_prior, mel_lens=spec_len, input_lens=text_lens, ) if durs is None: durs = attn_hard_dur mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels) dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) loss = mel_loss + dur_loss if self.learn_alignment: ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len) bin_loss_weight = min( self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0 bin_loss = self.bin_loss( hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight loss += ctc_loss + bin_loss pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) loss += pitch_loss self.log("t_loss", loss) self.log("t_mel_loss", mel_loss) self.log("t_dur_loss", dur_loss) self.log("t_pitch_loss", pitch_loss) if self.learn_alignment: self.log("t_ctc_loss", ctc_loss) self.log("t_bin_loss", bin_loss) # Log images to tensorboard if self.log_train_images: self.log_train_images = False self.tb_logger.add_image( "train_mel_target", plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) spec_predict = mels_pred[0].data.cpu().numpy().T self.tb_logger.add_image( "train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", ) if self.learn_alignment: attn = attn_hard[0].data.cpu().numpy().squeeze() self.tb_logger.add_image( "train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC", ) soft_attn = attn_soft[0].data.cpu().numpy().squeeze() self.tb_logger.add_image( "train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC", ) return loss def validation_step(self, batch, batch_idx): attn_prior, durs, speakers = None, None, None if self.learn_alignment: audio, audio_lens, text, text_lens, attn_prior, pitch = batch else: audio, audio_lens, text, text_lens, durs, pitch, speakers = batch mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens) # Calculate val loss on ground truth durations to better align L2 loss in time mels_pred, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self( text=text, durs=durs, pitch=pitch, speaker=speakers, pace=1.0, spec=mels if self.learn_alignment else None, attn_prior=attn_prior, mel_lens=mel_lens, input_lens=text_lens, ) if durs is None: durs = attn_hard_dur mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels) dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens) pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens) loss = mel_loss + dur_loss + pitch_loss return { "val_loss": loss, "mel_loss": mel_loss, "dur_loss": dur_loss, "pitch_loss": pitch_loss, "mel_target": mels if batch_idx == 0 else None, "mel_pred": mels_pred if batch_idx == 0 else None, } def validation_epoch_end(self, outputs): collect = lambda key: torch.stack([x[key] for x in outputs]).mean() val_loss = collect("val_loss") mel_loss = collect("mel_loss") dur_loss = collect("dur_loss") pitch_loss = collect("pitch_loss") self.log("v_loss", val_loss) self.log("v_mel_loss", mel_loss) self.log("v_dur_loss", dur_loss) self.log("v_pitch_loss", pitch_loss) _, _, _, _, spec_target, spec_predict = outputs[0].values() self.tb_logger.add_image( "val_mel_target", plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) spec_predict = spec_predict[0].data.cpu().numpy() self.tb_logger.add_image( "val_mel_predicted", plot_spectrogram_to_numpy(spec_predict.T), self.global_step, dataformats="HWC", ) self.log_train_images = True def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") if "dataloader_params" not in cfg or not isinstance( cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True") with open_dict(cfg.dataloader_params): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to False!!!" ) elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to True!!!") kwargs_dict = {} if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset": kwargs_dict["parser"] = self.parser dataset = instantiate(cfg.dataset, **kwargs_dict) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) def setup_training_data(self, cfg): self._train_dl = self.__setup_dataloader_from_config(cfg) def setup_validation_data(self, cfg): self._validation_dl = self.__setup_dataloader_from_config( cfg, shuffle_should_be=False, name="val") def setup_test_data(self, cfg): """Omitted.""" pass @classmethod def list_available_models(cls) -> 'List[PretrainedModelInfo]': """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ list_of_models = [] model = PretrainedModelInfo( pretrained_model_name="tts_en_fastpitch", location= "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.0.0/files/tts_en_fastpitch.nemo", description= "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.", class_=cls, ) list_of_models.append(model) return list_of_models
def input_types(self): return { "hard_attention": NeuralType(('B', 'S', 'T', 'D'), ProbsType()), "soft_attention": NeuralType(('B', 'S', 'T', 'D'), ProbsType()), }