def __init__(self, cfg: DictConfig, trainer: Trainer = None): if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) self.parser = instantiate(cfg.parser) super().__init__(cfg=cfg, trainer=trainer) schema = OmegaConf.structured(GlowTTSConfig) # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) # Ensure passed cfg is compliant with schema OmegaConf.merge(cfg, schema) self.preprocessor = instantiate(self._cfg.preprocessor) encoder = instantiate(self._cfg.encoder) decoder = instantiate(self._cfg.decoder) self.glow_tts = GlowTTSModule(encoder, decoder, n_speakers=cfg.n_speakers, gin_channels=cfg.gin_channels) self.loss = GlowTTSLoss()
class GlowTTSModel(SpectrogramGenerator): """ GlowTTS model used to generate spectrograms from text Consists of a text encoder and an invertible spectrogram decoder """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) self.parser = instantiate(cfg.parser) super().__init__(cfg=cfg, trainer=trainer) schema = OmegaConf.structured(GlowTTSConfig) # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) # Ensure passed cfg is compliant with schema OmegaConf.merge(cfg, schema) self.preprocessor = instantiate(self._cfg.preprocessor) encoder = instantiate(self._cfg.encoder) decoder = instantiate(self._cfg.decoder) self.glow_tts = GlowTTSModule(encoder, decoder, n_speakers=cfg.n_speakers, gin_channels=cfg.gin_channels) self.loss = GlowTTSLoss() def parse(self, str_input: str) -> torch.tensor: if str_input[-1] not in [".", "!", "?"]: str_input = str_input + "." tokens = self.parser(str_input) x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device) return x @typecheck( input_types={ "x": NeuralType(('B', 'T'), TokenIndex()), "x_lengths": NeuralType(('B'), LengthsType()), "y": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True), "y_lengths": NeuralType(('B'), LengthsType(), optional=True), "gen": NeuralType(optional=True), "noise_scale": NeuralType(optional=True), "length_scale": NeuralType(optional=True), }) def forward(self, *, x, x_lengths, y=None, y_lengths=None, gen=False, noise_scale=0.3, length_scale=1.0): if gen: return self.glow_tts.generate_spect(text=x, text_lengths=x_lengths, noise_scale=noise_scale, length_scale=length_scale) else: return self.glow_tts(text=x, text_lengths=x_lengths, spect=y, spect_lengths=y_lengths) def step(self, y, y_lengths, x, x_lengths): z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn = self( x=x, x_lengths=x_lengths, y=y, y_lengths=y_lengths, gen=False) l_mle, l_length, logdet = self.loss( z=z, y_m=y_m, y_logs=y_logs, logdet=logdet, logw=logw, logw_=logw_, x_lengths=x_lengths, y_lengths=y_lengths, ) loss = sum([l_mle, l_length]) return l_mle, l_length, logdet, loss, attn def training_step(self, batch, batch_idx): y, y_lengths, x, x_lengths = batch y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) l_mle, l_length, logdet, loss, _ = self.step(y, y_lengths, x, x_lengths) output = { "loss": loss, # required "progress_bar": { "l_mle": l_mle, "l_length": l_length, "logdet": logdet }, "log": { "loss": loss, "l_mle": l_mle, "l_length": l_length, "logdet": logdet }, } return output def validation_step(self, batch, batch_idx): y, y_lengths, x, x_lengths = batch y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) l_mle, l_length, logdet, loss, attn = self.step( y, y_lengths, x, x_lengths) y_gen, attn_gen = self(x=x, x_lengths=x_lengths, gen=True) return { "loss": loss, "l_mle": l_mle, "l_length": l_length, "logdet": logdet, "y": y, "y_gen": y_gen, "x": x, "attn": attn, "attn_gen": attn_gen, "progress_bar": { "l_mle": l_mle, "l_length": l_length, "logdet": logdet }, } def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() avg_mle = torch.stack([x['l_mle'] for x in outputs]).mean() avg_length_loss = torch.stack([x['l_length'] for x in outputs]).mean() avg_logdet = torch.stack([x['logdet'] for x in outputs]).mean() tensorboard_logs = { 'val_loss': avg_loss, 'val_mle': avg_mle, 'val_length_loss': avg_length_loss, 'val_logdet': avg_logdet, } if self.logger is not None and self.logger.experiment is not None: tb_logger = self.logger.experiment if isinstance(self.logger, LoggerCollection): for logger in self.logger: if isinstance(logger, TensorBoardLogger): tb_logger = logger.experiment break separated_phonemes = "|".join( [self.parser.symbols[c] for c in outputs[0]['x'][0]]) tb_logger.add_text("separated phonemes", separated_phonemes, self.global_step) tb_logger.add_image( "real_spectrogram", plot_spectrogram_to_numpy( outputs[0]['y'][0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) tb_logger.add_image( "generated_spectrogram", plot_spectrogram_to_numpy( outputs[0]['y_gen'][0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) tb_logger.add_image( "alignment_for_real_sp", plot_alignment_to_numpy( outputs[0]['attn'][0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) tb_logger.add_image( "alignment_for_generated_sp", plot_alignment_to_numpy( outputs[0]['attn_gen'][0].data.cpu().numpy()), self.global_step, dataformats="HWC", ) log_audio_to_tb(tb_logger, outputs[0]['y'][0], "true_audio_gf", self.global_step) log_audio_to_tb(tb_logger, outputs[0]['y_gen'][0], "generated_audio_gf", self.global_step) return {'val_loss': avg_loss, 'log': tensorboard_logs} def _setup_dataloader_from_config(self, cfg: DictConfig): if 'manifest_filepath' in cfg and cfg['manifest_filepath'] is None: logging.warning( f"Could not load dataset as `manifest_filepath` was None. Provided config : {cfg}" ) return None if 'augmentor' in cfg: augmentor = process_augmentations(cfg['augmentor']) else: augmentor = None dataset = _AudioTextDataset( manifest_filepath=cfg['manifest_filepath'], parser=self.parser, sample_rate=cfg['sample_rate'], int_values=cfg.get('int_values', False), augmentor=augmentor, max_duration=cfg.get('max_duration', None), min_duration=cfg.get('min_duration', None), max_utts=cfg.get('max_utts', 0), trim=cfg.get('trim_silence', True), load_audio=cfg.get('load_audio', True), add_misc=cfg.get('add_misc', False), ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg['batch_size'], collate_fn=dataset.collate_fn, drop_last=cfg.get('drop_last', False), shuffle=cfg['shuffle'], num_workers=cfg.get('num_workers', 0), ) def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config( cfg=train_data_config) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_dataloader_from_config( cfg=val_data_config) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config( cfg=test_data_config) def generate_spectrogram(self, tokens: 'torch.tensor', noise_scale: float = 0.0, length_scale: float = 1.0) -> torch.tensor: self.eval() token_len = torch.tensor([tokens.shape[1]]).to(self.device) spect, _ = self(x=tokens, x_lengths=token_len, gen=True, noise_scale=noise_scale, length_scale=length_scale) return spect @classmethod def list_available_models(cls) -> 'List[PretrainedModelInfo]': """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ list_of_models = [] model = PretrainedModelInfo( pretrained_model_name="tts_en_glowtts", location= "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_glowtts/versions/1.0.0rc1/files/tts_en_glowtts.nemo", description= "This model is trained on LJSpeech sampled at 22050Hz, and can be used to generate female English voices with an American accent.", class_=cls, ) list_of_models.append(model) return list_of_models