Beispiel #1
0
class FastSpeech2HifiGanE2EModel(TextToWaveform):
    """An end-to-end speech synthesis model based on FastSpeech2 and HiFiGan that converts strings to audio without
    using the intermediate mel spectrogram representation."""
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        self.encoder = instantiate(cfg.encoder)
        self.variance_adapter = instantiate(cfg.variance_adaptor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()

        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)
        self.mel_val_loss = L1MelLoss()
        self.durationloss = DurationLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()
        self.mseloss = torch.nn.MSELoss()

        self.energy = cfg.add_energy_predictor
        self.pitch = cfg.add_pitch_predictor
        self.mel_loss_coeff = cfg.mel_loss_coeff
        self.pitch_loss_coeff = cfg.pitch_loss_coeff
        self.energy_loss_coeff = cfg.energy_loss_coeff
        self.splice_length = cfg.splice_length

        self.use_energy_pred = False
        self.use_pitch_pred = False
        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size

        # Parser and mappings are used for inference only.
        self.parser = parsers.make_parser(name='en')
        if 'mappings_filepath' in cfg:
            mappings_filepath = cfg.get('mappings_filepath')
        else:
            logging.error(
                "ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath."
            )
        mappings_filepath = self.register_artifact('mappings_filepath',
                                                   mappings_filepath)
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']

    @property
    def tb_logger(self):
        if self._tb_logger is None:
            if self.logger is None and self.logger.experiment is None:
                return None
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            self._tb_logger = tb_logger
        return self._tb_logger

    def configure_optimizers(self):
        gen_params = chain(
            self.encoder.parameters(),
            self.generator.parameters(),
            self.variance_adapter.parameters(),
        )
        disc_params = chain(self.multiscaledisc.parameters(),
                            self.multiperioddisc.parameters())
        opt1 = torch.optim.AdamW(disc_params, lr=self._cfg.lr)
        opt2 = torch.optim.AdamW(gen_params, lr=self._cfg.lr)
        num_procs = self._trainer.num_gpus * self._trainer.num_nodes
        num_samples = len(self._train_dl.dataset)
        batch_size = self._train_dl.batch_size
        iter_per_epoch = np.ceil(num_samples / (num_procs * batch_size))
        max_steps = iter_per_epoch * self._trainer.max_epochs
        logging.info(f"MAX STEPS: {max_steps}")
        sch1 = NoamAnnealing(opt1,
                             d_model=256,
                             warmup_steps=3000,
                             max_steps=max_steps,
                             min_lr=1e-5)
        sch1_dict = {
            'scheduler': sch1,
            'interval': 'step',
        }
        sch2 = NoamAnnealing(opt2,
                             d_model=256,
                             warmup_steps=3000,
                             max_steps=max_steps,
                             min_lr=1e-5)
        sch2_dict = {
            'scheduler': sch2,
            'interval': 'step',
        }
        return [opt1, opt2], [sch1_dict, sch2_dict]

    @typecheck(
        input_types={
            "text":
            NeuralType(('B', 'T'), TokenIndex()),
            "text_length":
            NeuralType(('B'), LengthsType()),
            "splice":
            NeuralType(optional=True),
            "spec_len":
            NeuralType(('B'), LengthsType(), optional=True),
            "durations":
            NeuralType(('B', 'T'), TokenDurationType(), optional=True),
            "pitch":
            NeuralType(('B', 'T'), RegressionValuesType(), optional=True),
            "energies":
            NeuralType(('B', 'T'), RegressionValuesType(), optional=True),
        },
        output_types={
            "audio": NeuralType(('B', 'S', 'T'), MelSpectrogramType()),
            "splices": NeuralType(),
            "log_dur_preds": NeuralType(('B', 'T'), TokenLogDurationType()),
            "pitch_preds": NeuralType(('B', 'T'), RegressionValuesType()),
            "energy_preds": NeuralType(('B', 'T'), RegressionValuesType()),
            "encoded_text_mask": NeuralType(('B', 'T', 'D'), MaskType()),
        },
    )
    def forward(self,
                *,
                text,
                text_length,
                splice=True,
                durations=None,
                pitch=None,
                energies=None,
                spec_len=None):
        encoded_text, encoded_text_mask = self.encoder(text=text,
                                                       text_length=text_length)

        context, log_dur_preds, pitch_preds, energy_preds, spec_len = self.variance_adapter(
            x=encoded_text,
            x_len=text_length,
            dur_target=durations,
            pitch_target=pitch,
            energy_target=energies,
            spec_len=spec_len,
        )

        gen_in = context
        splices = None
        if splice:
            # Splice generated spec
            output = []
            splices = []
            for i, sample in enumerate(context):
                start = np.random.randint(
                    low=0,
                    high=min(int(sample.size(0)), int(spec_len[i])) -
                    self.splice_length)
                output.append(sample[start:start + self.splice_length, :])
                splices.append(start)
            gen_in = torch.stack(output)

        output = self.generator(x=gen_in.transpose(1, 2))

        return output, splices, log_dur_preds, pitch_preds, energy_preds, encoded_text_mask

    def training_step(self, batch, batch_idx, optimizer_idx):
        f, fl, t, tl, durations, pitch, energies = batch
        spec, spec_len = self.audio_to_melspec_precessor(f, fl)

        # train discriminator
        if optimizer_idx == 0:
            with torch.no_grad():
                audio_pred, splices, _, _, _, _ = self(
                    spec=spec,
                    spec_len=spec_len,
                    text=t,
                    text_length=tl,
                    durations=durations,
                    pitch=pitch if not self.use_pitch_pred else None,
                    energies=energies if not self.use_energy_pred else None,
                )
                real_audio = []
                for i, splice in enumerate(splices):
                    real_audio.append(
                        f[i, splice *
                          self.hop_size:(splice + self.splice_length) *
                          self.hop_size])
                real_audio = torch.stack(real_audio).unsqueeze(1)

            real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(
                real_audio, audio_pred)
            real_score_ms, gen_score_ms, _, _ = self.multiscaledisc(
                real_audio, audio_pred)

            loss_mp, loss_mp_real, _ = self.disc_loss(real_score_mp,
                                                      gen_score_mp)
            loss_ms, loss_ms_real, _ = self.disc_loss(real_score_ms,
                                                      gen_score_ms)
            loss_mp /= len(loss_mp_real)
            loss_ms /= len(loss_ms_real)
            loss_disc = loss_mp + loss_ms

            self.log("loss_discriminator", loss_disc, prog_bar=True)
            self.log("loss_discriminator_ms", loss_ms)
            self.log("loss_discriminator_mp", loss_mp)
            return loss_disc

        # train generator
        elif optimizer_idx == 1:
            audio_pred, splices, log_dur_preds, pitch_preds, energy_preds, encoded_text_mask = self(
                spec=spec,
                spec_len=spec_len,
                text=t,
                text_length=tl,
                durations=durations,
                pitch=pitch if not self.use_pitch_pred else None,
                energies=energies if not self.use_energy_pred else None,
            )
            real_audio = []
            for i, splice in enumerate(splices):
                real_audio.append(
                    f[i, splice * self.hop_size:(splice + self.splice_length) *
                      self.hop_size])
            real_audio = torch.stack(real_audio).unsqueeze(1)

            # Do HiFiGAN generator loss
            audio_length = torch.tensor([
                self.splice_length * self.hop_size
                for _ in range(real_audio.shape[0])
            ]).to(real_audio.device)
            real_spliced_spec, _ = self.melspec_fn(real_audio.squeeze(),
                                                   seq_len=audio_length)
            pred_spliced_spec, _ = self.melspec_fn(audio_pred.squeeze(),
                                                   seq_len=audio_length)
            loss_mel = torch.nn.functional.l1_loss(real_spliced_spec,
                                                   pred_spliced_spec)
            loss_mel *= self.mel_loss_coeff
            _, gen_score_mp, real_feat_mp, gen_feat_mp = self.multiperioddisc(
                real_audio, audio_pred)
            _, gen_score_ms, real_feat_ms, gen_feat_ms = self.multiscaledisc(
                real_audio, audio_pred)
            loss_gen_mp, list_loss_gen_mp = self.gen_loss(gen_score_mp)
            loss_gen_ms, list_loss_gen_ms = self.gen_loss(gen_score_ms)
            loss_gen_mp /= len(list_loss_gen_mp)
            loss_gen_ms /= len(list_loss_gen_ms)
            total_loss = loss_gen_mp + loss_gen_ms + loss_mel
            loss_feat_mp = self.feat_matching_loss(real_feat_mp, gen_feat_mp)
            loss_feat_ms = self.feat_matching_loss(real_feat_ms, gen_feat_ms)
            total_loss += loss_feat_mp + loss_feat_ms
            self.log(name="loss_gen_disc_feat",
                     value=loss_feat_mp + loss_feat_ms)
            self.log(name="loss_gen_disc_feat_ms", value=loss_feat_ms)
            self.log(name="loss_gen_disc_feat_mp", value=loss_feat_mp)

            self.log(name="loss_gen_mel", value=loss_mel)
            self.log(name="loss_gen_disc", value=loss_gen_mp + loss_gen_ms)
            self.log(name="loss_gen_disc_mp", value=loss_gen_mp)
            self.log(name="loss_gen_disc_ms", value=loss_gen_ms)

            dur_loss = self.durationloss(log_duration_pred=log_dur_preds,
                                         duration_target=durations.float(),
                                         mask=encoded_text_mask)
            self.log(name="loss_gen_duration", value=dur_loss)
            total_loss += dur_loss
            if self.pitch:
                pitch_loss = self.mseloss(
                    pitch_preds, pitch.float()) * self.pitch_loss_coeff
                total_loss += pitch_loss
                self.log(name="loss_gen_pitch", value=pitch_loss)
            if self.energy:
                energy_loss = self.mseloss(energy_preds,
                                           energies) * self.energy_loss_coeff
                total_loss += energy_loss
                self.log(name="loss_gen_energy", value=energy_loss)

            # Log images to tensorboard
            if self.log_train_images:
                self.log_train_images = False
                if self.logger is not None and self.logger.experiment is not None:
                    self.tb_logger.add_image(
                        "train_mel_target",
                        plot_spectrogram_to_numpy(
                            real_spliced_spec[0].data.cpu().numpy()),
                        self.global_step,
                        dataformats="HWC",
                    )
                    spec_predict = pred_spliced_spec[0].data.cpu().numpy()
                    self.tb_logger.add_image(
                        "train_mel_predicted",
                        plot_spectrogram_to_numpy(spec_predict),
                        self.global_step,
                        dataformats="HWC",
                    )
            self.log(name="loss_gen", prog_bar=True, value=total_loss)
            return total_loss

    def validation_step(self, batch, batch_idx):
        f, fl, t, tl, _, _, _ = batch
        spec, spec_len = self.audio_to_melspec_precessor(f, fl)
        audio_pred, _, _, _, _, _ = self(spec=spec,
                                         spec_len=spec_len,
                                         text=t,
                                         text_length=tl,
                                         splice=False)
        audio_pred.squeeze_()
        pred_spec, _ = self.melspec_fn(audio_pred, seq_len=spec_len)
        loss = self.mel_val_loss(spec_pred=pred_spec,
                                 spec_target=spec,
                                 spec_target_len=spec_len,
                                 pad_value=-11.52)

        return {
            "val_loss": loss,
            "audio_target": f.squeeze() if batch_idx == 0 else None,
            "audio_pred": audio_pred if batch_idx == 0 else None,
        }

    def on_train_epoch_start(self):
        # Switch to using energy predictions after 50% of training
        if not self.use_energy_pred and self.current_epoch >= np.ceil(
                0.5 * self._trainer.max_epochs):
            logging.info(
                f"Using energy predictions after epoch: {self.current_epoch}")
            self.use_energy_pred = True

        # Switch to using pitch predictions after 62.5% of training
        if not self.use_pitch_pred and self.current_epoch >= np.ceil(
                0.625 * self._trainer.max_epochs):
            logging.info(
                f"Using pitch predictions after epoch: {self.current_epoch}")
            self.use_pitch_pred = True

    def validation_epoch_end(self, outputs):
        if self.tb_logger is not None:
            _, audio_target, audio_predict = outputs[0].values()
            if not self.logged_real_samples:
                self.tb_logger.add_audio("val_target",
                                         audio_target[0].data.cpu(),
                                         self.global_step, self.sample_rate)
                self.logged_real_samples = True
            audio_predict = audio_predict[0].data.cpu()
            self.tb_logger.add_audio("val_pred", audio_predict,
                                     self.global_step, self.sample_rate)
        avg_loss = torch.stack([
            x['val_loss'] for x in outputs
        ]).mean()  # This reduces across batches, not workers!
        self.log('val_loss', avg_loss, sync_dist=True)

        self.log_train_images = True

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    def parse(self,
              str_input: str,
              additional_word2phones=None) -> torch.tensor:
        """
        Parses text input and converts them to phoneme indices.

        str_input (str): The input text to be converted.
        additional_word2phones (dict): Optional dictionary mapping words to phonemes for updating the model's
            word2phones.  This will not overwrite the existing dictionary, just update it with OOV or new mappings.
            Defaults to None, which will keep the existing mapping.
        """
        # Update model's word2phones if applicable
        if additional_word2phones is not None:
            self.word2phones.update(additional_word2phones)

        # Convert text -> normalized text -> list of phones per word -> indices
        if str_input[-1] not in [".", "!", "?"]:
            str_input = str_input + "."
        norm_text = re.findall(r"""[\w']+|[.,!?;"]""",
                               self.parser._normalize(str_input))

        try:
            phones = [self.word2phones[t] for t in norm_text]
        except KeyError as error:
            logging.error(
                f"ERROR: The following word in the input is not in the model's dictionary and could not be converted"
                f" to phonemes: ({error}).\n"
                f"You can pass in an `additional_word2phones` dictionary with a conversion for"
                f" this word, e.g. {{'{error}': \['phone1', 'phone2', ...\]}} to update the model's mapping."
            )
            raise

        tokens = []
        for phone_list in phones:
            inds = [self.phone2idx[p] for p in phone_list]
            tokens += inds

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    def convert_text_to_waveform(self, *, tokens):
        """
        Accepts tokens returned from self.parse() and returns a list of tensors. Note: The tensors in the list can have
        different lengths.
        """
        self.eval()
        token_len = torch.tensor([len(i) for i in tokens]).to(self.device)
        audio, _, log_dur_pred, _, _, _ = self(text=tokens,
                                               text_length=token_len,
                                               splice=False)
        audio = audio.squeeze(1)
        durations = torch.sum(torch.exp(log_dur_pred) - 1, 1).to(torch.int)
        audio_list = []
        for i, sample in enumerate(audio):
            audio_list.append(sample[:durations[i] * self.hop_size])

        return audio_list

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_e2e_fastspeech2hifigan",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_e2e_fastspeech2hifigan/versions/1.0.0/files/tts_en_e2e_fastspeech2hifigan.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models
Beispiel #2
0
class HifiGanModel(Vocoder):
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator()
        self.msd = MultiScaleDiscriminator()
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        if isinstance(self._train_dl.dataset, MelAudioDataset):
            self.finetune = True
            logging.info("fine-tuning on pre-computed mels")
        else:
            self.finetune = False
            logging.info("training on ground-truth mels")

    def configure_optimizers(self):
        self.optim_g = instantiate(
            self._cfg.optim,
            params=self.generator.parameters(),
        )
        self.optim_d = instantiate(
            self._cfg.optim,
            params=itertools.chain(self.msd.parameters(),
                                   self.mpd.parameters()),
        )

        max_steps = self._cfg.max_steps
        warmup_steps = 0 if self.finetune else np.ceil(0.2 * max_steps)
        self.scheduler_g = CosineAnnealing(
            self.optim_g,
            max_steps=max_steps,
            min_lr=1e-5,
            warmup_steps=warmup_steps)  # Use warmup to delay start
        sch1_dict = {
            'scheduler': self.scheduler_g,
            'interval': 'step',
        }
        self.scheduler_d = CosineAnnealing(self.optim_d,
                                           max_steps=max_steps,
                                           min_lr=1e-5)
        sch2_dict = {
            'scheduler': self.scheduler_d,
            'interval': 'step',
        }

        return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]

    @property
    def input_types(self):
        return {
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
        }

    @property
    def output_types(self):
        return {
            "audio": NeuralType(('B', 'S', 'T'),
                                AudioSignal(self.sample_rate)),
        }

    @typecheck()
    def forward(self, *, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)

    @typecheck(output_types={"audio": NeuralType(('B', 'T'), AudioSignal())})
    def convert_spectrogram_to_audio(self,
                                     spec: 'torch.tensor') -> 'torch.tensor':
        return self(spec=spec).squeeze(1)

    def training_step(self, batch, batch_idx, optimizer_idx):
        # if in finetune mode the mels are pre-computed using a
        # spectrogram generator
        if self.finetune:
            audio, audio_len, audio_mel = batch
        # else, we compute the mel using the ground truth audio
        else:
            audio, audio_len = batch
            # mel as input for generator
            audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)

        # mel as input for L1 mel loss
        audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
        audio = audio.unsqueeze(1)

        audio_pred = self.generator(x=audio_mel)
        audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1),
                                                audio_len)

        # train discriminator
        self.optim_d.zero_grad()
        mpd_score_real, mpd_score_gen, _, _ = self.mpd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_mpd, _, _ = self.discriminator_loss(
            disc_real_outputs=mpd_score_real,
            disc_generated_outputs=mpd_score_gen)
        msd_score_real, msd_score_gen, _, _ = self.msd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_msd, _, _ = self.discriminator_loss(
            disc_real_outputs=msd_score_real,
            disc_generated_outputs=msd_score_gen)
        loss_d = loss_disc_msd + loss_disc_mpd
        self.manual_backward(loss_d, self.optim_d)
        self.optim_d.step()

        # train generator
        self.optim_g.zero_grad()
        loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel) * 45
        _, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(
            y=audio, y_hat=audio_pred)
        _, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(
            y=audio, y_hat=audio_pred)
        loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real,
                                        fmap_g=fmap_mpd_gen)
        loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real,
                                        fmap_g=fmap_msd_gen)
        loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
        loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
        loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel
        self.manual_backward(loss_g, self.optim_g)
        self.optim_g.step()

        metrics = {
            "g_l1_loss": loss_mel,
            "g_loss_fm_mpd": loss_fm_mpd,
            "g_loss_fm_msd": loss_fm_msd,
            "g_loss_gen_mpd": loss_gen_mpd,
            "g_loss_gen_msd": loss_gen_msd,
            "g_loss": loss_g,
            "d_loss_mpd": loss_disc_mpd,
            "d_loss_msd": loss_disc_msd,
            "d_loss": loss_d,
            "global_step": self.global_step,
            "lr": self.optim_g.param_groups[0]['lr'],
        }
        self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        if self.finetune:
            audio, audio_len, audio_mel = batch
            audio_mel_len = [audio_mel.shape[1]] * audio_mel.shape[0]
        else:
            audio, audio_len = batch
            audio_mel, audio_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred = self(spec=audio_mel)

        # perform bias denoising
        pred_denoised = self._bias_denoise(audio_pred, audio_mel).squeeze(1)
        pred_denoised_mel, _ = self.audio_to_melspec_precessor(
            pred_denoised, audio_len)

        if self.finetune:
            gt_mel, gt_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred_mel, _ = self.audio_to_melspec_precessor(
            audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)

        # plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger,
                                         WandbLogger) and HAVE_WANDB:
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i,
                                   0, :audio_len[i]].data.cpu().numpy().astype(
                                       'float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        pred_denoised[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"denoised audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"input mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"output mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(pred_denoised_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"denoised mel {i}",
                    ),
                ]
                if self.finetune:
                    specs += [
                        wandb.Image(
                            plot_spectrogram_to_numpy(gt_mel[
                                i, :, :audio_mel_len[i]].data.cpu().numpy()),
                            caption=f"gt mel {i}",
                        ),
                    ]

            self.logger.experiment.log({
                "audio": clips,
                "specs": specs
            },
                                       commit=False)

    def _bias_denoise(self, audio, mel):
        def stft(x):
            comp = torch.stft(x.squeeze(1),
                              n_fft=1024,
                              hop_length=256,
                              win_length=1024)
            real, imag = comp[..., 0], comp[..., 1]
            mags = torch.sqrt(real**2 + imag**2)
            phase = torch.atan2(imag, real)
            return mags, phase

        def istft(mags, phase):
            comp = torch.stack(
                [mags * torch.cos(phase), mags * torch.sin(phase)], dim=-1)
            x = torch.istft(comp, n_fft=1024, hop_length=256, win_length=1024)
            return x

        # create bias tensor
        if self.stft_bias is None:
            audio_bias = self(spec=torch.zeros_like(mel, device=mel.device))
            self.stft_bias, _ = stft(audio_bias)
            self.stft_bias = self.stft_bias[:, :, 0][:, :, None]

        audio_mags, audio_phase = stft(audio)
        audio_mags = audio_mags - self.cfg.denoise_strength * self.stft_bias
        audio_mags = torch.clamp(audio_mags, 0.0)
        audio_denoised = istft(audio_mags, audio_phase).unsqueeze(1)

        return audio_denoised

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'Optional[Dict[str, str]]':
        # TODO
        pass
