Ejemplo n.º 1
0
 def input_types(self):
     return {
         "spec_pred": NeuralType(('B', 'C', 'T'), MelSpectrogramType()),
         "spec_target": NeuralType(('B', 'C', 'T'), MelSpectrogramType()),
         "spec_target_len": NeuralType(('B'), LengthsType()),
         "pad_value": NeuralType(),
     }
Ejemplo n.º 2
0
 def input_types(self):
     if self.mode == OperationMode.infer:
         return {
             "spec":
             NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
             "z":
             NeuralType(('B', 'D', 'T'),
                        MelSpectrogramType(),
                        optional=True),
             "sigma":
             NeuralType(optional=True),
         }
     else:
         return {
             "spec":
             NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
             "z":
             NeuralType(('B', 'D', 'T'),
                        MelSpectrogramType(),
                        optional=True),
             "audio":
             NeuralType(('B', 'T'), AudioSignal(), optional=True),
             "run_inverse":
             NeuralType(elements_type=IntType(), optional=True),
             "sigma":
             NeuralType(optional=True),
         }
Ejemplo n.º 3
0
 def output_types(self):
     if not self.calculate_loss and not self.training:
         return {
             "spec_pred_dec":
             NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
             "spec_pred_postnet":
             NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
             "gate_pred":
             NeuralType(('B', 'T'), LogitsType()),
             "alignments":
             NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()),
             "pred_length":
             NeuralType(('B'), LengthsType()),
         }
     return {
         "spec_pred_dec":
         NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "spec_pred_postnet":
         NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "gate_pred":
         NeuralType(('B', 'T'), LogitsType()),
         "spec_target":
         NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "spec_target_len":
         NeuralType(('B'), LengthsType()),
         "alignments":
         NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()),
     }
Ejemplo n.º 4
0
 def input_types(self):
     return {
         "spec_pred_dec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "spec_pred_postnet": NeuralType(('B', 'D', 'T'),
                                         MelSpectrogramType()),
         "gate_pred": NeuralType(('B', 'T'), LogitsType()),
         "spec_target": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "spec_target_len": NeuralType(('B'), LengthsType()),
         "pad_value": NeuralType(),
     }
Ejemplo n.º 5
0
 def input_types(self):
     return {
         "spect_predicted": NeuralType(('B', 'D', 'T'),
                                       MelSpectrogramType()),
         "log_durs_predicted": NeuralType(('B', 'T'),
                                          TokenLogDurationType()),
         "pitch_predicted": NeuralType(('B', 'T'), RegressionValuesType()),
         "spect_tgt": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "durs_tgt": NeuralType(('B', 'T'), TokenDurationType()),
         "dur_lens": NeuralType(('B'), LengthsType()),
         "pitch_tgt": NeuralType(('B', 'T'), RegressionValuesType()),
     }
Ejemplo n.º 6
0
 def input_types(self):
     return {
         "spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "spect_mask": NeuralType(('B', 'D', 'T'), MaskType()),
         "speaker_embeddings": NeuralType(('B', 'D'), AcousticEncodedRepresentation(), optional=True),
         "reverse": NeuralType(elements_type=IntType(), optional=True),
     }
Ejemplo n.º 7
0
 def input_types(self):
     return {
         "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "audio": NeuralType(('B', 'T'), AudioSignal(), optional=True),
         "run_inverse": NeuralType(elements_type=IntType(), optional=True),
         "sigma": NeuralType(optional=True),
     }
Ejemplo n.º 8
0
    def _prepare_for_export(self, **kwargs):
        super()._prepare_for_export(**kwargs)

        # Define input_types and output_types as required by export()
        self._input_types = {
            "text": NeuralType(('B', 'T_text'), TokenIndex()),
            "pitch": NeuralType(('B', 'T_text'), RegressionValuesType()),
            "pace": NeuralType(('B', 'T_text'), optional=True),
            "volume": NeuralType(('B', 'T_text')),
            "speaker": NeuralType(('B'), Index()),
        }
        self._output_types = {
            "spect":
            NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
            "num_frames":
            NeuralType(('B'), TokenDurationType()),
            "durs_predicted":
            NeuralType(('B', 'T_text'), TokenDurationType()),
            "log_durs_predicted":
            NeuralType(('B', 'T_text'), TokenLogDurationType()),
            "pitch_predicted":
            NeuralType(('B', 'T_text'), RegressionValuesType()),
            "volume_aligned":
            NeuralType(('B', 'T_spec'), RegressionValuesType()),
        }
Ejemplo n.º 9
0
 def input_types(self):
     return {
         "text": NeuralType(('B', 'T'), TokenIndex()),
         "text_lengths": NeuralType(('B'), LengthsType()),
         "spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "spect_lengths": NeuralType(('B'), LengthsType()),
         "speaker": NeuralType(('B'), IntType(), optional=True),
     }
Ejemplo n.º 10
0
 def input_types(self):
     input_dict = {
         "memory": NeuralType(('B', 'T', 'D'), EmbeddedTextType()),
         "memory_lengths": NeuralType(('B'), LengthsType()),
     }
     if self.training:
         input_dict["decoder_inputs"] = NeuralType(('B', 'D', 'T'), MelSpectrogramType())
     return input_dict
Ejemplo n.º 11
0
 def output_types(self):
     output_dict = {
         "mel_outputs": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "gate_outputs": NeuralType(('B', 'T'), LogitsType()),
         "alignments": NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()),
     }
     if not self.training:
         output_dict["mel_lengths"] = NeuralType(('B'), LengthsType())
     return output_dict
Ejemplo n.º 12
0
 def output_types(self):
     return {
         "spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "spect_lens": NeuralType(('B'), SequenceToSequenceAlignmentType()),
         "spect_mask": NeuralType(('B', 'D', 'T'), MaskType()),
         "durs_predicted": NeuralType(('B', 'T'), TokenDurationType()),
         "log_durs_predicted": NeuralType(('B', 'T'),
                                          TokenLogDurationType()),
         "pitch_predicted": NeuralType(('B', 'T'), RegressionValuesType()),
     }
Ejemplo n.º 13
0
 def input_types(self):
     return {
         "text": NeuralType(('B', 'T_text'), TokenIndex()),
         "durs": NeuralType(('B', 'T_text'), TokenDurationType()),
         "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()),
         "speaker": NeuralType(('B'), Index(), optional=True),
         "pace": NeuralType(optional=True),
         "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
         "attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
         "mel_lens": NeuralType(('B'), LengthsType(), optional=True),
         "input_lens": NeuralType(('B'), LengthsType(), optional=True),
     }
Ejemplo n.º 14
0
 def output_types(self):
     return {
         "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
         "num_frames": NeuralType(('B'), TokenDurationType()),
         "durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()),
         "log_durs_predicted": NeuralType(('B', 'T_text'), TokenLogDurationType()),
         "pitch_predicted": NeuralType(('B', 'T_text'), RegressionValuesType()),
         "attn_soft": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()),
         "attn_logprob": NeuralType(('B', 'S', 'T_spec', 'T_text'), LogprobsType()),
         "attn_hard": NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()),
         "attn_hard_dur": NeuralType(('B', 'T_text'), TokenDurationType()),
         "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()),
     }
Ejemplo n.º 15
0
 def output_types(self):
     if self.mode == OperationMode.training or self.mode == OperationMode.validation:
         output_dict = {
             "pred_normal_dist": NeuralType(('B', 'flowgroup', 'T'), NormalDistributionSamplesType()),
             "logdet": NeuralType(elements_type=LogDeterminantType()),
             "predicted_audio": NeuralType(('B', 'T'), AudioSignal()),
         }
         if self.mode == OperationMode.validation:
             output_dict["spec"] = NeuralType(('B', 'T', 'D'), MelSpectrogramType())
             output_dict["spec_len"] = NeuralType(('B'), LengthsType())
         return output_dict
     return {
         "audio_pred": NeuralType(('B', 'T'), AudioSignal()),
     }
Ejemplo n.º 16
0
 def output_types(self):
     if self.mode == OperationMode.training or self.mode == OperationMode.validation:
         output_dict = {
             "pred_normal_dist": NeuralType(('B', 'flowgroup', 'T'), NormalDistributionSamplesType()),
             "log_s_list": NeuralType(('B', 'flowgroup', 'T'), VoidType()),  # TODO: Figure out a good typing
             "log_det_W_list": NeuralType(elements_type=VoidType()),  # TODO: Figure out a good typing
         }
         if self.mode == OperationMode.validation:
             output_dict["audio_pred"] = NeuralType(('B', 'T'), AudioSignal())
             output_dict["spec"] = NeuralType(('B', 'T', 'D'), MelSpectrogramType())
             output_dict["spec_len"] = NeuralType(('B'), LengthsType())
         return output_dict
     return {
         "audio_pred": NeuralType(('B', 'T'), AudioSignal()),
     }
Ejemplo n.º 17
0
class SqueezeWaveModel(Vocoder):
    """ SqueezeWave model that generates audio conditioned on mel-spectrogram
    """

    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(SqueezeWaveConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig")
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.pad_value = self._cfg.preprocessor.params.pad_value
        self.sigma = self._cfg.sigma
        self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor)
        self.squeezewave = instantiate(self._cfg.squeezewave)
        self.mode = OperationMode.infer
        self.loss = WaveGlowLoss()  # Same loss as WaveGlow

    @property
    def input_types(self):
        return {
            "audio": NeuralType(('B', 'T'), AudioSignal()),
            "audio_len": NeuralType(('B'), LengthsType()),
            "run_inverse": NeuralType(optional=True),
        }

    @property
    def output_types(self):
        if self.mode == OperationMode.training or self.mode == OperationMode.validation:
            output_dict = {
                "pred_normal_dist": NeuralType(('B', 'flowgroup', 'T'), NormalDistributionSamplesType()),
                "log_s_list": NeuralType(('B', 'flowgroup', 'T'), VoidType()),  # TODO: Figure out a good typing
                "log_det_W_list": NeuralType(elements_type=VoidType()),  # TODO: Figure out a good typing
            }
            if self.mode == OperationMode.validation:
                output_dict["audio_pred"] = NeuralType(('B', 'T'), AudioSignal())
                output_dict["spec"] = NeuralType(('B', 'T', 'D'), MelSpectrogramType())
                output_dict["spec_len"] = NeuralType(('B'), LengthsType())
            return output_dict
        return {
            "audio_pred": NeuralType(('B', 'T'), AudioSignal()),
        }

    @typecheck()
    def forward(self, *, audio, audio_len, run_inverse=True):
        if self.mode != self.squeezewave.mode:
            raise ValueError(
                f"SqueezeWaveModel's mode {self.mode} does not match SqueezeWaveModule's mode {self.squeezewave.mode}"
            )
        spec, spec_len = self.audio_to_melspec_precessor(audio, audio_len)
        tensors = self.squeezewave(spec=spec, audio=audio, run_inverse=run_inverse)
        if self.mode == OperationMode.training:
            return tensors[:-1]  # z, log_s_list, log_det_W_list
        elif self.mode == OperationMode.validation:
            z, log_s_list, log_det_W_list, audio_pred = tensors
            return z, log_s_list, log_det_W_list, audio_pred, spec, spec_len
        return tensors  # audio_pred

    @typecheck(
        input_types={"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), "sigma": NeuralType(optional=True)},
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def convert_spectrogram_to_audio(self, spec: torch.Tensor, sigma: bool = 1.0) -> torch.Tensor:
        self.eval()
        self.mode = OperationMode.infer
        self.squeezewave.mode = OperationMode.infer

        with torch.no_grad():
            audio = self.squeezewave(spec=spec, run_inverse=True, audio=None, sigma=sigma)

        return audio

    def training_step(self, batch, batch_idx):
        self.mode = OperationMode.training
        self.squeezewave.mode = OperationMode.training
        audio, audio_len = batch
        z, log_s_list, log_det_W_list = self.forward(audio=audio, audio_len=audio_len, run_inverse=False)

        loss = self.loss(z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list, sigma=self.sigma)
        return {
            'loss': loss,
            'progress_bar': {'training_loss': loss},
            'log': {'loss': loss},
        }

    def validation_step(self, batch, batch_idx):
        self.mode = OperationMode.validation
        self.squeezewave.mode = OperationMode.validation
        audio, audio_len = batch
        z, log_s_list, log_det_W_list, audio_pred, spec, spec_len = self.forward(
            audio=audio, audio_len=audio_len, run_inverse=(batch_idx == 0)
        )
        loss = self.loss(z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list, sigma=self.sigma)
        return {
            "val_loss": loss,
            "audio_pred": audio_pred,
            "mel_target": spec,
            "mel_len": spec_len,
        }

    def validation_epoch_end(self, outputs):
        if self.logger is not None and self.logger.experiment is not None:
            waveglow_log_to_tb_func(
                self.logger.experiment,
                outputs[0].values(),
                self.global_step,
                tag="eval",
                mel_fb=self.audio_to_melspec_precessor.fb,
            )
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")  # TODO
        if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")  # TODO
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True"
                )
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="SqueezeWave-22050Hz",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemottsmodels/versions/1.0.0a5/files/SqueezeWave-22050Hz.nemo",
            description="This model is trained on LJSpeech sampled at 22050Hz, and can be used as an universal vocoder.",
            class_=cls,
        )
        list_of_models.append(model)
        return list_of_models
Ejemplo n.º 18
0
class FastSpeech2HifiGanE2EModel(TextToWaveform):
    """An end-to-end speech synthesis model based on FastSpeech2 and HiFiGan that converts strings to audio without
    using the intermediate mel spectrogram representation."""
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        self.encoder = instantiate(cfg.encoder)
        self.variance_adapter = instantiate(cfg.variance_adaptor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()

        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)
        self.mel_val_loss = L1MelLoss()
        self.durationloss = DurationLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()
        self.mseloss = torch.nn.MSELoss()

        self.energy = cfg.add_energy_predictor
        self.pitch = cfg.add_pitch_predictor
        self.mel_loss_coeff = cfg.mel_loss_coeff
        self.pitch_loss_coeff = cfg.pitch_loss_coeff
        self.energy_loss_coeff = cfg.energy_loss_coeff
        self.splice_length = cfg.splice_length

        self.use_energy_pred = False
        self.use_pitch_pred = False
        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size

        # Parser and mappings are used for inference only.
        self.parser = parsers.make_parser(name='en')
        if 'mappings_filepath' in cfg:
            mappings_filepath = cfg.get('mappings_filepath')
        else:
            logging.error(
                "ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath."
            )
        mappings_filepath = self.register_artifact('mappings_filepath',
                                                   mappings_filepath)
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']

    @property
    def tb_logger(self):
        if self._tb_logger is None:
            if self.logger is None and self.logger.experiment is None:
                return None
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            self._tb_logger = tb_logger
        return self._tb_logger

    def configure_optimizers(self):
        gen_params = chain(
            self.encoder.parameters(),
            self.generator.parameters(),
            self.variance_adapter.parameters(),
        )
        disc_params = chain(self.multiscaledisc.parameters(),
                            self.multiperioddisc.parameters())
        opt1 = torch.optim.AdamW(disc_params, lr=self._cfg.lr)
        opt2 = torch.optim.AdamW(gen_params, lr=self._cfg.lr)
        num_procs = self._trainer.num_gpus * self._trainer.num_nodes
        num_samples = len(self._train_dl.dataset)
        batch_size = self._train_dl.batch_size
        iter_per_epoch = np.ceil(num_samples / (num_procs * batch_size))
        max_steps = iter_per_epoch * self._trainer.max_epochs
        logging.info(f"MAX STEPS: {max_steps}")
        sch1 = NoamAnnealing(opt1,
                             d_model=256,
                             warmup_steps=3000,
                             max_steps=max_steps,
                             min_lr=1e-5)
        sch1_dict = {
            'scheduler': sch1,
            'interval': 'step',
        }
        sch2 = NoamAnnealing(opt2,
                             d_model=256,
                             warmup_steps=3000,
                             max_steps=max_steps,
                             min_lr=1e-5)
        sch2_dict = {
            'scheduler': sch2,
            'interval': 'step',
        }
        return [opt1, opt2], [sch1_dict, sch2_dict]

    @typecheck(
        input_types={
            "text":
            NeuralType(('B', 'T'), TokenIndex()),
            "text_length":
            NeuralType(('B'), LengthsType()),
            "splice":
            NeuralType(optional=True),
            "spec_len":
            NeuralType(('B'), LengthsType(), optional=True),
            "durations":
            NeuralType(('B', 'T'), TokenDurationType(), optional=True),
            "pitch":
            NeuralType(('B', 'T'), RegressionValuesType(), optional=True),
            "energies":
            NeuralType(('B', 'T'), RegressionValuesType(), optional=True),
        },
        output_types={
            "audio": NeuralType(('B', 'S', 'T'), MelSpectrogramType()),
            "splices": NeuralType(),
            "log_dur_preds": NeuralType(('B', 'T'), TokenLogDurationType()),
            "pitch_preds": NeuralType(('B', 'T'), RegressionValuesType()),
            "energy_preds": NeuralType(('B', 'T'), RegressionValuesType()),
            "encoded_text_mask": NeuralType(('B', 'T', 'D'), MaskType()),
        },
    )
    def forward(self,
                *,
                text,
                text_length,
                splice=True,
                durations=None,
                pitch=None,
                energies=None,
                spec_len=None):
        encoded_text, encoded_text_mask = self.encoder(text=text,
                                                       text_length=text_length)

        context, log_dur_preds, pitch_preds, energy_preds, spec_len = self.variance_adapter(
            x=encoded_text,
            x_len=text_length,
            dur_target=durations,
            pitch_target=pitch,
            energy_target=energies,
            spec_len=spec_len,
        )

        gen_in = context
        splices = None
        if splice:
            # Splice generated spec
            output = []
            splices = []
            for i, sample in enumerate(context):
                start = np.random.randint(
                    low=0,
                    high=min(int(sample.size(0)), int(spec_len[i])) -
                    self.splice_length)
                output.append(sample[start:start + self.splice_length, :])
                splices.append(start)
            gen_in = torch.stack(output)

        output = self.generator(x=gen_in.transpose(1, 2))

        return output, splices, log_dur_preds, pitch_preds, energy_preds, encoded_text_mask

    def training_step(self, batch, batch_idx, optimizer_idx):
        f, fl, t, tl, durations, pitch, energies = batch
        spec, spec_len = self.audio_to_melspec_precessor(f, fl)

        # train discriminator
        if optimizer_idx == 0:
            with torch.no_grad():
                audio_pred, splices, _, _, _, _ = self(
                    spec=spec,
                    spec_len=spec_len,
                    text=t,
                    text_length=tl,
                    durations=durations,
                    pitch=pitch if not self.use_pitch_pred else None,
                    energies=energies if not self.use_energy_pred else None,
                )
                real_audio = []
                for i, splice in enumerate(splices):
                    real_audio.append(
                        f[i, splice *
                          self.hop_size:(splice + self.splice_length) *
                          self.hop_size])
                real_audio = torch.stack(real_audio).unsqueeze(1)

            real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(
                real_audio, audio_pred)
            real_score_ms, gen_score_ms, _, _ = self.multiscaledisc(
                real_audio, audio_pred)

            loss_mp, loss_mp_real, _ = self.disc_loss(real_score_mp,
                                                      gen_score_mp)
            loss_ms, loss_ms_real, _ = self.disc_loss(real_score_ms,
                                                      gen_score_ms)
            loss_mp /= len(loss_mp_real)
            loss_ms /= len(loss_ms_real)
            loss_disc = loss_mp + loss_ms

            self.log("loss_discriminator", loss_disc, prog_bar=True)
            self.log("loss_discriminator_ms", loss_ms)
            self.log("loss_discriminator_mp", loss_mp)
            return loss_disc

        # train generator
        elif optimizer_idx == 1:
            audio_pred, splices, log_dur_preds, pitch_preds, energy_preds, encoded_text_mask = self(
                spec=spec,
                spec_len=spec_len,
                text=t,
                text_length=tl,
                durations=durations,
                pitch=pitch if not self.use_pitch_pred else None,
                energies=energies if not self.use_energy_pred else None,
            )
            real_audio = []
            for i, splice in enumerate(splices):
                real_audio.append(
                    f[i, splice * self.hop_size:(splice + self.splice_length) *
                      self.hop_size])
            real_audio = torch.stack(real_audio).unsqueeze(1)

            # Do HiFiGAN generator loss
            audio_length = torch.tensor([
                self.splice_length * self.hop_size
                for _ in range(real_audio.shape[0])
            ]).to(real_audio.device)
            real_spliced_spec, _ = self.melspec_fn(real_audio.squeeze(),
                                                   seq_len=audio_length)
            pred_spliced_spec, _ = self.melspec_fn(audio_pred.squeeze(),
                                                   seq_len=audio_length)
            loss_mel = torch.nn.functional.l1_loss(real_spliced_spec,
                                                   pred_spliced_spec)
            loss_mel *= self.mel_loss_coeff
            _, gen_score_mp, real_feat_mp, gen_feat_mp = self.multiperioddisc(
                real_audio, audio_pred)
            _, gen_score_ms, real_feat_ms, gen_feat_ms = self.multiscaledisc(
                real_audio, audio_pred)
            loss_gen_mp, list_loss_gen_mp = self.gen_loss(gen_score_mp)
            loss_gen_ms, list_loss_gen_ms = self.gen_loss(gen_score_ms)
            loss_gen_mp /= len(list_loss_gen_mp)
            loss_gen_ms /= len(list_loss_gen_ms)
            total_loss = loss_gen_mp + loss_gen_ms + loss_mel
            loss_feat_mp = self.feat_matching_loss(real_feat_mp, gen_feat_mp)
            loss_feat_ms = self.feat_matching_loss(real_feat_ms, gen_feat_ms)
            total_loss += loss_feat_mp + loss_feat_ms
            self.log(name="loss_gen_disc_feat",
                     value=loss_feat_mp + loss_feat_ms)
            self.log(name="loss_gen_disc_feat_ms", value=loss_feat_ms)
            self.log(name="loss_gen_disc_feat_mp", value=loss_feat_mp)

            self.log(name="loss_gen_mel", value=loss_mel)
            self.log(name="loss_gen_disc", value=loss_gen_mp + loss_gen_ms)
            self.log(name="loss_gen_disc_mp", value=loss_gen_mp)
            self.log(name="loss_gen_disc_ms", value=loss_gen_ms)

            dur_loss = self.durationloss(log_duration_pred=log_dur_preds,
                                         duration_target=durations.float(),
                                         mask=encoded_text_mask)
            self.log(name="loss_gen_duration", value=dur_loss)
            total_loss += dur_loss
            if self.pitch:
                pitch_loss = self.mseloss(
                    pitch_preds, pitch.float()) * self.pitch_loss_coeff
                total_loss += pitch_loss
                self.log(name="loss_gen_pitch", value=pitch_loss)
            if self.energy:
                energy_loss = self.mseloss(energy_preds,
                                           energies) * self.energy_loss_coeff
                total_loss += energy_loss
                self.log(name="loss_gen_energy", value=energy_loss)

            # Log images to tensorboard
            if self.log_train_images:
                self.log_train_images = False
                if self.logger is not None and self.logger.experiment is not None:
                    self.tb_logger.add_image(
                        "train_mel_target",
                        plot_spectrogram_to_numpy(
                            real_spliced_spec[0].data.cpu().numpy()),
                        self.global_step,
                        dataformats="HWC",
                    )
                    spec_predict = pred_spliced_spec[0].data.cpu().numpy()
                    self.tb_logger.add_image(
                        "train_mel_predicted",
                        plot_spectrogram_to_numpy(spec_predict),
                        self.global_step,
                        dataformats="HWC",
                    )
            self.log(name="loss_gen", prog_bar=True, value=total_loss)
            return total_loss

    def validation_step(self, batch, batch_idx):
        f, fl, t, tl, _, _, _ = batch
        spec, spec_len = self.audio_to_melspec_precessor(f, fl)
        audio_pred, _, _, _, _, _ = self(spec=spec,
                                         spec_len=spec_len,
                                         text=t,
                                         text_length=tl,
                                         splice=False)
        audio_pred.squeeze_()
        pred_spec, _ = self.melspec_fn(audio_pred, seq_len=spec_len)
        loss = self.mel_val_loss(spec_pred=pred_spec,
                                 spec_target=spec,
                                 spec_target_len=spec_len,
                                 pad_value=-11.52)

        return {
            "val_loss": loss,
            "audio_target": f.squeeze() if batch_idx == 0 else None,
            "audio_pred": audio_pred if batch_idx == 0 else None,
        }

    def on_train_epoch_start(self):
        # Switch to using energy predictions after 50% of training
        if not self.use_energy_pred and self.current_epoch >= np.ceil(
                0.5 * self._trainer.max_epochs):
            logging.info(
                f"Using energy predictions after epoch: {self.current_epoch}")
            self.use_energy_pred = True

        # Switch to using pitch predictions after 62.5% of training
        if not self.use_pitch_pred and self.current_epoch >= np.ceil(
                0.625 * self._trainer.max_epochs):
            logging.info(
                f"Using pitch predictions after epoch: {self.current_epoch}")
            self.use_pitch_pred = True

    def validation_epoch_end(self, outputs):
        if self.tb_logger is not None:
            _, audio_target, audio_predict = outputs[0].values()
            if not self.logged_real_samples:
                self.tb_logger.add_audio("val_target",
                                         audio_target[0].data.cpu(),
                                         self.global_step, self.sample_rate)
                self.logged_real_samples = True
            audio_predict = audio_predict[0].data.cpu()
            self.tb_logger.add_audio("val_pred", audio_predict,
                                     self.global_step, self.sample_rate)
        avg_loss = torch.stack([
            x['val_loss'] for x in outputs
        ]).mean()  # This reduces across batches, not workers!
        self.log('val_loss', avg_loss, sync_dist=True)

        self.log_train_images = True

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    def parse(self,
              str_input: str,
              additional_word2phones=None) -> torch.tensor:
        """
        Parses text input and converts them to phoneme indices.

        str_input (str): The input text to be converted.
        additional_word2phones (dict): Optional dictionary mapping words to phonemes for updating the model's
            word2phones.  This will not overwrite the existing dictionary, just update it with OOV or new mappings.
            Defaults to None, which will keep the existing mapping.
        """
        # Update model's word2phones if applicable
        if additional_word2phones is not None:
            self.word2phones.update(additional_word2phones)

        # Convert text -> normalized text -> list of phones per word -> indices
        if str_input[-1] not in [".", "!", "?"]:
            str_input = str_input + "."
        norm_text = re.findall(r"""[\w']+|[.,!?;"]""",
                               self.parser._normalize(str_input))

        try:
            phones = [self.word2phones[t] for t in norm_text]
        except KeyError as error:
            logging.error(
                f"ERROR: The following word in the input is not in the model's dictionary and could not be converted"
                f" to phonemes: ({error}).\n"
                f"You can pass in an `additional_word2phones` dictionary with a conversion for"
                f" this word, e.g. {{'{error}': \['phone1', 'phone2', ...\]}} to update the model's mapping."
            )
            raise

        tokens = []
        for phone_list in phones:
            inds = [self.phone2idx[p] for p in phone_list]
            tokens += inds

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    def convert_text_to_waveform(self, *, tokens):
        """
        Accepts tokens returned from self.parse() and returns a list of tensors. Note: The tensors in the list can have
        different lengths.
        """
        self.eval()
        token_len = torch.tensor([len(i) for i in tokens]).to(self.device)
        audio, _, log_dur_pred, _, _, _ = self(text=tokens,
                                               text_length=token_len,
                                               splice=False)
        audio = audio.squeeze(1)
        durations = torch.sum(torch.exp(log_dur_pred) - 1, 1).to(torch.int)
        audio_list = []
        for i, sample in enumerate(audio):
            audio_list.append(sample[:durations[i] * self.hop_size])

        return audio_list

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_e2e_fastspeech2hifigan",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_e2e_fastspeech2hifigan/versions/1.0.0/files/tts_en_e2e_fastspeech2hifigan.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models
Ejemplo n.º 19
0
 def output_types(self):
     return {
         "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
     }