class FastPitchHifiGanE2EModel(TextToWaveform):
    """An end-to-end speech synthesis model based on FastPitch and HiFiGan that converts strings to audio without using
    the intermediate mel spectrogram representation.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)

        self._parser = parsers.make_parser(
            labels=cfg.labels,
            name='en',
            unk_id=-1,
            blank_id=-1,
            do_normalize=True,
            abbreviation_version="fastpitch",
            make_table=False,
        )

        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(FastPitchHifiGanE2EConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.preprocessor = instantiate(cfg.preprocessor)
        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)

        self.encoder = instantiate(cfg.input_fft)
        self.duration_predictor = instantiate(cfg.duration_predictor)
        self.pitch_predictor = instantiate(cfg.pitch_predictor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()
        self.mel_val_loss = L1MelLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()

        self.max_token_duration = cfg.max_token_duration

        self.pitch_emb = torch.nn.Conv1d(
            1,
            cfg.symbols_embedding_dim,
            kernel_size=cfg.pitch_embedding_kernel_size,
            padding=int((cfg.pitch_embedding_kernel_size - 1) / 2),
        )

        # Store values precomputed from training data for convenience
        self.register_buffer('pitch_mean', torch.zeros(1))
        self.register_buffer('pitch_std', torch.zeros(1))

        self.pitchloss = PitchLoss()
        self.durationloss = DurationLoss()

        self.mel_loss_coeff = cfg.mel_loss_coeff

        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.hann_window = None
        self.splice_length = cfg.splice_length
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size

    @property
    def tb_logger(self):
        if self._tb_logger is None:
            if self.logger is None and self.logger.experiment is None:
                return None
            tb_logger = self.logger.experiment
            if isinstance(self.logger, LoggerCollection):
                for logger in self.logger:
                    if isinstance(logger, TensorBoardLogger):
                        tb_logger = logger.experiment
                        break
            self._tb_logger = tb_logger
        return self._tb_logger

    @property
    def parser(self):
        if self._parser is not None:
            return self._parser

        self._parser = parsers.make_parser(
            labels=self._cfg.labels,
            name='en',
            unk_id=-1,
            blank_id=-1,
            do_normalize=True,
            abbreviation_version="fastpitch",
            make_table=False,
        )
        return self._parser

    def parse(self, str_input: str) -> torch.tensor:
        if str_input[-1] not in [".", "!", "?"]:
            str_input = str_input + "."

        tokens = self.parser(str_input)

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    def configure_optimizers(self):
        gen_params = chain(
            self.pitch_emb.parameters(),
            self.encoder.parameters(),
            self.duration_predictor.parameters(),
            self.pitch_predictor.parameters(),
            self.generator.parameters(),
        )
        disc_params = chain(self.multiscaledisc.parameters(),
                            self.multiperioddisc.parameters())
        opt1 = torch.optim.AdamW(disc_params, lr=self._cfg.lr)
        opt2 = torch.optim.AdamW(gen_params, lr=self._cfg.lr)
        num_procs = self._trainer.num_gpus * self._trainer.num_nodes
        num_samples = len(self._train_dl.dataset)
        batch_size = self._train_dl.batch_size
        iter_per_epoch = np.ceil(num_samples / (num_procs * batch_size))
        max_steps = iter_per_epoch * self._trainer.max_epochs
        logging.info(f"MAX STEPS: {max_steps}")
        sch1 = NoamAnnealing(opt1,
                             d_model=1,
                             warmup_steps=1000,
                             max_steps=max_steps,
                             last_epoch=-1)
        sch1_dict = {
            'scheduler': sch1,
            'interval': 'step',
        }
        sch2 = NoamAnnealing(opt2,
                             d_model=1,
                             warmup_steps=1000,
                             max_steps=max_steps,
                             last_epoch=-1)
        sch2_dict = {
            'scheduler': sch2,
            'interval': 'step',
        }
        return [opt1, opt2], [sch1_dict, sch2_dict]

    @typecheck(
        input_types={
            "text": NeuralType(('B', 'T'), TokenIndex()),
            "durs": NeuralType(('B', 'T'), TokenDurationType(), optional=True),
            "pitch": NeuralType(('B', 'T'),
                                RegressionValuesType(),
                                optional=True),
            "pace": NeuralType(optional=True),
            "splice": NeuralType(optional=True),
        },
        output_types={
            "audio": NeuralType(('B', 'S', 'T'), AudioSignal()),
            "splices": NeuralType(),
            "log_dur_preds": NeuralType(('B', 'T'), TokenLogDurationType()),
            "pitch_preds": NeuralType(('B', 'T'), RegressionValuesType()),
        },
    )
    def forward(self, *, text, durs=None, pitch=None, pace=1.0, splice=True):
        if self.training:
            assert durs is not None
            assert pitch is not None

        # Input FFT
        enc_out, enc_mask = self.encoder(input=text, conditioning=0)

        # Embedded for predictors
        pred_enc_out, pred_enc_mask = enc_out, enc_mask

        # Predict durations
        log_durs_predicted = self.duration_predictor(pred_enc_out,
                                                     pred_enc_mask)
        durs_predicted = torch.clamp(
            torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)

        # Predict pitch
        pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
        if pitch is None:
            pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
        else:
            pitch_emb = self.pitch_emb(pitch.unsqueeze(1))
        enc_out = enc_out + pitch_emb.transpose(1, 2)

        if durs is None:
            len_regulated, dec_lens = regulate_len(durs_predicted, enc_out,
                                                   pace)
        else:
            len_regulated, dec_lens = regulate_len(durs, enc_out, pace)

        gen_in = len_regulated
        splices = []
        if splice:
            output = []
            for i, sample in enumerate(len_regulated):
                start = np.random.randint(
                    low=0,
                    high=min(int(sample.size(0)), int(dec_lens[i])) -
                    self.splice_length)
                # Splice generated spec
                output.append(sample[start:start + self.splice_length, :])
                splices.append(start)
            gen_in = torch.stack(output)

        output = self.generator(x=gen_in.transpose(1, 2))

        return output, torch.tensor(
            splices), log_durs_predicted, pitch_predicted

    def training_step(self, batch, batch_idx, optimizer_idx):
        audio, _, text, text_lens, durs, pitch, _ = batch

        # train discriminator
        if optimizer_idx == 0:
            with torch.no_grad():
                audio_pred, splices, _, _ = self(text=text,
                                                 durs=durs,
                                                 pitch=pitch)
                real_audio = []
                for i, splice in enumerate(splices):
                    real_audio.append(
                        audio[i, splice *
                              self.hop_size:(splice + self.splice_length) *
                              self.hop_size])
                real_audio = torch.stack(real_audio).unsqueeze(1)

            real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(
                real_audio, audio_pred)
            real_score_ms, gen_score_ms, _, _ = self.multiscaledisc(
                real_audio, audio_pred)

            loss_mp, loss_mp_real, _ = self.disc_loss(
                disc_real_outputs=real_score_mp,
                disc_generated_outputs=gen_score_mp)
            loss_ms, loss_ms_real, _ = self.disc_loss(
                disc_real_outputs=real_score_ms,
                disc_generated_outputs=gen_score_ms)
            loss_mp /= len(loss_mp_real)
            loss_ms /= len(loss_ms_real)
            loss_disc = loss_mp + loss_ms

            self.log("loss_discriminator", loss_disc, prog_bar=True)
            self.log("loss_discriminator_ms", loss_ms)
            self.log("loss_discriminator_mp", loss_mp)
            return loss_disc

        # train generator
        elif optimizer_idx == 1:
            audio_pred, splices, log_dur_preds, pitch_preds = self(text=text,
                                                                   durs=durs,
                                                                   pitch=pitch)
            real_audio = []
            for i, splice in enumerate(splices):
                real_audio.append(
                    audio[i, splice *
                          self.hop_size:(splice + self.splice_length) *
                          self.hop_size])
            real_audio = torch.stack(real_audio).unsqueeze(1)

            dur_loss = self.durationloss(log_durs_predicted=log_dur_preds,
                                         durs_tgt=durs,
                                         len=text_lens)
            pitch_loss = self.pitchloss(
                pitch_predicted=pitch_preds,
                pitch_tgt=pitch,
            )

            # Do HiFiGAN generator loss
            audio_length = torch.tensor([
                self.splice_length * self.hop_size
                for _ in range(real_audio.shape[0])
            ]).to(real_audio.device)
            real_spliced_spec, _ = self.melspec_fn(real_audio.squeeze(),
                                                   audio_length)
            pred_spliced_spec, _ = self.melspec_fn(audio_pred.squeeze(),
                                                   audio_length)
            loss_mel = torch.nn.functional.l1_loss(real_spliced_spec,
                                                   pred_spliced_spec)
            loss_mel *= self.mel_loss_coeff
            _, gen_score_mp, _, _ = self.multiperioddisc(
                real_audio, audio_pred)
            _, gen_score_ms, _, _ = self.multiscaledisc(y=real_audio,
                                                        y_hat=audio_pred)
            loss_gen_mp, list_loss_gen_mp = self.gen_loss(
                disc_outputs=gen_score_mp)
            loss_gen_ms, list_loss_gen_ms = self.gen_loss(
                disc_outputs=gen_score_ms)
            loss_gen_mp /= len(list_loss_gen_mp)
            loss_gen_ms /= len(list_loss_gen_ms)
            total_loss = loss_gen_mp + loss_gen_ms + loss_mel
            total_loss += dur_loss
            total_loss += pitch_loss

            self.log(name="loss_gen_mel", value=loss_mel)
            self.log(name="loss_gen_disc", value=loss_gen_mp + loss_gen_ms)
            self.log(name="loss_gen_disc_mp", value=loss_gen_mp)
            self.log(name="loss_gen_disc_ms", value=loss_gen_ms)
            self.log(name="loss_gen_duration", value=dur_loss)
            self.log(name="loss_gen_pitch", value=pitch_loss)

            # Log images to tensorboard
            if self.log_train_images:
                self.log_train_images = False

                if self.logger is not None and self.logger.experiment is not None:
                    self.tb_logger.add_image(
                        "train_mel_target",
                        plot_spectrogram_to_numpy(
                            real_spliced_spec[0].data.cpu().numpy()),
                        self.global_step,
                        dataformats="HWC",
                    )
                    spec_predict = pred_spliced_spec[0].data.cpu().numpy()
                    self.tb_logger.add_image(
                        "train_mel_predicted",
                        plot_spectrogram_to_numpy(spec_predict),
                        self.global_step,
                        dataformats="HWC",
                    )
            self.log(name="loss_gen", prog_bar=True, value=total_loss)
            return total_loss

    def validation_step(self, batch, batch_idx):
        audio, audio_lens, text, _, _, _, _ = batch
        mels, mel_lens = self.preprocessor(audio, audio_lens)

        audio_pred, _, log_durs_predicted, _ = self(text=text,
                                                    durs=None,
                                                    pitch=None,
                                                    splice=False)
        audio_length = torch.sum(torch.clamp(torch.exp(log_durs_predicted - 1),
                                             0),
                                 axis=1)
        audio_pred.squeeze_()
        pred_spec, _ = self.melspec_fn(audio_pred, audio_length)
        loss = self.mel_val_loss(spec_pred=pred_spec,
                                 spec_target=mels,
                                 spec_target_len=mel_lens,
                                 pad_value=-11.52,
                                 transpose=False)

        return {
            "val_loss": loss,
            "audio_target": audio if batch_idx == 0 else None,
            "audio_pred": audio_pred.squeeze() if batch_idx == 0 else None,
        }

    def validation_epoch_end(self, outputs):
        if self.tb_logger is not None:
            _, audio_target, audio_predict = outputs[0].values()
            if not self.logged_real_samples:
                self.tb_logger.add_audio("val_target",
                                         audio_target[0].data.cpu(),
                                         self.global_step, self.sample_rate)
                self.logged_real_samples = True
            audio_predict = audio_predict[0].data.cpu()
            self.tb_logger.add_audio("val_pred", audio_predict,
                                     self.global_step, self.sample_rate)
        avg_loss = torch.stack([
            x['val_loss'] for x in outputs
        ]).mean()  # This reduces across batches, not workers!
        self.log('val_loss', avg_loss, sync_dist=True)

        self.log_train_images = True

    def _loader(self, cfg):
        dataset = FastPitchDataset(
            manifest_filepath=cfg['manifest_filepath'],
            parser=self.parser,
            sample_rate=cfg['sample_rate'],
            int_values=cfg.get('int_values', False),
            max_duration=cfg.get('max_duration', None),
            min_duration=cfg.get('min_duration', None),
            max_utts=cfg.get('max_utts', 0),
            trim=cfg.get('trim_silence', True),
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg['batch_size'],
            collate_fn=dataset.collate_fn,
            drop_last=cfg.get('drop_last', True),
            shuffle=cfg['shuffle'],
            num_workers=cfg.get('num_workers', 16),
        )

    def setup_training_data(self, cfg):
        self._train_dl = self._loader(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self._loader(cfg)

    def setup_test_data(self, cfg):
        """Omitted."""
        pass

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_e2e_fastpitchhifigan",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_e2e_fastpitchhifigan/versions/1.0.0/files/tts_en_e2e_fastpitchhifigan.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models

    def convert_text_to_waveform(self, *, tokens):
        """
        Accepts tokens returned from self.parse() and returns a list of tensors. Note: The tensors in the list can have
        different lengths.
        """
        self.eval()
        audio, _, log_dur_pred, _ = self(text=tokens, splice=False)
        audio = audio.squeeze(1)
        durations = torch.sum(
            torch.clamp(
                torch.exp(log_dur_pred) - 1, 0, self.max_token_duration),
            1).to(torch.int)
        audio_list = []
        for i, sample in enumerate(audio):
            audio_list.append(sample[:durations[i] * self.hop_size])

        return audio_list
Beispiel #4
0
class HifiGanModel(Vocoder, Exportable):
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator()
        self.msd = MultiScaleDiscriminator()
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.l1_factor = cfg.get("l1_loss_factor", 45)

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        if self._train_dl and isinstance(self._train_dl.dataset,
                                         MelAudioDataset):
            self.input_as_mel = True
        else:
            self.input_as_mel = False

        self.automatic_optimization = False

    def configure_optimizers(self):
        self.optim_g = instantiate(
            self._cfg.optim,
            params=self.generator.parameters(),
        )
        self.optim_d = instantiate(
            self._cfg.optim,
            params=itertools.chain(self.msd.parameters(),
                                   self.mpd.parameters()),
        )

        self.scheduler_g = CosineAnnealing(
            optimizer=self.optim_g,
            max_steps=self._cfg.max_steps,
            min_lr=self._cfg.sched.min_lr,
            warmup_steps=self._cfg.sched.warmup_ratio * self._cfg.max_steps,
        )  # Use warmup to delay start
        sch1_dict = {
            'scheduler': self.scheduler_g,
            'interval': 'step',
        }

        self.scheduler_d = CosineAnnealing(
            optimizer=self.optim_d,
            max_steps=self._cfg.max_steps,
            min_lr=self._cfg.sched.min_lr,
        )
        sch2_dict = {
            'scheduler': self.scheduler_d,
            'interval': 'step',
        }

        return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]

    @property
    def input_types(self):
        return {
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
        }

    @property
    def output_types(self):
        return {
            "audio": NeuralType(('B', 'S', 'T'),
                                AudioSignal(self.sample_rate)),
        }

    @typecheck()
    def forward(self, *, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)

    def forward_for_export(self, spec):
        return self.generator(x=spec)

    @typecheck(
        input_types={
            "spec": NeuralType(('B', 'C', 'T'), MelSpectrogramType())
        },
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def convert_spectrogram_to_audio(self,
                                     spec: 'torch.tensor') -> 'torch.tensor':
        return self(spec=spec).squeeze(1)

    def training_step(self, batch, batch_idx, optimizer_idx):
        # if in finetune mode the mels are pre-computed using a
        # spectrogram generator
        if self.input_as_mel:
            audio, audio_len, audio_mel = batch
        # else, we compute the mel using the ground truth audio
        else:
            audio, audio_len = batch
            # mel as input for generator
            audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)

        # mel as input for L1 mel loss
        audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
        audio = audio.unsqueeze(1)

        audio_pred = self.generator(x=audio_mel)
        audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1),
                                                audio_len)

        # train discriminator
        self.optim_d.zero_grad()
        mpd_score_real, mpd_score_gen, _, _ = self.mpd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_mpd, _, _ = self.discriminator_loss(
            disc_real_outputs=mpd_score_real,
            disc_generated_outputs=mpd_score_gen)
        msd_score_real, msd_score_gen, _, _ = self.msd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_msd, _, _ = self.discriminator_loss(
            disc_real_outputs=msd_score_real,
            disc_generated_outputs=msd_score_gen)
        loss_d = loss_disc_msd + loss_disc_mpd
        self.manual_backward(loss_d)
        self.optim_d.step()

        # train generator
        self.optim_g.zero_grad()
        loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel)
        _, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(
            y=audio, y_hat=audio_pred)
        _, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(
            y=audio, y_hat=audio_pred)
        loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real,
                                        fmap_g=fmap_mpd_gen)
        loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real,
                                        fmap_g=fmap_msd_gen)
        loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
        loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
        loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel * self.l1_factor
        self.manual_backward(loss_g)
        self.optim_g.step()

        metrics = {
            "g_l1_loss": loss_mel,
            "g_loss_fm_mpd": loss_fm_mpd,
            "g_loss_fm_msd": loss_fm_msd,
            "g_loss_gen_mpd": loss_gen_mpd,
            "g_loss_gen_msd": loss_gen_msd,
            "g_loss": loss_g,
            "d_loss_mpd": loss_disc_mpd,
            "d_loss_msd": loss_disc_msd,
            "d_loss": loss_d,
            "global_step": self.global_step,
            "lr": self.optim_g.param_groups[0]['lr'],
        }
        self.log_dict(metrics, on_step=True, sync_dist=True)
        self.log("g_l1_loss",
                 loss_mel,
                 prog_bar=True,
                 logger=False,
                 sync_dist=True)

    def validation_step(self, batch, batch_idx):
        if self.input_as_mel:
            audio, audio_len, audio_mel = batch
            audio_mel_len = [audio_mel.shape[1]] * audio_mel.shape[0]
        else:
            audio, audio_len = batch
            audio_mel, audio_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred = self(spec=audio_mel)

        # perform bias denoising
        pred_denoised = self._bias_denoise(audio_pred, audio_mel).squeeze(1)
        pred_denoised_mel, _ = self.audio_to_melspec_precessor(
            pred_denoised, audio_len)

        if self.input_as_mel:
            gt_mel, gt_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred_mel, _ = self.audio_to_melspec_precessor(
            audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log_dict({"val_loss": loss_mel}, on_epoch=True, sync_dist=True)

        # plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger,
                                         WandbLogger) and HAVE_WANDB:
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i,
                                   0, :audio_len[i]].data.cpu().numpy().astype(
                                       'float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        pred_denoised[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"denoised audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"input mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"output mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(pred_denoised_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"denoised mel {i}",
                    ),
                ]
                if self.input_as_mel:
                    specs += [
                        wandb.Image(
                            plot_spectrogram_to_numpy(gt_mel[
                                i, :, :audio_mel_len[i]].data.cpu().numpy()),
                            caption=f"gt mel {i}",
                        ),
                    ]

            self.logger.experiment.log({"audio": clips, "specs": specs})

    def _bias_denoise(self, audio, mel):
        def stft(x):
            comp = stft_patch(x.squeeze(1),
                              n_fft=1024,
                              hop_length=256,
                              win_length=1024)
            real, imag = comp[..., 0], comp[..., 1]
            mags = torch.sqrt(real**2 + imag**2)
            phase = torch.atan2(imag, real)
            return mags, phase

        def istft(mags, phase):
            comp = torch.stack(
                [mags * torch.cos(phase), mags * torch.sin(phase)], dim=-1)
            x = torch.istft(comp, n_fft=1024, hop_length=256, win_length=1024)
            return x

        # create bias tensor
        if self.stft_bias is None or self.stft_bias.shape[0] != audio.shape[0]:
            audio_bias = self(spec=torch.zeros_like(mel, device=mel.device))
            self.stft_bias, _ = stft(audio_bias)
            self.stft_bias = self.stft_bias[:, :, 0][:, :, None]

        audio_mags, audio_phase = stft(audio)
        audio_mags = audio_mags - self.cfg.get("denoise_strength",
                                               0.0025) * self.stft_bias
        audio_mags = torch.clamp(audio_mags, 0.0)
        audio_denoised = istft(audio_mags, audio_phase).unsqueeze(1)

        return audio_denoised

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'Optional[Dict[str, str]]':
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_hifigan",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_hifigan/versions/1.0.0rc1/files/tts_hifigan.nemo",
            description=
            "This model is trained on LJSpeech audio sampled at 22050Hz and mel spectrograms generated from Tacotron2, TalkNet, and FastPitch. This model has been tested on generating female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)
        return list_of_models

    def load_state_dict(self, state_dict, strict=True):
        # override load_state_dict to give us some flexibility to be backward-compatible
        # with old checkpoints
        new_state_dict = {}
        num_resblocks = len(self.cfg['generator']['resblock_kernel_sizes'])
        for k, v in state_dict.items():
            new_k = k
            if 'resblocks' in k:
                parts = k.split(".")
                # only do this is the checkpoint type is older
                if len(parts) == 6:
                    layer = int(parts[2])
                    new_layer = f"{layer//num_resblocks}.{layer%num_resblocks}"
                    new_k = f"generator.resblocks.{new_layer}.{'.'.join(parts[3:])}"
            new_state_dict[new_k] = v
        super().load_state_dict(new_state_dict, strict=strict)

    def _prepare_for_export(self, **kwargs):
        """
        Override this method to prepare module for export. This is in-place operation.
        Base version does common necessary module replacements (Apex etc)
        """
        self.generator.remove_weight_norm()

    def input_example(self):
        """
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        """
        par = next(self.parameters())
        mel = torch.randn((1, self.cfg['preprocessor']['nfilt'], 96),
                          device=par.device,
                          dtype=par.dtype)
        return mel
Beispiel #5
0
class HifiGanModel(Vocoder):
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator()
        self.msd = MultiScaleDiscriminator()
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.sample_rate = self._cfg.preprocessor.sample_rate

    def configure_optimizers(self):
        self.optim_g = torch.optim.AdamW(
            self.generator.parameters(),
            self._cfg.optim.lr,
            betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2])
        self.optim_d = torch.optim.AdamW(
            itertools.chain(self.msd.parameters(), self.mpd.parameters()),
            self._cfg.optim.lr,
            betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2],
        )

        self.scheduler_g = torch.optim.lr_scheduler.StepLR(
            self.optim_g,
            step_size=self._cfg.optim.lr_step,
            gamma=self._cfg.optim.lr_decay,
        )
        self.scheduler_d = torch.optim.lr_scheduler.StepLR(
            self.optim_d,
            step_size=self._cfg.optim.lr_step,
            gamma=self._cfg.optim.lr_decay,
        )

        return [self.optim_g,
                self.optim_d], [self.scheduler_g, self.scheduler_d]

    @property
    def input_types(self):
        return {
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
        }

    @property
    def output_types(self):
        return {
            "audio": NeuralType(('B', 'S', 'T'),
                                AudioSignal(self.sample_rate)),
        }

    @typecheck()
    def forward(self, *, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)

    @typecheck(output_types={"audio": NeuralType(('B', 'T'), AudioSignal())})
    def convert_spectrogram_to_audio(self,
                                     spec: 'torch.tensor') -> 'torch.tensor':
        return self(spec=spec).squeeze(1)

    def training_step(self, batch, batch_idx, optimizer_idx):
        audio, audio_len = batch
        # mel as input for generator
        audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)
        # mel as input for L1 mel loss
        audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
        audio = audio.unsqueeze(1)

        audio_pred = self.generator(x=audio_mel)
        audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1),
                                                audio_len)

        # train discriminator
        self.optim_d.zero_grad()
        mpd_score_real, mpd_score_gen, _, _ = self.mpd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_mpd, _, _ = self.discriminator_loss(
            disc_real_outputs=mpd_score_real,
            disc_generated_outputs=mpd_score_gen)
        msd_score_real, msd_score_gen, _, _ = self.msd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_msd, _, _ = self.discriminator_loss(
            disc_real_outputs=msd_score_real,
            disc_generated_outputs=msd_score_gen)
        loss_d = loss_disc_msd + loss_disc_mpd
        self.manual_backward(loss_d, self.optim_d)
        self.optim_d.step()

        # train generator
        self.optim_g.zero_grad()
        loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel) * 45
        _, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(
            y=audio, y_hat=audio_pred)
        _, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(
            y=audio, y_hat=audio_pred)
        loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real,
                                        fmap_g=fmap_mpd_gen)
        loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real,
                                        fmap_g=fmap_msd_gen)
        loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
        loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
        loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel
        self.manual_backward(loss_g, self.optim_g)
        self.optim_g.step()

        metrics = {
            "g_l1_loss": loss_mel,
            "g_loss_fm_mpd": loss_fm_mpd,
            "g_loss_fm_msd": loss_fm_msd,
            "g_loss_gen_mpd": loss_gen_mpd,
            "g_loss_gen_msd": loss_gen_msd,
            "g_loss": loss_g,
            "d_loss_mpd": loss_disc_mpd,
            "d_loss_msd": loss_disc_msd,
            "d_loss": loss_d,
            "global_step": self.global_step,
            "lr": self.optim_g.param_groups[0]['lr'],
        }
        self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        audio, audio_len = batch
        audio_mel, audio_mel_len = self.audio_to_melspec_precessor(
            audio, audio_len)
        audio_pred = self(spec=audio_mel)

        audio_pred_mel, _ = self.audio_to_melspec_precessor(
            audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)

        # plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger, WandbLogger):
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i,
                                   0, :audio_len[i]].data.cpu().numpy().astype(
                                       'float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"real audio {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"generated audio {i}",
                    ),
                ]

            self.logger.experiment.log({
                "audio": clips,
                "specs": specs
            },
                                       commit=False)

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'Optional[Dict[str, str]]':
        # TODO
        pass