Ejemplo n.º 20
0
class MixerTTSModel(SpectrogramGenerator, Exportable):
    """MixerTTS pipeline."""
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        super().__init__(cfg=cfg, trainer=trainer)
        cfg = self._cfg
        if "text_normalizer" in cfg.train_ds.dataset:
            self.normalizer = instantiate(cfg.train_ds.dataset.text_normalizer)
            self.text_normalizer_call = self.normalizer.normalize
            self.text_normalizer_call_args = {}
            if cfg.train_ds.dataset.get("text_normalizer_call_args",
                                        None) is not None:
                self.text_normalizer_call_args = cfg.train_ds.dataset.text_normalizer_call_args

        self.tokenizer = instantiate(cfg.train_ds.dataset.text_tokenizer)
        num_tokens = len(self.tokenizer.tokens)
        self.tokenizer_pad = self.tokenizer.pad
        self.tokenizer_unk = self.tokenizer.oov

        self.pitch_loss_scale = cfg.pitch_loss_scale
        self.durs_loss_scale = cfg.durs_loss_scale
        self.mel_loss_scale = cfg.mel_loss_scale

        self.aligner = instantiate(cfg.alignment_module)
        self.forward_sum_loss = ForwardSumLoss()
        self.bin_loss = BinLoss()
        self.add_bin_loss = False
        self.bin_loss_scale = 0.0
        self.bin_loss_start_ratio = cfg.bin_loss_start_ratio
        self.bin_loss_warmup_epochs = cfg.bin_loss_warmup_epochs

        self.cond_on_lm_embeddings = cfg.get("cond_on_lm_embeddings", False)

        if self.cond_on_lm_embeddings:
            self.lm_padding_value = (self._train_dl.dataset.lm_padding_value
                                     if self._train_dl is not None else
                                     self._get_lm_padding_value(
                                         cfg.train_ds.dataset.lm_model))
            self.lm_embeddings = self._get_lm_embeddings(
                cfg.train_ds.dataset.lm_model)
            self.lm_embeddings.weight.requires_grad = False

            self.self_attention_module = instantiate(
                cfg.self_attention_module,
                n_lm_tokens_channels=self.lm_embeddings.weight.shape[1])

        self.encoder = instantiate(cfg.encoder,
                                   num_tokens=num_tokens,
                                   padding_idx=self.tokenizer_pad)
        self.symbol_emb = self.encoder.to_embed

        self.duration_predictor = instantiate(cfg.duration_predictor)

        self.pitch_mean, self.pitch_std = float(cfg.pitch_mean), float(
            cfg.pitch_std)
        self.pitch_predictor = instantiate(cfg.pitch_predictor)
        self.pitch_emb = instantiate(cfg.pitch_emb)

        self.preprocessor = instantiate(cfg.preprocessor)

        self.decoder = instantiate(cfg.decoder)
        self.proj = nn.Linear(self.decoder.d_model, cfg.n_mel_channels)

    def _get_lm_model_tokenizer(self, lm_model="albert"):
        if getattr(self, "_lm_model_tokenizer", None) is not None:
            return self._lm_model_tokenizer

        if self._train_dl is not None and self._train_dl.dataset is not None:
            self._lm_model_tokenizer = self._train_dl.dataset.lm_model_tokenizer

        if lm_model == "albert":
            self._lm_model_tokenizer = AlbertTokenizer.from_pretrained(
                'albert-base-v2')
        else:
            raise NotImplementedError(
                f"{lm_model} lm model is not supported. Only albert is supported at this moment."
            )

        return self._lm_model_tokenizer

    def _get_lm_embeddings(self, lm_model="albert"):
        if lm_model == "albert":
            return transformers.AlbertModel.from_pretrained(
                'albert-base-v2').embeddings.word_embeddings
        else:
            raise NotImplementedError(
                f"{lm_model} lm model is not supported. Only albert is supported at this moment."
            )

    def _get_lm_padding_value(self, lm_model="albert"):
        if lm_model == "albert":
            return transformers.AlbertTokenizer.from_pretrained(
                'albert-base-v2')._convert_token_to_id('<pad>')
        else:
            raise NotImplementedError(
                f"{lm_model} lm model is not supported. Only albert is supported at this moment."
            )

    def _metrics(
        self,
        true_durs,
        true_text_len,
        pred_durs,
        true_pitch,
        pred_pitch,
        true_spect=None,
        pred_spect=None,
        true_spect_len=None,
        attn_logprob=None,
        attn_soft=None,
        attn_hard=None,
        attn_hard_dur=None,
    ):
        text_mask = get_mask_from_lengths(true_text_len)
        mel_mask = get_mask_from_lengths(true_spect_len)
        loss = 0.0

        # dur loss and metrics
        durs_loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(),
                               reduction='none')
        durs_loss = durs_loss * text_mask.float()
        durs_loss = durs_loss.sum() / text_mask.sum()

        durs_pred = pred_durs.exp() - 1
        durs_pred = torch.clamp_min(durs_pred, min=0)
        durs_pred = durs_pred.round().long()

        acc = ((true_durs == durs_pred) *
               text_mask).sum().float() / text_mask.sum() * 100
        acc_dist_1 = (((true_durs - durs_pred).abs() <= 1) *
                      text_mask).sum().float() / text_mask.sum() * 100
        acc_dist_3 = (((true_durs - durs_pred).abs() <= 3) *
                      text_mask).sum().float() / text_mask.sum() * 100

        pred_spect = pred_spect.transpose(1, 2)

        # mel loss
        mel_loss = F.mse_loss(pred_spect, true_spect,
                              reduction='none').mean(dim=-2)
        mel_loss = mel_loss * mel_mask.float()
        mel_loss = mel_loss.sum() / mel_mask.sum()

        loss = loss + self.durs_loss_scale * durs_loss + self.mel_loss_scale * mel_loss

        # aligner loss
        bin_loss, ctc_loss = None, None
        ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                         in_lens=true_text_len,
                                         out_lens=true_spect_len)
        loss = loss + ctc_loss
        if self.add_bin_loss:
            bin_loss = self.bin_loss(hard_attention=attn_hard,
                                     soft_attention=attn_soft)
            loss = loss + self.bin_loss_scale * bin_loss
        true_avg_pitch = average_pitch(true_pitch.unsqueeze(1),
                                       attn_hard_dur).squeeze(1)

        # pitch loss
        pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch,
                                reduction='none')  # noqa
        pitch_loss = (pitch_loss * text_mask).sum() / text_mask.sum()

        loss = loss + self.pitch_loss_scale * pitch_loss

        return loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss

    @torch.jit.unused
    def run_aligner(self, text, text_len, text_mask, spect, spect_len,
                    attn_prior):
        text_emb = self.symbol_emb(text)
        attn_soft, attn_logprob = self.aligner(
            spect,
            text_emb.permute(0, 2, 1),
            mask=text_mask == 0,
            attn_prior=attn_prior,
        )
        attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len)
        attn_hard_dur = attn_hard.sum(2)[:, 0, :]
        assert torch.all(torch.eq(attn_hard_dur.sum(dim=1), spect_len))
        return attn_soft, attn_logprob, attn_hard, attn_hard_dur

    @typecheck(
        input_types={
            "text":
            NeuralType(('B', 'T_text'), TokenIndex()),
            "text_len":
            NeuralType(('B', ), LengthsType()),
            "pitch":
            NeuralType(('B', 'T_audio'), RegressionValuesType(),
                       optional=True),
            "spect":
            NeuralType(('B', 'D', 'T_spec'),
                       MelSpectrogramType(),
                       optional=True),
            "spect_len":
            NeuralType(('B', ), LengthsType(), optional=True),
            "attn_prior":
            NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
            "lm_tokens":
            NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True),
        },
        output_types={
            "pred_spect":
            NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
            "durs_predicted":
            NeuralType(('B', 'T_text'), TokenDurationType()),
            "log_durs_predicted":
            NeuralType(('B', 'T_text'), TokenLogDurationType()),
            "pitch_predicted":
            NeuralType(('B', 'T_text'), RegressionValuesType()),
            "attn_soft":
            NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()),
            "attn_logprob":
            NeuralType(('B', 'S', 'T_spec', 'T_text'), LogprobsType()),
            "attn_hard":
            NeuralType(('B', 'S', 'T_spec', 'T_text'), ProbsType()),
            "attn_hard_dur":
            NeuralType(('B', 'T_text'), TokenDurationType()),
        },
    )
    def forward(self,
                text,
                text_len,
                pitch=None,
                spect=None,
                spect_len=None,
                attn_prior=None,
                lm_tokens=None):
        if self.training:
            assert pitch is not None

        text_mask = get_mask_from_lengths(text_len).unsqueeze(2)

        enc_out, enc_mask = self.encoder(text, text_mask)

        # aligner
        attn_soft, attn_logprob, attn_hard, attn_hard_dur = None, None, None, None
        if spect is not None:
            attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner(
                text, text_len, text_mask, spect, spect_len, attn_prior)

        if self.cond_on_lm_embeddings:
            lm_emb = self.lm_embeddings(lm_tokens)
            lm_features = self.self_attention_module(
                enc_out,
                lm_emb,
                lm_emb,
                q_mask=enc_mask.squeeze(2),
                kv_mask=lm_tokens != self.lm_padding_value)

        # duration predictor
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0)

        # pitch predictor
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)

        # avg pitch, add pitch_emb
        if not self.training:
            if pitch is not None:
                pitch = average_pitch(pitch.unsqueeze(1),
                                      attn_hard_dur).squeeze(1)
                pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
            else:
                pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
        else:
            pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))

        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.cond_on_lm_embeddings:
            enc_out = enc_out + lm_features

        # regulate length
        len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out)

        dec_out, dec_lens = self.decoder(
            len_regulated_enc_out,
            get_mask_from_lengths(dec_lens).unsqueeze(2))
        pred_spect = self.proj(dec_out)

        return (
            pred_spect,
            durs_predicted,
            log_durs_predicted,
            pitch_predicted,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        )

    def infer(
        self,
        text,
        text_len=None,
        text_mask=None,
        spect=None,
        spect_len=None,
        attn_prior=None,
        use_gt_durs=False,
        lm_tokens=None,
        pitch=None,
    ):
        if text_mask is None:
            text_mask = get_mask_from_lengths(text_len).unsqueeze(2)

        enc_out, enc_mask = self.encoder(text, text_mask)

        # aligner
        attn_hard_dur = None
        if use_gt_durs:
            attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner(
                text, text_len, text_mask, spect, spect_len, attn_prior)

        if self.cond_on_lm_embeddings:
            lm_emb = self.lm_embeddings(lm_tokens)
            lm_features = self.self_attention_module(
                enc_out,
                lm_emb,
                lm_emb,
                q_mask=enc_mask.squeeze(2),
                kv_mask=lm_tokens != self.lm_padding_value)

        # duration predictor
        log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
        durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0)

        # avg pitch, pitch predictor
        if use_gt_durs and pitch is not None:
            pitch = average_pitch(pitch.unsqueeze(1), attn_hard_dur).squeeze(1)
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        else:
            pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))

        # add pitch emb
        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if self.cond_on_lm_embeddings:
            enc_out = enc_out + lm_features

        if use_gt_durs:
            if attn_hard_dur is not None:
                len_regulated_enc_out, dec_lens = regulate_len(
                    attn_hard_dur, enc_out)
            else:
                raise NotImplementedError
        else:
            len_regulated_enc_out, dec_lens = regulate_len(
                durs_predicted, enc_out)

        dec_out, _ = self.decoder(len_regulated_enc_out,
                                  get_mask_from_lengths(dec_lens).unsqueeze(2))
        pred_spect = self.proj(dec_out)

        return pred_spect

    def on_train_epoch_start(self):
        bin_loss_start_epoch = np.ceil(self.bin_loss_start_ratio *
                                       self._trainer.max_epochs)

        # Add bin loss when current_epoch >= bin_start_epoch
        if not self.add_bin_loss and self.current_epoch >= bin_loss_start_epoch:
            logging.info(
                f"Using hard attentions after epoch: {self.current_epoch}")
            self.add_bin_loss = True

        if self.add_bin_loss:
            self.bin_loss_scale = min(
                (self.current_epoch - bin_loss_start_epoch) /
                self.bin_loss_warmup_epochs, 1.0)

    def training_step(self, batch, batch_idx):
        attn_prior, lm_tokens = None, None
        if self.cond_on_lm_embeddings:
            audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch
        else:
            audio, audio_len, text, text_len, attn_prior, pitch, _ = batch

        spect, spect_len = self.preprocessor(input_signal=audio,
                                             length=audio_len)

        # pitch normalization
        zero_pitch_idx = pitch == 0
        pitch = (pitch - self.pitch_mean) / self.pitch_std
        pitch[zero_pitch_idx] = 0.0

        (
            pred_spect,
            _,
            pred_log_durs,
            pred_pitch,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        ) = self(
            text=text,
            text_len=text_len,
            pitch=pitch,
            spect=spect,
            spect_len=spect_len,
            attn_prior=attn_prior,
            lm_tokens=lm_tokens,
        )

        (
            loss,
            durs_loss,
            acc,
            acc_dist_1,
            acc_dist_3,
            pitch_loss,
            mel_loss,
            ctc_loss,
            bin_loss,
        ) = self._metrics(
            pred_durs=pred_log_durs,
            pred_pitch=pred_pitch,
            true_durs=attn_hard_dur,
            true_text_len=text_len,
            true_pitch=pitch,
            true_spect=spect,
            pred_spect=pred_spect,
            true_spect_len=spect_len,
            attn_logprob=attn_logprob,
            attn_soft=attn_soft,
            attn_hard=attn_hard,
            attn_hard_dur=attn_hard_dur,
        )

        train_log = {
            'train_loss':
            loss,
            'train_durs_loss':
            durs_loss,
            'train_pitch_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if pitch_loss is None else pitch_loss,
            'train_mel_loss':
            mel_loss,
            'train_durs_acc':
            acc,
            'train_durs_acc_dist_3':
            acc_dist_3,
            'train_ctc_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if ctc_loss is None else ctc_loss,
            'train_bin_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if bin_loss is None else bin_loss,
        }

        return {'loss': loss, 'progress_bar': train_log, 'log': train_log}

    def validation_step(self, batch, batch_idx):
        attn_prior, lm_tokens = None, None
        if self.cond_on_lm_embeddings:
            audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch
        else:
            audio, audio_len, text, text_len, attn_prior, pitch, _ = batch

        spect, spect_len = self.preprocessor(input_signal=audio,
                                             length=audio_len)

        # pitch normalization
        zero_pitch_idx = pitch == 0
        pitch = (pitch - self.pitch_mean) / self.pitch_std
        pitch[zero_pitch_idx] = 0.0

        (
            pred_spect,
            _,
            pred_log_durs,
            pred_pitch,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
        ) = self(
            text=text,
            text_len=text_len,
            pitch=pitch,
            spect=spect,
            spect_len=spect_len,
            attn_prior=attn_prior,
            lm_tokens=lm_tokens,
        )

        (
            loss,
            durs_loss,
            acc,
            acc_dist_1,
            acc_dist_3,
            pitch_loss,
            mel_loss,
            ctc_loss,
            bin_loss,
        ) = self._metrics(
            pred_durs=pred_log_durs,
            pred_pitch=pred_pitch,
            true_durs=attn_hard_dur,
            true_text_len=text_len,
            true_pitch=pitch,
            true_spect=spect,
            pred_spect=pred_spect,
            true_spect_len=spect_len,
            attn_logprob=attn_logprob,
            attn_soft=attn_soft,
            attn_hard=attn_hard,
            attn_hard_dur=attn_hard_dur,
        )

        # without ground truth internal features except for durations
        pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur = self(
            text=text,
            text_len=text_len,
            pitch=None,
            spect=spect,
            spect_len=spect_len,
            attn_prior=attn_prior,
            lm_tokens=lm_tokens,
        )

        *_, with_pred_features_mel_loss, _, _ = self._metrics(
            pred_durs=pred_log_durs,
            pred_pitch=pred_pitch,
            true_durs=attn_hard_dur,
            true_text_len=text_len,
            true_pitch=pitch,
            true_spect=spect,
            pred_spect=pred_spect,
            true_spect_len=spect_len,
            attn_logprob=attn_logprob,
            attn_soft=attn_soft,
            attn_hard=attn_hard,
            attn_hard_dur=attn_hard_dur,
        )

        val_log = {
            'val_loss':
            loss,
            'val_durs_loss':
            durs_loss,
            'val_pitch_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if pitch_loss is None else pitch_loss,
            'val_mel_loss':
            mel_loss,
            'val_with_pred_features_mel_loss':
            with_pred_features_mel_loss,
            'val_durs_acc':
            acc,
            'val_durs_acc_dist_3':
            acc_dist_3,
            'val_ctc_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if ctc_loss is None else ctc_loss,
            'val_bin_loss':
            torch.tensor(1.0).to(durs_loss.device)
            if bin_loss is None else bin_loss,
        }
        self.log_dict(val_log,
                      prog_bar=False,
                      on_epoch=True,
                      logger=True,
                      sync_dist=True)

        if batch_idx == 0 and self.current_epoch % 5 == 0 and isinstance(
                self.logger, WandbLogger):
            specs = []
            pitches = []
            for i in range(min(3, spect.shape[0])):
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(
                            spect[i, :, :spect_len[i]].data.cpu().numpy()),
                        caption=f"gt mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(
                            pred_spect.transpose(
                                1, 2)[i, :, :spect_len[i]].data.cpu().numpy()),
                        caption=f"pred mel {i}",
                    ),
                ]

                pitches += [
                    wandb.Image(
                        plot_pitch_to_numpy(
                            average_pitch(pitch.unsqueeze(1),
                                          attn_hard_dur).squeeze(1)
                            [i, :text_len[i]].data.cpu().numpy(),
                            ylim_range=[-2.5, 2.5],
                        ),
                        caption=f"gt pitch {i}",
                    ),
                ]

                pitches += [
                    wandb.Image(
                        plot_pitch_to_numpy(
                            pred_pitch[i, :text_len[i]].data.cpu().numpy(),
                            ylim_range=[-2.5, 2.5]),
                        caption=f"pred pitch {i}",
                    ),
                ]

            self.logger.experiment.log({"specs": specs, "pitches": pitches})

    @typecheck(
        input_types={
            "tokens":
            NeuralType(('B', 'T_text'), TokenIndex(), optional=True),
            "tokens_len":
            NeuralType(('B'), LengthsType(), optional=True),
            "lm_tokens":
            NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True),
            "raw_texts": [NeuralType(optional=True)],
            "lm_model":
            NeuralType(optional=True),
        },
        output_types={
            "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
        },
    )
    def generate_spectrogram(
        self,
        tokens: Optional[torch.Tensor] = None,
        tokens_len: Optional[torch.Tensor] = None,
        lm_tokens: Optional[torch.Tensor] = None,
        raw_texts: Optional[List[str]] = None,
        lm_model: str = "albert",
    ):
        if tokens is not None:
            if tokens_len is None:
                # it is assumed that padding is consecutive and only at the end
                tokens_len = (tokens != self.tokenizer.pad).sum(dim=-1)
        else:
            if raw_texts is None:
                logging.error("raw_texts must be specified if tokens is None")

            t_seqs = [self.tokenizer(t) for t in raw_texts]
            tokens = torch.nn.utils.rnn.pad_sequence(
                sequences=[
                    torch.tensor(t, dtype=torch.long, device=self.device)
                    for t in t_seqs
                ],
                batch_first=True,
                padding_value=self.tokenizer.pad,
            )
            tokens_len = torch.tensor([len(t) for t in t_seqs],
                                      dtype=torch.long,
                                      device=tokens.device)

        if self.cond_on_lm_embeddings and lm_tokens is None:
            if raw_texts is None:
                logging.error(
                    "raw_texts must be specified if lm_tokens is None")

            lm_model_tokenizer = self._get_lm_model_tokenizer(lm_model)
            lm_padding_value = lm_model_tokenizer._convert_token_to_id('<pad>')
            lm_space_value = lm_model_tokenizer._convert_token_to_id('▁')

            assert isinstance(self.tokenizer,
                              EnglishCharsTokenizer) or isinstance(
                                  self.tokenizer, EnglishPhonemesTokenizer)

            preprocess_texts_as_tts_input = [
                self.tokenizer.text_preprocessing_func(t) for t in raw_texts
            ]
            lm_tokens_as_ids_list = [
                lm_model_tokenizer.encode(t, add_special_tokens=False)
                for t in preprocess_texts_as_tts_input
            ]

            if self.tokenizer.pad_with_space:
                lm_tokens_as_ids_list = [[lm_space_value] + t +
                                         [lm_space_value]
                                         for t in lm_tokens_as_ids_list]

            lm_tokens = torch.full(
                (len(lm_tokens_as_ids_list),
                 max([len(t) for t in lm_tokens_as_ids_list])),
                fill_value=lm_padding_value,
                device=tokens.device,
            )
            for i, lm_tokens_i in enumerate(lm_tokens_as_ids_list):
                lm_tokens[i, :len(lm_tokens_i)] = torch.tensor(
                    lm_tokens_i, device=tokens.device)

        pred_spect = self.infer(tokens, tokens_len,
                                lm_tokens=lm_tokens).transpose(1, 2)
        return pred_spect

    def parse(self, text: str, normalize=True) -> torch.Tensor:
        if normalize and getattr(self, "text_normalizer_call",
                                 None) is not None:
            text = self.text_normalizer_call(text,
                                             **self.text_normalizer_call_args)
        return torch.tensor(
            self.tokenizer.encode(text)).long().unsqueeze(0).to(self.device)

    @staticmethod
    def _loader(cfg):
        try:
            _ = cfg.dataset.manifest_filepath
        except omegaconf.errors.MissingMandatoryValue:
            logging.warning(
                "manifest_filepath was skipped. No dataset for this model.")
            return None

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(  # noqa
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            **cfg.dataloader_params,
        )

    def setup_training_data(self, cfg):
        self._train_dl = self._loader(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self._loader(cfg)

    def setup_test_data(self, cfg):
        """Omitted."""
        pass

    @classmethod
    def list_available_models(cls):
        """Empty."""
        pass

    @property
    def input_types(self):
        return {
            "text":
            NeuralType(('B', 'T_text'), TokenIndex()),
            "lm_tokens":
            NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True),
        }

    @property
    def output_types(self):
        return {
            "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
        }

    def forward_for_export(self, text, lm_tokens=None):
        text_mask = (text != self.tokenizer_pad).unsqueeze(2)
        spect = self.infer(text=text, text_mask=text_mask,
                           lm_tokens=lm_tokens).transpose(1, 2)
        return spect.to(torch.float)
Ejemplo n.º 21
0
class FastPitchModel(SpectrogramGenerator, Exportable):
    """FastPitch model (https://arxiv.org/abs/2006.06873) that is used to generate mel spectrogram from text."""
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Setup normalizer
        self.normalizer = None
        self.text_normalizer_call = None
        self.text_normalizer_call_kwargs = {}
        self._setup_normalizer(cfg)

        self.learn_alignment = cfg.get("learn_alignment", False)

        # Setup vocabulary (=tokenizer) and input_fft_kwargs (supported only with self.learn_alignment=True)
        input_fft_kwargs = {}
        if self.learn_alignment:
            self.vocab = None
            self.ds_class_name = cfg.train_ds.dataset._target_.split(".")[-1]

            if self.ds_class_name == "TTSDataset":
                self._setup_tokenizer(cfg)
                assert self.vocab is not None
                input_fft_kwargs["n_embed"] = len(self.vocab.tokens)
                input_fft_kwargs["padding_idx"] = self.vocab.pad
            elif self.ds_class_name == "AudioToCharWithPriorAndPitchDataset":
                logging.warning(
                    "AudioToCharWithPriorAndPitchDataset class has been deprecated. No support for"
                    " training or finetuning. Only inference is supported.")
                tokenizer_conf = self._get_default_text_tokenizer_conf()
                self._setup_tokenizer(tokenizer_conf)
                assert self.vocab is not None
                input_fft_kwargs["n_embed"] = len(self.vocab.tokens)
                input_fft_kwargs["padding_idx"] = self.vocab.pad
            else:
                raise ValueError(
                    f"Unknown dataset class: {self.ds_class_name}")

        self._parser = None
        self._tb_logger = None
        super().__init__(cfg=cfg, trainer=trainer)

        self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100)
        self.log_train_images = False

        loss_scale = 0.1 if self.learn_alignment else 1.0
        dur_loss_scale = loss_scale
        pitch_loss_scale = loss_scale
        if "dur_loss_scale" in cfg:
            dur_loss_scale = cfg.dur_loss_scale
        if "pitch_loss_scale" in cfg:
            pitch_loss_scale = cfg.pitch_loss_scale

        self.mel_loss = MelLoss()
        self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale)
        self.duration_loss = DurationLoss(loss_scale=dur_loss_scale)

        self.aligner = None
        if self.learn_alignment:
            self.aligner = instantiate(self._cfg.alignment_module)
            self.forward_sum_loss = ForwardSumLoss()
            self.bin_loss = BinLoss()

        self.preprocessor = instantiate(self._cfg.preprocessor)
        input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs)
        output_fft = instantiate(self._cfg.output_fft)
        duration_predictor = instantiate(self._cfg.duration_predictor)
        pitch_predictor = instantiate(self._cfg.pitch_predictor)

        self.fastpitch = FastPitchModule(
            input_fft,
            output_fft,
            duration_predictor,
            pitch_predictor,
            self.aligner,
            cfg.n_speakers,
            cfg.symbols_embedding_dim,
            cfg.pitch_embedding_kernel_size,
            cfg.n_mel_channels,
        )
        self._input_types = self._output_types = None

    def _get_default_text_tokenizer_conf(self):
        text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
        return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer))

    def _setup_normalizer(self, cfg):
        if "text_normalizer" in cfg:
            normalizer_kwargs = {}

            if "whitelist" in cfg.text_normalizer:
                normalizer_kwargs["whitelist"] = self.register_artifact(
                    'text_normalizer.whitelist', cfg.text_normalizer.whitelist)

            self.normalizer = instantiate(cfg.text_normalizer,
                                          **normalizer_kwargs)
            self.text_normalizer_call = self.normalizer.normalize
            if "text_normalizer_call_kwargs" in cfg:
                self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs

    def _setup_tokenizer(self, cfg):
        text_tokenizer_kwargs = {}
        if "g2p" in cfg.text_tokenizer:
            g2p_kwargs = {}

            if "phoneme_dict" in cfg.text_tokenizer.g2p:
                g2p_kwargs["phoneme_dict"] = self.register_artifact(
                    'text_tokenizer.g2p.phoneme_dict',
                    cfg.text_tokenizer.g2p.phoneme_dict,
                )

            if "heteronyms" in cfg.text_tokenizer.g2p:
                g2p_kwargs["heteronyms"] = self.register_artifact(
                    'text_tokenizer.g2p.heteronyms',
                    cfg.text_tokenizer.g2p.heteronyms,
                )

            text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p,
                                                       **g2p_kwargs)

        self.vocab = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs)

    @property
    def tb_logger(self):
        if self._tb_logger is None:
            if self.logger is None and self.logger.experiment is None:
                return None
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            self._tb_logger = tb_logger
        return self._tb_logger

    @property
    def parser(self):
        if self._parser is not None:
            return self._parser

        if self.learn_alignment:
            ds_class_name = self._cfg.train_ds.dataset._target_.split(".")[-1]

            if ds_class_name == "TTSDataset":
                self._parser = self.vocab.encode
            elif ds_class_name == "AudioToCharWithPriorAndPitchDataset":
                if self.vocab is None:
                    tokenizer_conf = self._get_default_text_tokenizer_conf()
                    self._setup_tokenizer(tokenizer_conf)
                self._parser = self.vocab.encode
            else:
                raise ValueError(f"Unknown dataset class: {ds_class_name}")
        else:
            self._parser = parsers.make_parser(
                labels=self._cfg.labels,
                name='en',
                unk_id=-1,
                blank_id=-1,
                do_normalize=True,
                abbreviation_version="fastpitch",
                make_table=False,
            )
        return self._parser

    def parse(self, str_input: str, normalize=True) -> torch.tensor:
        if self.training:
            logging.warning("parse() is meant to be called in eval mode.")

        if normalize and self.text_normalizer_call is not None:
            str_input = self.text_normalizer_call(
                str_input, **self.text_normalizer_call_kwargs)

        if self.learn_alignment:
            eval_phon_mode = contextlib.nullcontext()
            if hasattr(self.vocab, "set_phone_prob"):
                eval_phon_mode = self.vocab.set_phone_prob(prob=1.0)

            # Disable mixed g2p representation if necessary
            with eval_phon_mode:
                tokens = self.parser(str_input)
        else:
            tokens = self.parser(str_input)

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    @typecheck(
        input_types={
            "text":
            NeuralType(('B', 'T_text'), TokenIndex()),
            "durs":
            NeuralType(('B', 'T_text'), TokenDurationType()),
            "pitch":
            NeuralType(('B', 'T_audio'), RegressionValuesType()),
            "speaker":
            NeuralType(('B'), Index(), optional=True),
            "pace":
            NeuralType(optional=True),
            "spec":
            NeuralType(('B', 'D', 'T_spec'),
                       MelSpectrogramType(),
                       optional=True),
            "attn_prior":
            NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
            "mel_lens":
            NeuralType(('B'), LengthsType(), optional=True),
            "input_lens":
            NeuralType(('B'), LengthsType(), optional=True),
        })
    def forward(
        self,
        *,
        text,
        durs=None,
        pitch=None,
        speaker=None,
        pace=1.0,
        spec=None,
        attn_prior=None,
        mel_lens=None,
        input_lens=None,
    ):
        return self.fastpitch(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=pace,
            spec=spec,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=input_lens,
        )

    @typecheck(output_types={
        "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())
    })
    def generate_spectrogram(self,
                             tokens: 'torch.tensor',
                             speaker: Optional[int] = None,
                             pace: float = 1.0) -> torch.tensor:
        if self.training:
            logging.warning(
                "generate_spectrogram() is meant to be called in eval mode.")
        if isinstance(speaker, int):
            speaker = torch.tensor([speaker]).to(self.device)
        spect, *_ = self(text=tokens,
                         durs=None,
                         pitch=None,
                         speaker=speaker,
                         pace=pace)
        return spect

    def training_step(self, batch, batch_idx):
        attn_prior, durs, speaker = None, None, None
        if self.learn_alignment:
            if self.ds_class_name == "TTSDataset":
                if SpeakerID in self._train_dl.dataset.sup_data_types_set:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
                else:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
            else:
                raise ValueError(
                    f"Unknown vocab class: {self.vocab.__class__.__name__}")
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

        mels, spec_len = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                             in_lens=text_lens,
                                             out_lens=spec_len)
            bin_loss_weight = min(
                self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss += pitch_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_train_images and isinstance(self.logger,
                                                TensorBoardLogger):
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss

    def validation_step(self, batch, batch_idx):
        attn_prior, durs, speaker = None, None, None
        if self.learn_alignment:
            if self.ds_class_name == "TTSDataset":
                if SpeakerID in self._train_dl.dataset.sup_data_types_set:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _, speaker = batch
                else:
                    audio, audio_lens, text, text_lens, attn_prior, pitch, _ = batch
            else:
                raise ValueError(
                    f"Unknown vocab class: {self.vocab.__class__.__name__}")
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

        mels, mel_lens = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        # Calculate val loss on ground truth durations to better align L2 loss in time
        mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss = mel_loss + dur_loss + pitch_loss

        return {
            "val_loss": loss,
            "mel_loss": mel_loss,
            "dur_loss": dur_loss,
            "pitch_loss": pitch_loss,
            "mel_target": mels if batch_idx == 0 else None,
            "mel_pred": mels_pred if batch_idx == 0 else None,
        }

    def validation_epoch_end(self, outputs):
        collect = lambda key: torch.stack([x[key] for x in outputs]).mean()
        val_loss = collect("val_loss")
        mel_loss = collect("mel_loss")
        dur_loss = collect("dur_loss")
        pitch_loss = collect("pitch_loss")
        self.log("v_loss", val_loss)
        self.log("v_mel_loss", mel_loss)
        self.log("v_dur_loss", dur_loss)
        self.log("v_pitch_loss", pitch_loss)

        _, _, _, _, spec_target, spec_predict = outputs[0].values()

        if isinstance(self.logger, TensorBoardLogger):
            self.tb_logger.add_image(
                "val_mel_target",
                plot_spectrogram_to_numpy(
                    spec_target[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = spec_predict[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "val_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            self.log_train_images = True

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        if cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset":
            phon_mode = contextlib.nullcontext()
            if hasattr(self.vocab, "set_phone_prob"):
                phon_mode = self.vocab.set_phone_prob(
                    prob=None if name ==
                    "val" else self.vocab.phoneme_probability)

            with phon_mode:
                dataset = instantiate(
                    cfg.dataset,
                    text_normalizer=self.normalizer,
                    text_normalizer_call_kwargs=self.
                    text_normalizer_call_kwargs,
                    text_tokenizer=self.vocab,
                )
        else:
            dataset = instantiate(cfg.dataset)

        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="val")

    def setup_test_data(self, cfg):
        """Omitted."""
        pass

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_fastpitch",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.8.1/files/tts_en_fastpitch_align.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models

    # Methods for model exportability
    def _prepare_for_export(self, **kwargs):
        super()._prepare_for_export(**kwargs)

        # Define input_types and output_types as required by export()
        self._input_types = {
            "text": NeuralType(('B', 'T_text'), TokenIndex()),
            "pitch": NeuralType(('B', 'T_text'), RegressionValuesType()),
            "pace": NeuralType(('B', 'T_text'), optional=True),
            "volume": NeuralType(('B', 'T_text')),
            "speaker": NeuralType(('B'), Index()),
        }
        self._output_types = {
            "spect":
            NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
            "num_frames":
            NeuralType(('B'), TokenDurationType()),
            "durs_predicted":
            NeuralType(('B', 'T_text'), TokenDurationType()),
            "log_durs_predicted":
            NeuralType(('B', 'T_text'), TokenLogDurationType()),
            "pitch_predicted":
            NeuralType(('B', 'T_text'), RegressionValuesType()),
            "volume_aligned":
            NeuralType(('B', 'T_spec'), RegressionValuesType()),
        }

    def _export_teardown(self):
        self._input_types = self._output_types = None

    @property
    def disabled_deployment_input_names(self):
        """Implement this method to return a set of input names disabled for export"""
        disabled_inputs = set()
        if self.fastpitch.speaker_emb is None:
            disabled_inputs.add("speaker")
        return disabled_inputs

    @property
    def input_types(self):
        return self._input_types

    @property
    def output_types(self):
        return self._output_types

    def input_example(self, max_batch=1, max_dim=44):
        """
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        """
        par = next(self.fastpitch.parameters())
        sz = (max_batch, max_dim)
        inp = torch.randint(0,
                            self.fastpitch.encoder.word_emb.num_embeddings,
                            sz,
                            device=par.device,
                            dtype=torch.int64)
        pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
        pace = torch.clamp(
            (torch.randn(sz, device=par.device, dtype=torch.float32) + 1) *
            0.1,
            min=0.01)
        volume = torch.clamp(
            (torch.randn(sz, device=par.device, dtype=torch.float32) + 1) *
            0.1,
            min=0.01)

        inputs = {'text': inp, 'pitch': pitch, 'pace': pace, 'volume': volume}

        if self.fastpitch.speaker_emb is not None:
            inputs['speaker'] = torch.randint(
                0,
                self.fastpitch.speaker_emb.num_embeddings, (max_batch, ),
                device=par.device,
                dtype=torch.int64)

        return (inputs, )

    def forward_for_export(self, text, pitch, pace, volume, speaker=None):
        return self.fastpitch.infer(text=text,
                                    pitch=pitch,
                                    pace=pace,
                                    volume=volume,
                                    speaker=speaker)
Ejemplo n.º 22
0
class HifiGanModel(Vocoder, Exportable):
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator()
        self.msd = MultiScaleDiscriminator()
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.l1_factor = cfg.get("l1_loss_factor", 45)

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        if self._train_dl and isinstance(self._train_dl.dataset,
                                         MelAudioDataset):
            self.input_as_mel = True
        else:
            self.input_as_mel = False

        self.automatic_optimization = False

    def configure_optimizers(self):
        self.optim_g = instantiate(
            self._cfg.optim,
            params=self.generator.parameters(),
        )
        self.optim_d = instantiate(
            self._cfg.optim,
            params=itertools.chain(self.msd.parameters(),
                                   self.mpd.parameters()),
        )

        self.scheduler_g = CosineAnnealing(
            optimizer=self.optim_g,
            max_steps=self._cfg.max_steps,
            min_lr=self._cfg.sched.min_lr,
            warmup_steps=self._cfg.sched.warmup_ratio * self._cfg.max_steps,
        )  # Use warmup to delay start
        sch1_dict = {
            'scheduler': self.scheduler_g,
            'interval': 'step',
        }

        self.scheduler_d = CosineAnnealing(
            optimizer=self.optim_d,
            max_steps=self._cfg.max_steps,
            min_lr=self._cfg.sched.min_lr,
        )
        sch2_dict = {
            'scheduler': self.scheduler_d,
            'interval': 'step',
        }

        return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]

    @property
    def input_types(self):
        return {
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
        }

    @property
    def output_types(self):
        return {
            "audio": NeuralType(('B', 'S', 'T'),
                                AudioSignal(self.sample_rate)),
        }

    @typecheck()
    def forward(self, *, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)

    def forward_for_export(self, spec):
        return self.generator(x=spec)

    @typecheck(
        input_types={
            "spec": NeuralType(('B', 'C', 'T'), MelSpectrogramType())
        },
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def convert_spectrogram_to_audio(self,
                                     spec: 'torch.tensor') -> 'torch.tensor':
        return self(spec=spec).squeeze(1)

    def training_step(self, batch, batch_idx, optimizer_idx):
        # if in finetune mode the mels are pre-computed using a
        # spectrogram generator
        if self.input_as_mel:
            audio, audio_len, audio_mel = batch
        # else, we compute the mel using the ground truth audio
        else:
            audio, audio_len = batch
            # mel as input for generator
            audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)

        # mel as input for L1 mel loss
        audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
        audio = audio.unsqueeze(1)

        audio_pred = self.generator(x=audio_mel)
        audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1),
                                                audio_len)

        # train discriminator
        self.optim_d.zero_grad()
        mpd_score_real, mpd_score_gen, _, _ = self.mpd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_mpd, _, _ = self.discriminator_loss(
            disc_real_outputs=mpd_score_real,
            disc_generated_outputs=mpd_score_gen)
        msd_score_real, msd_score_gen, _, _ = self.msd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_msd, _, _ = self.discriminator_loss(
            disc_real_outputs=msd_score_real,
            disc_generated_outputs=msd_score_gen)
        loss_d = loss_disc_msd + loss_disc_mpd
        self.manual_backward(loss_d)
        self.optim_d.step()

        # train generator
        self.optim_g.zero_grad()
        loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel)
        _, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(
            y=audio, y_hat=audio_pred)
        _, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(
            y=audio, y_hat=audio_pred)
        loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real,
                                        fmap_g=fmap_mpd_gen)
        loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real,
                                        fmap_g=fmap_msd_gen)
        loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
        loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
        loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel * self.l1_factor
        self.manual_backward(loss_g)
        self.optim_g.step()

        metrics = {
            "g_l1_loss": loss_mel,
            "g_loss_fm_mpd": loss_fm_mpd,
            "g_loss_fm_msd": loss_fm_msd,
            "g_loss_gen_mpd": loss_gen_mpd,
            "g_loss_gen_msd": loss_gen_msd,
            "g_loss": loss_g,
            "d_loss_mpd": loss_disc_mpd,
            "d_loss_msd": loss_disc_msd,
            "d_loss": loss_d,
            "global_step": self.global_step,
            "lr": self.optim_g.param_groups[0]['lr'],
        }
        self.log_dict(metrics, on_step=True, sync_dist=True)
        self.log("g_l1_loss",
                 loss_mel,
                 prog_bar=True,
                 logger=False,
                 sync_dist=True)

    def validation_step(self, batch, batch_idx):
        if self.input_as_mel:
            audio, audio_len, audio_mel = batch
            audio_mel_len = [audio_mel.shape[1]] * audio_mel.shape[0]
        else:
            audio, audio_len = batch
            audio_mel, audio_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred = self(spec=audio_mel)

        # perform bias denoising
        pred_denoised = self._bias_denoise(audio_pred, audio_mel).squeeze(1)
        pred_denoised_mel, _ = self.audio_to_melspec_precessor(
            pred_denoised, audio_len)

        if self.input_as_mel:
            gt_mel, gt_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred_mel, _ = self.audio_to_melspec_precessor(
            audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log_dict({"val_loss": loss_mel}, on_epoch=True, sync_dist=True)

        # plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger,
                                         WandbLogger) and HAVE_WANDB:
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i,
                                   0, :audio_len[i]].data.cpu().numpy().astype(
                                       'float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        pred_denoised[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"denoised audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"input mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"output mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(pred_denoised_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"denoised mel {i}",
                    ),
                ]
                if self.input_as_mel:
                    specs += [
                        wandb.Image(
                            plot_spectrogram_to_numpy(gt_mel[
                                i, :, :audio_mel_len[i]].data.cpu().numpy()),
                            caption=f"gt mel {i}",
                        ),
                    ]

            self.logger.experiment.log({"audio": clips, "specs": specs})

    def _bias_denoise(self, audio, mel):
        def stft(x):
            comp = stft_patch(x.squeeze(1),
                              n_fft=1024,
                              hop_length=256,
                              win_length=1024)
            real, imag = comp[..., 0], comp[..., 1]
            mags = torch.sqrt(real**2 + imag**2)
            phase = torch.atan2(imag, real)
            return mags, phase

        def istft(mags, phase):
            comp = torch.stack(
                [mags * torch.cos(phase), mags * torch.sin(phase)], dim=-1)
            x = torch.istft(comp, n_fft=1024, hop_length=256, win_length=1024)
            return x

        # create bias tensor
        if self.stft_bias is None or self.stft_bias.shape[0] != audio.shape[0]:
            audio_bias = self(spec=torch.zeros_like(mel, device=mel.device))
            self.stft_bias, _ = stft(audio_bias)
            self.stft_bias = self.stft_bias[:, :, 0][:, :, None]

        audio_mags, audio_phase = stft(audio)
        audio_mags = audio_mags - self.cfg.get("denoise_strength",
                                               0.0025) * self.stft_bias
        audio_mags = torch.clamp(audio_mags, 0.0)
        audio_denoised = istft(audio_mags, audio_phase).unsqueeze(1)

        return audio_denoised

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'Optional[Dict[str, str]]':
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_hifigan",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_hifigan/versions/1.0.0rc1/files/tts_hifigan.nemo",
            description=
            "This model is trained on LJSpeech audio sampled at 22050Hz and mel spectrograms generated from Tacotron2, TalkNet, and FastPitch. This model has been tested on generating female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)
        return list_of_models

    def load_state_dict(self, state_dict, strict=True):
        # override load_state_dict to give us some flexibility to be backward-compatible
        # with old checkpoints
        new_state_dict = {}
        num_resblocks = len(self.cfg['generator']['resblock_kernel_sizes'])
        for k, v in state_dict.items():
            new_k = k
            if 'resblocks' in k:
                parts = k.split(".")
                # only do this is the checkpoint type is older
                if len(parts) == 6:
                    layer = int(parts[2])
                    new_layer = f"{layer//num_resblocks}.{layer%num_resblocks}"
                    new_k = f"generator.resblocks.{new_layer}.{'.'.join(parts[3:])}"
            new_state_dict[new_k] = v
        super().load_state_dict(new_state_dict, strict=strict)

    def _prepare_for_export(self, **kwargs):
        """
        Override this method to prepare module for export. This is in-place operation.
        Base version does common necessary module replacements (Apex etc)
        """
        self.generator.remove_weight_norm()

    def input_example(self):
        """
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        """
        par = next(self.parameters())
        mel = torch.randn((1, self.cfg['preprocessor']['nfilt'], 96),
                          device=par.device,
                          dtype=par.dtype)
        return mel
Ejemplo n.º 23
0
 def output_types(self):
     return {
         'mel': NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
     }
Ejemplo n.º 24
0
class UnivNetModel(Vocoder, Exportable):
    """UnivNet model (https://arxiv.org/abs/2106.07889) that is used to generate audio from mel spectrogram."""
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # We use separate preprocessor for training, because we need to pass grads and remove pitch fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(
            cfg.generator,
            n_mel_channels=cfg.preprocessor.nfilt,
            hop_length=cfg.preprocessor.n_window_stride)
        self.mpd = MultiPeriodDiscriminator(
            cfg.discriminator.mpd,
            debug=cfg.debug if "debug" in cfg else False)
        self.mrd = MultiResolutionDiscriminator(
            cfg.discriminator.mrd,
            debug=cfg.debug if "debug" in cfg else False)

        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        # Reshape MRD resolutions hyperparameter and apply them to MRSTFT loss
        self.stft_resolutions = cfg.discriminator.mrd.resolutions
        self.fft_sizes = [res[0] for res in self.stft_resolutions]
        self.hop_sizes = [res[1] for res in self.stft_resolutions]
        self.win_lengths = [res[2] for res in self.stft_resolutions]
        self.mrstft_loss = MultiResolutionSTFTLoss(self.fft_sizes,
                                                   self.hop_sizes,
                                                   self.win_lengths)
        self.stft_lamb = cfg.stft_lamb

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        self.input_as_mel = False
        if self._train_dl:
            # TODO(Oktai15): remove it in 1.8.0 version
            if isinstance(self._train_dl.dataset, MelAudioDataset):
                self.input_as_mel = True
            elif isinstance(self._train_dl.dataset, VocoderDataset):
                self.input_as_mel = self._train_dl.dataset.load_precomputed_mel

        self.automatic_optimization = False

    def _get_max_steps(self):
        return compute_max_steps(
            max_epochs=self._cfg.max_epochs,
            accumulate_grad_batches=self.trainer.accumulate_grad_batches,
            limit_train_batches=self.trainer.limit_train_batches,
            num_workers=get_num_workers(self.trainer),
            num_samples=len(self._train_dl.dataset),
            batch_size=get_batch_size(self._train_dl),
            drop_last=self._train_dl.drop_last,
        )

    def configure_optimizers(self):
        optim_g = instantiate(
            self._cfg.optim,
            params=self.generator.parameters(),
        )
        optim_d = instantiate(
            self._cfg.optim,
            params=itertools.chain(self.mrd.parameters(),
                                   self.mpd.parameters()),
        )

        return [optim_g, optim_d]

    @typecheck()
    def forward(self, *, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)

    @typecheck(
        input_types={
            "spec": NeuralType(('B', 'C', 'T'), MelSpectrogramType())
        },
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def convert_spectrogram_to_audio(self,
                                     spec: 'torch.tensor') -> 'torch.tensor':
        return self(spec=spec).squeeze(1)

    def training_step(self, batch, batch_idx):
        if self.input_as_mel:
            # Pre-computed spectrograms will be used as input
            audio, audio_len, audio_mel = batch
        else:
            audio, audio_len = batch
            audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)

        audio = audio.unsqueeze(1)

        audio_pred = self.generator(x=audio_mel)
        audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1),
                                                audio_len)

        optim_g, optim_d = self.optimizers()

        # Train discriminator
        optim_d.zero_grad()
        mpd_score_real, mpd_score_gen, _, _ = self.mpd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_mpd, _, _ = self.discriminator_loss(
            disc_real_outputs=mpd_score_real,
            disc_generated_outputs=mpd_score_gen)
        mrd_score_real, mrd_score_gen, _, _ = self.mrd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_mrd, _, _ = self.discriminator_loss(
            disc_real_outputs=mrd_score_real,
            disc_generated_outputs=mrd_score_gen)
        loss_d = loss_disc_mrd + loss_disc_mpd
        self.manual_backward(loss_d)
        optim_d.step()

        # Train generator
        optim_g.zero_grad()
        loss_sc, loss_mag = self.mrstft_loss(x=audio_pred.squeeze(1),
                                             y=audio.squeeze(1),
                                             input_lengths=audio_len)
        loss_sc = torch.stack(loss_sc).mean()
        loss_mag = torch.stack(loss_mag).mean()
        loss_mrstft = (loss_sc + loss_mag) * self.stft_lamb
        _, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred)
        _, mrd_score_gen, _, _ = self.mrd(y=audio, y_hat=audio_pred)
        loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
        loss_gen_mrd, _ = self.generator_loss(disc_outputs=mrd_score_gen)
        loss_g = loss_gen_mrd + loss_gen_mpd + loss_mrstft
        self.manual_backward(loss_g)
        optim_g.step()

        metrics = {
            "g_loss_sc": loss_sc,
            "g_loss_mag": loss_mag,
            "g_loss_mrstft": loss_mrstft,
            "g_loss_gen_mpd": loss_gen_mpd,
            "g_loss_gen_mrd": loss_gen_mrd,
            "g_loss": loss_g,
            "d_loss_mpd": loss_disc_mpd,
            "d_loss_mrd": loss_disc_mrd,
            "d_loss": loss_d,
            "global_step": self.global_step,
            "lr": optim_g.param_groups[0]['lr'],
        }
        self.log_dict(metrics, on_step=True, sync_dist=True)
        self.log("g_mrstft_loss",
                 loss_mrstft,
                 prog_bar=True,
                 logger=False,
                 sync_dist=True)

    def validation_step(self, batch, batch_idx):
        if self.input_as_mel:
            audio, audio_len, audio_mel = batch
            audio_mel_len = [audio_mel.shape[1]] * audio_mel.shape[0]
        else:
            audio, audio_len = batch
            audio_mel, audio_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred = self(spec=audio_mel)

        # Perform bias denoising
        pred_denoised = self._bias_denoise(audio_pred, audio_mel).squeeze(1)
        pred_denoised_mel, _ = self.audio_to_melspec_precessor(
            pred_denoised, audio_len)

        if self.input_as_mel:
            gt_mel, gt_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred_mel, _ = self.audio_to_melspec_precessor(
            audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log_dict({"val_loss": loss_mel}, on_epoch=True, sync_dist=True)

        # Plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger,
                                         WandbLogger) and HAVE_WANDB:
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i,
                                   0, :audio_len[i]].data.cpu().numpy().astype(
                                       'float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        pred_denoised[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"denoised audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"input mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"output mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(pred_denoised_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"denoised mel {i}",
                    ),
                ]
                if self.input_as_mel:
                    specs += [
                        wandb.Image(
                            plot_spectrogram_to_numpy(gt_mel[
                                i, :, :audio_mel_len[i]].data.cpu().numpy()),
                            caption=f"gt mel {i}",
                        ),
                    ]

            self.logger.experiment.log({"audio": clips, "specs": specs})

    def _bias_denoise(self, audio, mel):
        def stft(x):
            comp = torch.stft(x.squeeze(1),
                              n_fft=1024,
                              hop_length=256,
                              win_length=1024)
            real, imag = comp[..., 0], comp[..., 1]
            mags = torch.sqrt(real**2 + imag**2)
            phase = torch.atan2(imag, real)
            return mags, phase

        def istft(mags, phase):
            comp = torch.stack(
                [mags * torch.cos(phase), mags * torch.sin(phase)], dim=-1)
            x = torch.istft(comp, n_fft=1024, hop_length=256, win_length=1024)
            return x

        # Create bias tensor
        if self.stft_bias is None or self.stft_bias.shape[0] != audio.shape[0]:
            audio_bias = self(spec=torch.zeros_like(mel, device=mel.device))
            self.stft_bias, _ = stft(audio_bias)
            self.stft_bias = self.stft_bias[:, :, 0][:, :, None]

        audio_mags, audio_phase = stft(audio)
        audio_mags = audio_mags - self.cfg.get("denoise_strength",
                                               0.0025) * self.stft_bias
        audio_mags = torch.clamp(audio_mags, 0.0)
        audio_denoised = istft(audio_mags, audio_phase).unsqueeze(1)

        return audio_denoised

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    def setup_test_data(self, cfg):
        pass

    @classmethod
    def list_available_models(cls) -> 'Optional[Dict[str, str]]':
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_lj_univnet",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_lj_univnet/versions/1.7.0/files/tts_en_lj_univnet.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz, and has been tested on generating female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_libritts_univnet",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_libritts_univnet/versions/1.7.0/files/tts_en_libritts_multispeaker_univnet.nemo",
            description=
            "This model is trained on all LibriTTS training data (train-clean-100, train-clean-360, and train-other-500) sampled at 22050Hz, and has been tested on generating English voices.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models

    # Methods for model exportability
    def _prepare_for_export(self, **kwargs):
        if self.generator is not None:
            try:
                self.generator.remove_weight_norm()
            except ValueError:
                return

    @property
    def input_types(self):
        return {
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
        }

    @property
    def output_types(self):
        return {
            "audio": NeuralType(('B', 'S', 'T'),
                                AudioSignal(self.sample_rate)),
        }

    def input_example(self, max_batch=1, max_dim=256):
        """
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        """
        par = next(self.parameters())
        mel = torch.randn(
            (max_batch, self.cfg['preprocessor']['nfilt'], max_dim),
            device=par.device,
            dtype=par.dtype)
        return ({'spec': mel}, )

    def forward_for_export(self, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)
Ejemplo n.º 25
0
 def input_types(self):
     return {
         "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
     }
Ejemplo n.º 26
0
class GlowTTSModel(SpectrogramGenerator):
    """
    GlowTTS model used to generate spectrograms from text
    Consists of a text encoder and an invertible spectrogram decoder
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)

        self.parser = instantiate(cfg.parser)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(GlowTTSConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.preprocessor = instantiate(self._cfg.preprocessor)

        encoder = instantiate(self._cfg.encoder)
        decoder = instantiate(self._cfg.decoder)

        self.glow_tts = GlowTTSModule(encoder,
                                      decoder,
                                      n_speakers=cfg.n_speakers,
                                      gin_channels=cfg.gin_channels)
        self.loss = GlowTTSLoss()

    def parse(self, str_input: str) -> torch.tensor:
        if str_input[-1] not in [".", "!", "?"]:
            str_input = str_input + "."

        tokens = self.parser(str_input)

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    @typecheck(
        input_types={
            "x": NeuralType(('B', 'T'), TokenIndex()),
            "x_lengths": NeuralType(('B'), LengthsType()),
            "y": NeuralType(('B', 'D', 'T'),
                            MelSpectrogramType(),
                            optional=True),
            "y_lengths": NeuralType(('B'), LengthsType(), optional=True),
            "gen": NeuralType(optional=True),
            "noise_scale": NeuralType(optional=True),
            "length_scale": NeuralType(optional=True),
        })
    def forward(self,
                *,
                x,
                x_lengths,
                y=None,
                y_lengths=None,
                gen=False,
                noise_scale=0.3,
                length_scale=1.0):
        if gen:
            return self.glow_tts.generate_spect(text=x,
                                                text_lengths=x_lengths,
                                                noise_scale=noise_scale,
                                                length_scale=length_scale)
        else:
            return self.glow_tts(text=x,
                                 text_lengths=x_lengths,
                                 spect=y,
                                 spect_lengths=y_lengths)

    def step(self, y, y_lengths, x, x_lengths):
        z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn = self(
            x=x, x_lengths=x_lengths, y=y, y_lengths=y_lengths, gen=False)

        l_mle, l_length, logdet = self.loss(
            z=z,
            y_m=y_m,
            y_logs=y_logs,
            logdet=logdet,
            logw=logw,
            logw_=logw_,
            x_lengths=x_lengths,
            y_lengths=y_lengths,
        )

        loss = sum([l_mle, l_length])

        return l_mle, l_length, logdet, loss, attn

    def training_step(self, batch, batch_idx):
        y, y_lengths, x, x_lengths = batch

        y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths)

        l_mle, l_length, logdet, loss, _ = self.step(y, y_lengths, x,
                                                     x_lengths)

        output = {
            "loss": loss,  # required
            "progress_bar": {
                "l_mle": l_mle,
                "l_length": l_length,
                "logdet": logdet
            },
            "log": {
                "loss": loss,
                "l_mle": l_mle,
                "l_length": l_length,
                "logdet": logdet
            },
        }

        return output

    def validation_step(self, batch, batch_idx):
        y, y_lengths, x, x_lengths = batch

        y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths)

        l_mle, l_length, logdet, loss, attn = self.step(
            y, y_lengths, x, x_lengths)

        y_gen, attn_gen = self(x=x, x_lengths=x_lengths, gen=True)

        return {
            "loss": loss,
            "l_mle": l_mle,
            "l_length": l_length,
            "logdet": logdet,
            "y": y,
            "y_gen": y_gen,
            "x": x,
            "attn": attn,
            "attn_gen": attn_gen,
            "progress_bar": {
                "l_mle": l_mle,
                "l_length": l_length,
                "logdet": logdet
            },
        }

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_mle = torch.stack([x['l_mle'] for x in outputs]).mean()
        avg_length_loss = torch.stack([x['l_length'] for x in outputs]).mean()
        avg_logdet = torch.stack([x['logdet'] for x in outputs]).mean()
        tensorboard_logs = {
            'val_loss': avg_loss,
            'val_mle': avg_mle,
            'val_length_loss': avg_length_loss,
            'val_logdet': avg_logdet,
        }
        if self.logger is not None and self.logger.experiment is not None:
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            separated_phonemes = "|".join(
                [self.parser.symbols[c] for c in outputs[0]['x'][0]])
            tb_logger.add_text("separated phonemes", separated_phonemes,
                               self.global_step)
            tb_logger.add_image(
                "real_spectrogram",
                plot_spectrogram_to_numpy(
                    outputs[0]['y'][0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            tb_logger.add_image(
                "generated_spectrogram",
                plot_spectrogram_to_numpy(
                    outputs[0]['y_gen'][0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            tb_logger.add_image(
                "alignment_for_real_sp",
                plot_alignment_to_numpy(
                    outputs[0]['attn'][0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            tb_logger.add_image(
                "alignment_for_generated_sp",
                plot_alignment_to_numpy(
                    outputs[0]['attn_gen'][0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            log_audio_to_tb(tb_logger, outputs[0]['y'][0], "true_audio_gf",
                            self.global_step)
            log_audio_to_tb(tb_logger, outputs[0]['y_gen'][0],
                            "generated_audio_gf", self.global_step)
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def _setup_dataloader_from_config(self, cfg: DictConfig):

        if 'manifest_filepath' in cfg and cfg['manifest_filepath'] is None:
            logging.warning(
                f"Could not load dataset as `manifest_filepath` was None. Provided config : {cfg}"
            )
            return None

        if 'augmentor' in cfg:
            augmentor = process_augmentations(cfg['augmentor'])
        else:
            augmentor = None

        dataset = _AudioTextDataset(
            manifest_filepath=cfg['manifest_filepath'],
            parser=self.parser,
            sample_rate=cfg['sample_rate'],
            int_values=cfg.get('int_values', False),
            augmentor=augmentor,
            max_duration=cfg.get('max_duration', None),
            min_duration=cfg.get('min_duration', None),
            max_utts=cfg.get('max_utts', 0),
            trim=cfg.get('trim_silence', True),
            load_audio=cfg.get('load_audio', True),
            add_misc=cfg.get('add_misc', False),
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg['batch_size'],
            collate_fn=dataset.collate_fn,
            drop_last=cfg.get('drop_last', False),
            shuffle=cfg['shuffle'],
            num_workers=cfg.get('num_workers', 0),
        )

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def generate_spectrogram(self,
                             tokens: 'torch.tensor',
                             noise_scale: float = 0.0,
                             length_scale: float = 1.0) -> torch.tensor:

        self.eval()

        token_len = torch.tensor([tokens.shape[1]]).to(self.device)
        spect, _ = self(x=tokens,
                        x_lengths=token_len,
                        gen=True,
                        noise_scale=noise_scale,
                        length_scale=length_scale)

        return spect

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_glowtts",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_glowtts/versions/1.0.0rc1/files/tts_en_glowtts.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz, and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)
        return list_of_models
Ejemplo n.º 27
0
class WaveGlowModel(GlowVocoder, Exportable):
    """Waveglow model used to convert betweeen spectrograms and audio"""
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(WaveglowConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.sigma = self._cfg.sigma
        self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor)
        self.waveglow = instantiate(self._cfg.waveglow)
        self.loss = WaveGlowLoss()

    @GlowVocoder.mode.setter
    def mode(self, new_mode):
        if new_mode == OperationMode.training:
            self.train()
        else:
            self.eval()
        self._mode = new_mode
        self.waveglow.mode = new_mode

    @property
    def input_types(self):
        return {
            "audio": NeuralType(('B', 'T'), AudioSignal()),
            "audio_len": NeuralType(('B'), LengthsType()),
            "run_inverse": NeuralType(optional=True),
        }

    @property
    def output_types(self):
        if self.mode == OperationMode.training or self.mode == OperationMode.validation:
            output_dict = {
                "pred_normal_dist":
                NeuralType(('B', 'flowgroup', 'T'),
                           NormalDistributionSamplesType()),
                "log_s_list":
                [NeuralType(('B', 'flowgroup', 'T'),
                            VoidType())],  # TODO: Figure out a good typing
                "log_det_W_list":
                [NeuralType(elements_type=LogDeterminantType())],
            }
            if self.mode == OperationMode.validation:
                output_dict["audio_pred"] = NeuralType(('B', 'T'),
                                                       AudioSignal())
                output_dict["spec"] = NeuralType(('B', 'T', 'D'),
                                                 MelSpectrogramType())
                output_dict["spec_len"] = NeuralType(('B'), LengthsType())
            return output_dict
        return {
            "audio_pred": NeuralType(('B', 'T'), AudioSignal()),
        }

    @typecheck()
    def forward(self, *, audio, audio_len, run_inverse=True):
        if self.mode != self.waveglow.mode:
            raise ValueError(
                f"WaveGlowModel's mode {self.mode} does not match WaveGlowModule's mode {self.waveglow.mode}"
            )
        spec, spec_len = self.audio_to_melspec_precessor(audio, audio_len)
        tensors = self.waveglow(spec=spec,
                                audio=audio,
                                run_inverse=run_inverse,
                                sigma=self.sigma)
        if self.mode == OperationMode.training:
            return tensors[:-1]  # z, log_s_list, log_det_W_list
        elif self.mode == OperationMode.validation:
            z, log_s_list, log_det_W_list, audio_pred = tensors
            return z, log_s_list, log_det_W_list, audio_pred, spec, spec_len
        return tensors  # audio_pred

    @typecheck(
        input_types={
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "sigma": NeuralType(optional=True),
            "denoise": NeuralType(optional=True),
            "denoiser_strength": NeuralType(optional=True),
        },
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def convert_spectrogram_to_audio(
            self,
            spec: torch.Tensor,
            sigma: float = 1.0,
            denoise: bool = True,
            denoiser_strength: float = 0.01) -> torch.Tensor:
        with self.nemo_infer():
            self.waveglow.remove_weightnorm()
            audio = self.waveglow(spec=spec.to(
                self.waveglow.upsample.weight.dtype),
                                  run_inverse=True,
                                  audio=None,
                                  sigma=sigma)
            if denoise:
                audio = self.denoise(audio=audio, strength=denoiser_strength)

        return audio

    def training_step(self, batch, batch_idx):
        self.mode = OperationMode.training
        audio, audio_len = batch
        z, log_s_list, log_det_W_list = self(audio=audio,
                                             audio_len=audio_len,
                                             run_inverse=False)

        loss = self.loss(z=z,
                         log_s_list=log_s_list,
                         log_det_W_list=log_det_W_list,
                         sigma=self.sigma)
        output = {
            'loss': loss,
            'progress_bar': {
                'training_loss': loss
            },
            'log': {
                'loss': loss
            },
        }
        return output

    def validation_step(self, batch, batch_idx):
        self.mode = OperationMode.validation
        audio, audio_len = batch
        z, log_s_list, log_det_W_list, audio_pred, spec, spec_len = self(
            audio=audio, audio_len=audio_len, run_inverse=(batch_idx == 0))
        loss = self.loss(z=z,
                         log_s_list=log_s_list,
                         log_det_W_list=log_det_W_list,
                         sigma=self.sigma)
        return {
            "val_loss": loss,
            "audio_pred": audio_pred,
            "mel_target": spec,
            "mel_len": spec_len,
        }

    def validation_epoch_end(self, outputs):
        if self.logger is not None and self.logger.experiment is not None:
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            waveglow_log_to_tb_func(
                tb_logger,
                outputs[0].values(),
                self.global_step,
                tag="eval",
                mel_fb=self.audio_to_melspec_precessor.fb,
            )
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log('val_loss', avg_loss)

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_waveglow_268m",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_waveglow_268m/versions/1.0.0rc1/files/tts_waveglow_268m.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz, and has been tested on generating female English voices with an American accent and Mandarin voices.",
            class_=cls,
        )
        list_of_models.append(model)
        model = PretrainedModelInfo(
            pretrained_model_name="tts_waveglow_88m",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_waveglow_88m/versions/1.0.0rc1/files/tts_waveglow_88m.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz, and has been tested on generating female English voices with an American accent and Mandarin voices.",
            class_=cls,
        )
        list_of_models.append(model)
        return list_of_models

    @property
    def input_module(self):
        return self.waveglow

    @property
    def output_module(self):
        return self.waveglow

    def _prepare_for_export(self, **kwargs):
        self.update_bias_spect()
        self.waveglow._prepare_for_export(**kwargs)

    def forward_for_export(self, spec, z=None):
        return self.waveglow(spec, z)
Ejemplo n.º 28
0
class FastPitchModel(SpectrogramGenerator):
    """FastPitch Model that is used to generate mel spectrograms from text"""
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)

        self.learn_alignment = False
        if "learn_alignment" in cfg:
            self.learn_alignment = cfg.learn_alignment
        self._parser = None
        self._tb_logger = None
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(FastPitchConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.bin_loss_warmup_epochs = 100
        self.aligner = None
        self.log_train_images = False
        self.mel_loss = MelLoss()
        loss_scale = 0.1 if self.learn_alignment else 1.0
        self.pitch_loss = PitchLoss(loss_scale=loss_scale)
        self.duration_loss = DurationLoss(loss_scale=loss_scale)
        input_fft_kwargs = {}
        if self.learn_alignment:
            self.aligner = instantiate(self._cfg.alignment_module)
            self.forward_sum_loss = ForwardSumLoss()
            self.bin_loss = BinLoss()
            self.vocab = AudioToCharWithDursF0Dataset.make_vocab(
                **self._cfg.train_ds.dataset.vocab)
            input_fft_kwargs["n_embed"] = len(self.vocab.labels)
            input_fft_kwargs["padding_idx"] = self.vocab.pad

        self.preprocessor = instantiate(self._cfg.preprocessor)

        input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs)
        output_fft = instantiate(self._cfg.output_fft)
        duration_predictor = instantiate(self._cfg.duration_predictor)
        pitch_predictor = instantiate(self._cfg.pitch_predictor)

        self.fastpitch = FastPitchModule(
            input_fft,
            output_fft,
            duration_predictor,
            pitch_predictor,
            self.aligner,
            cfg.n_speakers,
            cfg.symbols_embedding_dim,
            cfg.pitch_embedding_kernel_size,
            cfg.n_mel_channels,
        )

    @property
    def tb_logger(self):
        if self._tb_logger is None:
            if self.logger is None and self.logger.experiment is None:
                return None
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            self._tb_logger = tb_logger
        return self._tb_logger

    @property
    def parser(self):
        if self._parser is not None:
            return self._parser

        if self.learn_alignment:
            vocab = AudioToCharWithDursF0Dataset.make_vocab(
                **self._cfg.train_ds.dataset.vocab)
            self._parser = vocab.encode
        else:
            self._parser = parsers.make_parser(
                labels=self._cfg.labels,
                name='en',
                unk_id=-1,
                blank_id=-1,
                do_normalize=True,
                abbreviation_version="fastpitch",
                make_table=False,
            )
        return self._parser

    def parse(self, str_input: str) -> torch.tensor:
        if str_input[-1] not in [".", "!", "?"]:
            str_input = str_input + "."

        tokens = self.parser(str_input)

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    @typecheck(
        input_types={
            "text":
            NeuralType(('B', 'T'), TokenIndex()),
            "durs":
            NeuralType(('B', 'T'), TokenDurationType()),
            "pitch":
            NeuralType(('B', 'T'), RegressionValuesType()),
            "speaker":
            NeuralType(('B'), Index()),
            "pace":
            NeuralType(optional=True),
            "spec":
            NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True),
            "attn_prior":
            NeuralType(('B', 'T', 'T'), ProbsType(), optional=True),
            "mel_lens":
            NeuralType(('B'), LengthsType(), optional=True),
            "input_lens":
            NeuralType(('B'), LengthsType(), optional=True),
        })
    def forward(
        self,
        *,
        text,
        durs=None,
        pitch=None,
        speaker=0,
        pace=1.0,
        spec=None,
        attn_prior=None,
        mel_lens=None,
        input_lens=None,
    ):
        return self.fastpitch(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speaker,
            pace=pace,
            spec=spec,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=input_lens,
        )

    @typecheck(output_types={
        "spect": NeuralType(('B', 'C', 'T'), MelSpectrogramType())
    })
    def generate_spectrogram(self,
                             tokens: 'torch.tensor',
                             speaker: int = 0,
                             pace: float = 1.0) -> torch.tensor:
        self.eval()
        spect, *_ = self(text=tokens,
                         durs=None,
                         pitch=None,
                         speaker=speaker,
                         pace=pace)
        return spect.transpose(1, 2)

    def training_step(self, batch, batch_idx):
        attn_prior, durs, speakers = None, None, None
        if self.learn_alignment:
            audio, audio_lens, text, text_lens, attn_prior, pitch = batch
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speakers = batch
        mels, spec_len = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        mels_pred, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speakers,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob,
                                             in_lens=text_lens,
                                             out_lens=spec_len)
            bin_loss_weight = min(
                self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss += pitch_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_train_images:
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().numpy().T
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss

    def validation_step(self, batch, batch_idx):
        attn_prior, durs, speakers = None, None, None
        if self.learn_alignment:
            audio, audio_lens, text, text_lens, attn_prior, pitch = batch
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speakers = batch
        mels, mel_lens = self.preprocessor(input_signal=audio,
                                           length=audio_lens)

        # Calculate val loss on ground truth durations to better align L2 loss in time
        mels_pred, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch = self(
            text=text,
            durs=durs,
            pitch=pitch,
            speaker=speakers,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred,
                                      durs_tgt=durs,
                                      len=text_lens)
        pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred,
                                     pitch_tgt=pitch,
                                     len=text_lens)
        loss = mel_loss + dur_loss + pitch_loss

        return {
            "val_loss": loss,
            "mel_loss": mel_loss,
            "dur_loss": dur_loss,
            "pitch_loss": pitch_loss,
            "mel_target": mels if batch_idx == 0 else None,
            "mel_pred": mels_pred if batch_idx == 0 else None,
        }

    def validation_epoch_end(self, outputs):
        collect = lambda key: torch.stack([x[key] for x in outputs]).mean()
        val_loss = collect("val_loss")
        mel_loss = collect("mel_loss")
        dur_loss = collect("dur_loss")
        pitch_loss = collect("pitch_loss")
        self.log("v_loss", val_loss)
        self.log("v_mel_loss", mel_loss)
        self.log("v_dur_loss", dur_loss)
        self.log("v_pitch_loss", pitch_loss)

        _, _, _, _, spec_target, spec_predict = outputs[0].values()
        self.tb_logger.add_image(
            "val_mel_target",
            plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
            self.global_step,
            dataformats="HWC",
        )
        spec_predict = spec_predict[0].data.cpu().numpy()
        self.tb_logger.add_image(
            "val_mel_predicted",
            plot_spectrogram_to_numpy(spec_predict.T),
            self.global_step,
            dataformats="HWC",
        )
        self.log_train_images = True

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        kwargs_dict = {}
        if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset":
            kwargs_dict["parser"] = self.parser
        dataset = instantiate(cfg.dataset, **kwargs_dict)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="val")

    def setup_test_data(self, cfg):
        """Omitted."""
        pass

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_fastpitch",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.0.0/files/tts_en_fastpitch.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models
Ejemplo n.º 29
0
class UniGlowModel(Vocoder):
    """Waveglow model used to convert betweeen spectrograms and audio"""
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(WaveglowConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.pad_value = self._cfg.preprocessor.params.pad_value
        self.sigma = self._cfg.sigma
        self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor)
        self.model = UniGlowModule(
            self._cfg.uniglow.n_mel_channels,
            self._cfg.uniglow.n_flows,
            self._cfg.uniglow.n_group,
            self._cfg.uniglow.n_wn_channels,
            self._cfg.uniglow.n_wn_layers,
            self._cfg.uniglow.wn_kernel_size,
            self.get_upsample_factor(),
        )
        self.mode = OperationMode.infer
        self.loss = UniGlowLoss(self._cfg.uniglow.stft_loss_coef)
        self.removed_weightnorm = False

    @property
    def input_types(self):
        return {
            "audio": NeuralType(('B', 'T'), AudioSignal()),
            "audio_len": NeuralType(('B'), LengthsType()),
        }

    @property
    def output_types(self):
        if self.mode == OperationMode.training or self.mode == OperationMode.validation:
            output_dict = {
                "pred_normal_dist":
                NeuralType(('B', 'flowgroup', 'T'),
                           NormalDistributionSamplesType()),
                "logdet":
                NeuralType(elements_type=LogDeterminantType()),
                "predicted_audio":
                NeuralType(('B', 'T'), AudioSignal()),
            }
            if self.mode == OperationMode.validation:
                output_dict["spec"] = NeuralType(('B', 'T', 'D'),
                                                 MelSpectrogramType())
                output_dict["spec_len"] = NeuralType(('B'), LengthsType())
            return output_dict
        return {
            "audio_pred": NeuralType(('B', 'T'), AudioSignal()),
        }

    @typecheck()
    def forward(self, *, audio, audio_len):
        if self.mode != self.model.mode:
            raise ValueError(
                f"WaveGlowModel's mode {self.mode} does not match WaveGlowModule's mode {self.model.mode}"
            )
        spec, spec_len = self.audio_to_melspec_precessor(audio, audio_len)
        tensors = self.model(spec=spec, audio=audio, sigma=self.sigma)
        if self.mode == OperationMode.training:
            return tensors  # z, logdet, audio_pred
        elif self.mode == OperationMode.validation:
            z, logdet, audio_pred = tensors
            return z, logdet, audio_pred, spec, spec_len
        return tensors

    @typecheck(
        input_types={
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "sigma": NeuralType(optional=True)
        },
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def convert_spectrogram_to_audio(self,
                                     spec: torch.Tensor,
                                     sigma: float = 1.0) -> torch.Tensor:
        if not self.removed_weightnorm:
            self.waveglow.remove_weightnorm()
            self.removed_weightnorm = True
        self.eval()
        self.mode = OperationMode.infer
        self.model.mode = OperationMode.infer

        with torch.no_grad():
            audio = self.model(spec=spec, audio=None, sigma=sigma)

        return audio

    def training_step(self, batch, batch_idx):
        self.mode = OperationMode.training
        self.model.mode = OperationMode.training
        audio, audio_len = batch
        z, logdet, predicted_audio = self(audio=audio, audio_len=audio_len)
        loss = self.loss(z=z,
                         logdet=logdet,
                         gt_audio=audio,
                         predicted_audio=predicted_audio,
                         sigma=self.sigma)
        output = {
            'loss': loss,
            'progress_bar': {
                'training_loss': loss
            },
            'log': {
                'loss': loss
            },
        }
        return output

    def validation_step(self, batch, batch_idx):
        self.mode = OperationMode.validation
        self.model.mode = OperationMode.validation
        audio, audio_len = batch
        z, logdet, predicted_audio, spec, spec_len = self(audio=audio,
                                                          audio_len=audio_len)
        loss = self.loss(z=z,
                         logdet=logdet,
                         gt_audio=audio,
                         predicted_audio=predicted_audio,
                         sigma=self.sigma)

        # compute average stoi score for batch
        stoi_score = 0
        sr = self._cfg.preprocessor.params.sample_rate
        for audio_i, audio_recon_i in zip(audio.cpu(), predicted_audio.cpu()):
            stoi_score += stoi(audio_i, audio_recon_i, sr)
        stoi_score /= audio.shape[0]

        return {
            "val_loss": loss,
            "predicted_audio": predicted_audio,
            "mel_target": spec,
            "mel_len": spec_len,
            "stoi": stoi_score,
        }

    def validation_epoch_end(self, outputs):
        if self.logger is not None and self.logger.experiment is not None:
            waveglow_log_to_tb_func(
                self.logger.experiment,
                tuple(outputs[0].values())[:-1],
                self.global_step,
                tag="eval",
                mel_fb=self.audio_to_melspec_precessor.fb,
            )
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_stoi = torch.FloatTensor([x['stoi'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss, 'stoi': avg_stoi}
        logging.info(
            f"Validation summary | Epoch {self.current_epoch} | NLL {avg_loss:.2f} | STOI: {avg_stoi:.2f}"
        )
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")  # TODO
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")  # TODO
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="UniGlow-22050Hz",
            location=
            "https://drive.google.com/file/d/18JO5heoz1pBicZnGGqJzAJYMpzxiDQDa/view?usp=sharing",
            description=
            "The model is trained on LJSpeech sampled at 22050Hz, and can be used as an universal vocoder",
        )
        list_of_models.append(model)
        return list_of_models

    def get_upsample_factor(self) -> int:
        """
        As the MelSpectrogram upsampling is done using interpolation, the upsampling factor is determined
        by the ratio of the MelSpectrogram length and the waveform length
        Returns:
            An integer representing the upsampling factor
        """
        audio = torch.ones(1, self._cfg.train_ds.dataset.params.n_segments)
        spec, spec_len = self.audio_to_melspec_precessor(
            audio, torch.FloatTensor([len(audio)]))
        spec = spec[:, :, :-1]
        audio = audio.unfold(1, self._cfg.uniglow.n_group,
                             self._cfg.uniglow.n_group).permute(0, 2, 1)
        upsample_factor = audio.shape[2] // spec.shape[2]
        return upsample_factor
Ejemplo n.º 30
0
class Tacotron2Model(SpectrogramGenerator):
    """Tacotron 2 Model that is used to generate mel spectrograms from text"""
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(Tacotron2Config)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        try:
            OmegaConf.merge(cfg, schema)
            self.pad_value = self._cfg.preprocessor.pad_value
        except ConfigAttributeError:
            self.pad_value = self._cfg.preprocessor.params.pad_value
            logging.warning(
                "Your config is using an old NeMo yaml configuration. Please ensure that the yaml matches the "
                "current version in the main branch for future compatibility.")

        self._parser = None
        self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor)
        self.text_embedding = nn.Embedding(len(cfg.labels) + 3, 512)
        self.encoder = instantiate(self._cfg.encoder)
        self.decoder = instantiate(self._cfg.decoder)
        self.postnet = instantiate(self._cfg.postnet)
        self.loss = Tacotron2Loss()
        self.calculate_loss = True

    @property
    def parser(self):
        if self._parser is not None:
            return self._parser
        if self._validation_dl is not None:
            return self._validation_dl.dataset.parser
        if self._test_dl is not None:
            return self._test_dl.dataset.parser
        if self._train_dl is not None:
            return self._train_dl.dataset.parser

        # Else construct a parser
        # Try to get params from validation, test, and then train
        params = {}
        try:
            params = self._cfg.validation_ds.dataset
        except ConfigAttributeError:
            pass
        if params == {}:
            try:
                params = self._cfg.test_ds.dataset
            except ConfigAttributeError:
                pass
        if params == {}:
            try:
                params = self._cfg.train_ds.dataset
            except ConfigAttributeError:
                pass

        name = params.get('parser', None) or 'en'
        unk_id = params.get('unk_index', None) or -1
        blank_id = params.get('blank_index', None) or -1
        do_normalize = params.get('normalize', None) or False
        self._parser = parsers.make_parser(
            labels=self._cfg.labels,
            name=name,
            unk_id=unk_id,
            blank_id=blank_id,
            do_normalize=do_normalize,
        )
        return self._parser

    def parse(self, str_input: str) -> torch.tensor:
        tokens = self.parser(str_input)
        # Parser doesn't add bos and eos ids, so maunally add it
        tokens = [len(self._cfg.labels)] + tokens + [len(self._cfg.labels) + 1]
        tokens_tensor = torch.tensor(tokens).unsqueeze_(0).to(self.device)

        return tokens_tensor

    @property
    def input_types(self):
        if self.training:
            return {
                "tokens": NeuralType(('B', 'T'), EmbeddedTextType()),
                "token_len": NeuralType(('B'), LengthsType()),
                "audio": NeuralType(('B', 'T'), AudioSignal()),
                "audio_len": NeuralType(('B'), LengthsType()),
            }
        else:
            return {
                "tokens": NeuralType(('B', 'T'), EmbeddedTextType()),
                "token_len": NeuralType(('B'), LengthsType()),
                "audio": NeuralType(('B', 'T'), AudioSignal(), optional=True),
                "audio_len": NeuralType(('B'), LengthsType(), optional=True),
            }

    @property
    def output_types(self):
        if not self.calculate_loss and not self.training:
            return {
                "spec_pred_dec":
                NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
                "spec_pred_postnet":
                NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
                "gate_pred":
                NeuralType(('B', 'T'), LogitsType()),
                "alignments":
                NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()),
                "pred_length":
                NeuralType(('B'), LengthsType()),
            }
        return {
            "spec_pred_dec":
            NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "spec_pred_postnet":
            NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "gate_pred":
            NeuralType(('B', 'T'), LogitsType()),
            "spec_target":
            NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "spec_target_len":
            NeuralType(('B'), LengthsType()),
            "alignments":
            NeuralType(('B', 'T', 'T'), SequenceToSequenceAlignmentType()),
        }

    @typecheck()
    def forward(self, *, tokens, token_len, audio=None, audio_len=None):
        if audio is not None and audio_len is not None:
            spec_target, spec_target_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        token_embedding = self.text_embedding(tokens).transpose(1, 2)
        encoder_embedding = self.encoder(token_embedding=token_embedding,
                                         token_len=token_len)
        if self.training:
            spec_pred_dec, gate_pred, alignments = self.decoder(
                memory=encoder_embedding,
                decoder_inputs=spec_target,
                memory_lengths=token_len)
        else:
            spec_pred_dec, gate_pred, alignments, pred_length = self.decoder(
                memory=encoder_embedding, memory_lengths=token_len)
        spec_pred_postnet = self.postnet(mel_spec=spec_pred_dec)

        if not self.calculate_loss:
            return spec_pred_dec, spec_pred_postnet, gate_pred, alignments, pred_length
        return spec_pred_dec, spec_pred_postnet, gate_pred, spec_target, spec_target_len, alignments

    @typecheck(
        input_types={"tokens": NeuralType(('B', 'T'), EmbeddedTextType())},
        output_types={
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType())
        },
    )
    def generate_spectrogram(self, *, tokens):
        self.eval()
        self.calculate_loss = False
        token_len = torch.tensor([len(i) for i in tokens]).to(self.device)
        tensors = self(tokens=tokens, token_len=token_len)
        spectrogram_pred = tensors[1]

        if spectrogram_pred.shape[0] > 1:
            # Silence all frames past the predicted end
            mask = ~get_mask_from_lengths(tensors[-1])
            mask = mask.expand(spectrogram_pred.shape[1], mask.size(0),
                               mask.size(1))
            mask = mask.permute(1, 0, 2)
            spectrogram_pred.data.masked_fill_(mask, self.pad_value)

        return spectrogram_pred

    def training_step(self, batch, batch_idx):
        audio, audio_len, tokens, token_len = batch
        spec_pred_dec, spec_pred_postnet, gate_pred, spec_target, spec_target_len, _ = self.forward(
            audio=audio,
            audio_len=audio_len,
            tokens=tokens,
            token_len=token_len)

        loss, _ = self.loss(
            spec_pred_dec=spec_pred_dec,
            spec_pred_postnet=spec_pred_postnet,
            gate_pred=gate_pred,
            spec_target=spec_target,
            spec_target_len=spec_target_len,
            pad_value=self.pad_value,
        )

        output = {
            'loss': loss,
            'progress_bar': {
                'training_loss': loss
            },
            'log': {
                'loss': loss
            },
        }
        return output

    def validation_step(self, batch, batch_idx):
        audio, audio_len, tokens, token_len = batch
        spec_pred_dec, spec_pred_postnet, gate_pred, spec_target, spec_target_len, alignments = self.forward(
            audio=audio,
            audio_len=audio_len,
            tokens=tokens,
            token_len=token_len)

        loss, gate_target = self.loss(
            spec_pred_dec=spec_pred_dec,
            spec_pred_postnet=spec_pred_postnet,
            gate_pred=gate_pred,
            spec_target=spec_target,
            spec_target_len=spec_target_len,
            pad_value=self.pad_value,
        )
        return {
            "val_loss": loss,
            "mel_target": spec_target,
            "mel_postnet": spec_pred_postnet,
            "gate": gate_pred,
            "gate_target": gate_target,
            "alignments": alignments,
        }

    def validation_epoch_end(self, outputs):
        if self.logger is not None and self.logger.experiment is not None:
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            tacotron2_log_to_tb_func(
                tb_logger,
                outputs[0].values(),
                self.global_step,
                tag="val",
                log_images=True,
                add_audio=False,
            )
        avg_loss = torch.stack([
            x['val_loss'] for x in outputs
        ]).mean()  # This reduces across batches, not workers!
        self.log('val_loss', avg_loss)

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        labels = self._cfg.labels

        dataset = instantiate(cfg.dataset,
                              labels=labels,
                              bos_id=len(labels),
                              eos_id=len(labels) + 1,
                              pad_id=len(labels) + 2)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="Tacotron2-22050Hz",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemottsmodels/versions/1.0.0a5/files/Tacotron2-22050Hz.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz, and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)
        return list_of_models