コード例 #1
0
ファイル: hifigan.py プロジェクト: Tpt/NeMo
 def output_types(self):
     return {
         "audio": NeuralType(('B', 'S', 'T'), AudioSignal(self.sample_rate)),
     }
コード例 #2
0
ファイル: hifigan.py プロジェクト: Tpt/NeMo
class HifiGanModel(Vocoder):
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator()
        self.msd = MultiScaleDiscriminator()
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        if isinstance(self._train_dl.dataset, MelAudioDataset):
            self.finetune = True
            logging.info("fine-tuning on pre-computed mels")
        else:
            self.finetune = False
            logging.info("training on ground-truth mels")

    def configure_optimizers(self):
        self.optim_g = instantiate(self._cfg.optim, params=self.generator.parameters(),)
        self.optim_d = instantiate(
            self._cfg.optim, params=itertools.chain(self.msd.parameters(), self.mpd.parameters()),
        )

        max_steps = self._cfg.max_steps
        warmup_steps = 0 if self.finetune else np.ceil(0.2 * max_steps)
        self.scheduler_g = CosineAnnealing(
            self.optim_g, max_steps=max_steps, min_lr=1e-5, warmup_steps=warmup_steps
        )  # Use warmup to delay start
        sch1_dict = {
            'scheduler': self.scheduler_g,
            'interval': 'step',
        }
        self.scheduler_d = CosineAnnealing(self.optim_d, max_steps=max_steps, min_lr=1e-5)
        sch2_dict = {
            'scheduler': self.scheduler_d,
            'interval': 'step',
        }

        return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]

    @property
    def input_types(self):
        return {
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
        }

    @property
    def output_types(self):
        return {
            "audio": NeuralType(('B', 'S', 'T'), AudioSignal(self.sample_rate)),
        }

    @typecheck()
    def forward(self, *, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)

    @typecheck(output_types={"audio": NeuralType(('B', 'T'), AudioSignal())})
    def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor':
        return self(spec=spec).squeeze(1)

    def training_step(self, batch, batch_idx, optimizer_idx):
        # if in finetune mode the mels are pre-computed using a
        # spectrogram generator
        if self.finetune:
            audio, audio_len, audio_mel = batch
        # else, we compute the mel using the ground truth audio
        else:
            audio, audio_len = batch
            # mel as input for generator
            audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)

        # mel as input for L1 mel loss
        audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
        audio = audio.unsqueeze(1)

        audio_pred = self.generator(x=audio_mel)
        audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len)

        # train discriminator
        self.optim_d.zero_grad()
        mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach())
        loss_disc_mpd, _, _ = self.discriminator_loss(
            disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen
        )
        msd_score_real, msd_score_gen, _, _ = self.msd(y=audio, y_hat=audio_pred.detach())
        loss_disc_msd, _, _ = self.discriminator_loss(
            disc_real_outputs=msd_score_real, disc_generated_outputs=msd_score_gen
        )
        loss_d = loss_disc_msd + loss_disc_mpd
        self.manual_backward(loss_d, self.optim_d)
        self.optim_d.step()

        # train generator
        self.optim_g.zero_grad()
        loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel) * 45
        _, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred)
        _, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred)
        loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real, fmap_g=fmap_mpd_gen)
        loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real, fmap_g=fmap_msd_gen)
        loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
        loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
        loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel
        self.manual_backward(loss_g, self.optim_g)
        self.optim_g.step()

        metrics = {
            "g_l1_loss": loss_mel,
            "g_loss_fm_mpd": loss_fm_mpd,
            "g_loss_fm_msd": loss_fm_msd,
            "g_loss_gen_mpd": loss_gen_mpd,
            "g_loss_gen_msd": loss_gen_msd,
            "g_loss": loss_g,
            "d_loss_mpd": loss_disc_mpd,
            "d_loss_msd": loss_disc_msd,
            "d_loss": loss_d,
            "global_step": self.global_step,
            "lr": self.optim_g.param_groups[0]['lr'],
        }
        self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        if self.finetune:
            audio, audio_len, audio_mel = batch
            audio_mel_len = [audio_mel.shape[1]] * audio_mel.shape[0]
        else:
            audio, audio_len = batch
            audio_mel, audio_mel_len = self.audio_to_melspec_precessor(audio, audio_len)
        audio_pred = self(spec=audio_mel)

        # perform bias denoising
        pred_denoised = self._bias_denoise(audio_pred, audio_mel).squeeze(1)
        pred_denoised_mel, _ = self.audio_to_melspec_precessor(pred_denoised, audio_len)

        if self.finetune:
            gt_mel, gt_mel_len = self.audio_to_melspec_precessor(audio, audio_len)
        audio_pred_mel, _ = self.audio_to_melspec_precessor(audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)

        # plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger, WandbLogger) and HAVE_WANDB:
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, : audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i, 0, : audio_len[i]].data.cpu().numpy().astype('float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        pred_denoised[i, : audio_len[i]].data.cpu().numpy(),
                        caption=f"denoised audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"input mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"output mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(pred_denoised_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"denoised mel {i}",
                    ),
                ]
                if self.finetune:
                    specs += [
                        wandb.Image(
                            plot_spectrogram_to_numpy(gt_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
                            caption=f"gt mel {i}",
                        ),
                    ]

            self.logger.experiment.log({"audio": clips, "specs": specs}, commit=False)

    def _bias_denoise(self, audio, mel):
        def stft(x):
            comp = stft_patch(x.squeeze(1), n_fft=1024, hop_length=256, win_length=1024)
            real, imag = comp[..., 0], comp[..., 1]
            mags = torch.sqrt(real ** 2 + imag ** 2)
            phase = torch.atan2(imag, real)
            return mags, phase

        def istft(mags, phase):
            comp = torch.stack([mags * torch.cos(phase), mags * torch.sin(phase)], dim=-1)
            x = torch.istft(comp, n_fft=1024, hop_length=256, win_length=1024)
            return x

        # create bias tensor
        if self.stft_bias is None:
            audio_bias = self(spec=torch.zeros_like(mel, device=mel.device))
            self.stft_bias, _ = stft(audio_bias)
            self.stft_bias = self.stft_bias[:, :, 0][:, :, None]

        audio_mags, audio_phase = stft(audio)
        audio_mags = audio_mags - self.cfg.denoise_strength * self.stft_bias
        audio_mags = torch.clamp(audio_mags, 0.0)
        audio_denoised = istft(audio_mags, audio_phase).unsqueeze(1)

        return audio_denoised

    def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True"
                )
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'Optional[Dict[str, str]]':
        # TODO
        pass
コード例 #3
0
class GlowVocoder(Vocoder):
    """ Base class for all Vocoders that use a Glow or reversible Flow-based setup. All child class are expected
    to have a parameter called audio_to_melspec_precessor that is an instance of
    nemo.collections.asr.parts.FilterbankFeatures"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._mode = OperationMode.infer
        self.stft = None
        self.istft = None
        self.n_mel = None
        self.bias_spect = None

    @property
    def mode(self):
        return self._mode

    @contextmanager
    def temp_mode(self, mode):
        old_mode = self.mode
        self.mode = mode
        try:
            yield
        finally:
            self.mode = old_mode

    @contextmanager
    def nemo_infer(
        self
    ):  # Prepend with nemo to avoid any .infer() clashes with lightning or pytorch
        with ExitStack() as stack:
            stack.enter_context(self.temp_mode(OperationMode.infer))
            stack.enter_context(torch.no_grad())
            yield

    def check_children_attributes(self):
        if self.stft is None:
            if isinstance(self.audio_to_melspec_precessor.stft, STFT):
                logging.warning(
                    "torch_stft is deprecated. Please change your model to use torch.stft and torch.istft instead."
                )
                self.stft = self.audio_to_melspec_precessor.stft.transform
                self.istft = self.audio_to_melspec_precessor.stft.inverse
            else:
                try:
                    n_fft = self.audio_to_melspec_precessor.n_fft
                    hop_length = self.audio_to_melspec_precessor.hop_length
                    win_length = self.audio_to_melspec_precessor.win_length
                    window = self.audio_to_melspec_precessor.window.to(
                        self.device)
                except AttributeError as e:
                    raise AttributeError(
                        f"{self} could not find a valid audio_to_melspec_precessor. GlowVocoder requires child class "
                        "to have audio_to_melspec_precessor defined to obtain stft parameters. "
                        "audio_to_melspec_precessor requires n_fft, hop_length, win_length, window, and nfilt to be "
                        "defined.") from e

                def yet_another_patch(audio, n_fft, hop_length, win_length,
                                      window):
                    spec = stft_patch(audio,
                                      n_fft=n_fft,
                                      hop_length=hop_length,
                                      win_length=win_length,
                                      window=window)
                    if spec.dtype in [torch.cfloat, torch.cdouble]:
                        spec = torch.view_as_real(spec)
                    return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(
                        spec[..., -1], spec[..., 0])

                self.stft = lambda x: yet_another_patch(
                    x,
                    n_fft=n_fft,
                    hop_length=hop_length,
                    win_length=win_length,
                    window=window,
                )
                self.istft = lambda x, y: istft_patch(
                    torch.complex(x * torch.cos(y), x * torch.sin(y)),
                    n_fft=n_fft,
                    hop_length=hop_length,
                    win_length=win_length,
                    window=window,
                )

        if self.n_mel is None:
            try:
                self.n_mel = self.audio_to_melspec_precessor.nfilt
            except AttributeError as e:
                raise AttributeError(
                    f"{self} could not find a valid audio_to_melspec_precessor. GlowVocoder requires child class to "
                    "have audio_to_melspec_precessor defined to obtain stft parameters. audio_to_melspec_precessor "
                    "requires nfilt to be defined.") from e

    def update_bias_spect(self):
        self.check_children_attributes()  # Ensure stft parameters are defined

        with self.nemo_infer():
            spect = torch.zeros((1, self.n_mel, 88)).to(self.device)
            bias_audio = self.convert_spectrogram_to_audio(spec=spect,
                                                           sigma=0.0,
                                                           denoise=False)
            bias_spect, _ = self.stft(bias_audio)
            self.bias_spect = bias_spect[..., 0][..., None]

    @typecheck(
        input_types={
            "audio": NeuralType(('B', 'T'), AudioSignal()),
            "strength": NeuralType(optional=True)
        },
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def denoise(self, audio: 'torch.tensor', strength: float = 0.01):
        self.check_children_attributes(
        )  # Ensure self.n_mel and self.stft are defined

        if self.bias_spect is None:
            self.update_bias_spect()
        audio_spect, audio_angles = self.stft(audio)
        audio_spect_denoised = audio_spect - self.bias_spect.to(
            audio.device) * strength
        audio_spect_denoised = torch.clamp(audio_spect_denoised, 0.0)
        audio_denoised = self.istft(audio_spect_denoised, audio_angles)
        return audio_denoised
コード例 #4
0
ファイル: squeezewave.py プロジェクト: climatepals/NeMo
 def input_types(self):
     return {
         "audio": NeuralType(('B', 'T'), AudioSignal()),
         "audio_len": NeuralType(('B'), LengthsType()),
         "run_inverse": NeuralType(optional=True),
     }
コード例 #5
0
ファイル: squeezewave.py プロジェクト: climatepals/NeMo
class SqueezeWaveModel(Vocoder):
    """ SqueezeWave model that generates audio conditioned on mel-spectrogram
    """
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(SqueezeWaveConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.pad_value = self._cfg.preprocessor.params.pad_value
        self.sigma = self._cfg.sigma
        self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor)
        self.squeezewave = instantiate(self._cfg.squeezewave)
        self.mode = OperationMode.infer
        self.loss = WaveGlowLoss()  # Same loss as WaveGlow

    @property
    def input_types(self):
        return {
            "audio": NeuralType(('B', 'T'), AudioSignal()),
            "audio_len": NeuralType(('B'), LengthsType()),
            "run_inverse": NeuralType(optional=True),
        }

    @property
    def output_types(self):
        if self.mode == OperationMode.training or self.mode == OperationMode.validation:
            output_dict = {
                "pred_normal_dist":
                NeuralType(('B', 'flowgroup', 'T'),
                           NormalDistributionSamplesType()),
                "log_s_list":
                NeuralType(('B', 'flowgroup', 'T'),
                           VoidType()),  # TODO: Figure out a good typing
                "log_det_W_list":
                NeuralType(elements_type=VoidType()
                           ),  # TODO: Figure out a good typing
            }
            if self.mode == OperationMode.validation:
                output_dict["audio_pred"] = NeuralType(('B', 'T'),
                                                       AudioSignal())
                output_dict["spec"] = NeuralType(('B', 'T', 'D'),
                                                 MelSpectrogramType())
                output_dict["spec_len"] = NeuralType(('B'), LengthsType())
            return output_dict
        return {
            "audio_pred": NeuralType(('B', 'T'), AudioSignal()),
        }

    @typecheck()
    def forward(self, *, audio, audio_len, run_inverse=True):
        if self.mode != self.squeezewave.mode:
            raise ValueError(
                f"SqueezeWaveModel's mode {self.mode} does not match SqueezeWaveModule's mode {self.squeezewave.mode}"
            )
        spec, spec_len = self.audio_to_melspec_precessor(audio, audio_len)
        tensors = self.squeezewave(spec=spec,
                                   audio=audio,
                                   run_inverse=run_inverse)
        if self.mode == OperationMode.training:
            return tensors[:-1]  # z, log_s_list, log_det_W_list
        elif self.mode == OperationMode.validation:
            z, log_s_list, log_det_W_list, audio_pred = tensors
            return z, log_s_list, log_det_W_list, audio_pred, spec, spec_len
        return tensors  # audio_pred

    @typecheck(
        input_types={
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "sigma": NeuralType(optional=True)
        },
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def convert_spectrogram_to_audio(self,
                                     spec: torch.Tensor,
                                     sigma: bool = 1.0) -> torch.Tensor:
        self.eval()
        self.mode = OperationMode.infer
        self.squeezewave.mode = OperationMode.infer

        with torch.no_grad():
            audio = self.squeezewave(spec=spec,
                                     run_inverse=True,
                                     audio=None,
                                     sigma=sigma)

        return audio

    def training_step(self, batch, batch_idx):
        self.mode = OperationMode.training
        self.squeezewave.mode = OperationMode.training
        audio, audio_len = batch
        z, log_s_list, log_det_W_list = self.forward(audio=audio,
                                                     audio_len=audio_len,
                                                     run_inverse=False)

        loss = self.loss(z=z,
                         log_s_list=log_s_list,
                         log_det_W_list=log_det_W_list,
                         sigma=self.sigma)
        return {
            'loss': loss,
            'progress_bar': {
                'training_loss': loss
            },
            'log': {
                'loss': loss
            },
        }

    def validation_step(self, batch, batch_idx):
        self.mode = OperationMode.validation
        self.squeezewave.mode = OperationMode.validation
        audio, audio_len = batch
        z, log_s_list, log_det_W_list, audio_pred, spec, spec_len = self.forward(
            audio=audio, audio_len=audio_len, run_inverse=(batch_idx == 0))
        loss = self.loss(z=z,
                         log_s_list=log_s_list,
                         log_det_W_list=log_det_W_list,
                         sigma=self.sigma)
        return {
            "val_loss": loss,
            "audio_pred": audio_pred,
            "mel_target": spec,
            "mel_len": spec_len,
        }

    def validation_epoch_end(self, outputs):
        if self.logger is not None and self.logger.experiment is not None:
            waveglow_log_to_tb_func(
                self.logger.experiment,
                outputs[0].values(),
                self.global_step,
                tag="eval",
                mel_fb=self.audio_to_melspec_precessor.fb,
            )
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log('val_loss', avg_loss)

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")  # TODO
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")  # TODO
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="SqueezeWave-22050Hz",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemottsmodels/versions/1.0.0a5/files/SqueezeWave-22050Hz.nemo",
            description=
            "This model is trained on LJSpeech sampled at 22050Hz, and can be used as an universal vocoder.",
            class_=cls,
        )
        list_of_models.append(model)
        return list_of_models
コード例 #6
0
class HifiGanModel(Vocoder, Exportable):
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator(
            debug=cfg.debug if "debug" in cfg else False)
        self.msd = MultiScaleDiscriminator(
            debug=cfg.debug if "debug" in cfg else False)
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.l1_factor = cfg.get("l1_loss_factor", 45)

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        if self._train_dl and isinstance(self._train_dl.dataset,
                                         MelAudioDataset):
            self.input_as_mel = True
        else:
            self.input_as_mel = False

        self.automatic_optimization = False

    def configure_optimizers(self):
        self.optim_g = instantiate(
            self._cfg.optim,
            params=self.generator.parameters(),
        )
        self.optim_d = instantiate(
            self._cfg.optim,
            params=itertools.chain(self.msd.parameters(),
                                   self.mpd.parameters()),
        )

        self.scheduler_g = CosineAnnealing(
            optimizer=self.optim_g,
            max_steps=self._cfg.max_steps,
            min_lr=self._cfg.sched.min_lr,
            warmup_steps=self._cfg.sched.warmup_ratio * self._cfg.max_steps,
        )  # Use warmup to delay start
        sch1_dict = {
            'scheduler': self.scheduler_g,
            'interval': 'step',
        }

        self.scheduler_d = CosineAnnealing(
            optimizer=self.optim_d,
            max_steps=self._cfg.max_steps,
            min_lr=self._cfg.sched.min_lr,
        )
        sch2_dict = {
            'scheduler': self.scheduler_d,
            'interval': 'step',
        }

        return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]

    @property
    def input_types(self):
        return {
            "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
        }

    @property
    def output_types(self):
        return {
            "audio": NeuralType(('B', 'S', 'T'),
                                AudioSignal(self.sample_rate)),
        }

    @typecheck()
    def forward(self, *, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)

    @typecheck(
        input_types={
            "spec": NeuralType(('B', 'C', 'T'), MelSpectrogramType())
        },
        output_types={"audio": NeuralType(('B', 'T'), AudioSignal())},
    )
    def convert_spectrogram_to_audio(self,
                                     spec: 'torch.tensor') -> 'torch.tensor':
        return self(spec=spec).squeeze(1)

    def training_step(self, batch, batch_idx, optimizer_idx):
        # if in finetune mode the mels are pre-computed using a
        # spectrogram generator
        if self.input_as_mel:
            audio, audio_len, audio_mel = batch
        # else, we compute the mel using the ground truth audio
        else:
            audio, audio_len = batch
            # mel as input for generator
            audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)

        # mel as input for L1 mel loss
        audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
        audio = audio.unsqueeze(1)

        audio_pred = self.generator(x=audio_mel)
        audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1),
                                                audio_len)

        # train discriminator
        self.optim_d.zero_grad()
        mpd_score_real, mpd_score_gen, _, _ = self.mpd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_mpd, _, _ = self.discriminator_loss(
            disc_real_outputs=mpd_score_real,
            disc_generated_outputs=mpd_score_gen)
        msd_score_real, msd_score_gen, _, _ = self.msd(
            y=audio, y_hat=audio_pred.detach())
        loss_disc_msd, _, _ = self.discriminator_loss(
            disc_real_outputs=msd_score_real,
            disc_generated_outputs=msd_score_gen)
        loss_d = loss_disc_msd + loss_disc_mpd
        self.manual_backward(loss_d)
        self.optim_d.step()

        # train generator
        self.optim_g.zero_grad()
        loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel)
        _, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(
            y=audio, y_hat=audio_pred)
        _, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(
            y=audio, y_hat=audio_pred)
        loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real,
                                        fmap_g=fmap_mpd_gen)
        loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real,
                                        fmap_g=fmap_msd_gen)
        loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
        loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
        loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel * self.l1_factor
        self.manual_backward(loss_g)
        self.optim_g.step()

        # run schedulers
        sch1, sch2 = self.lr_schedulers()
        sch1.step()
        sch2.step()

        metrics = {
            "g_loss_fm_mpd": loss_fm_mpd,
            "g_loss_fm_msd": loss_fm_msd,
            "g_loss_gen_mpd": loss_gen_mpd,
            "g_loss_gen_msd": loss_gen_msd,
            "g_loss": loss_g,
            "d_loss_mpd": loss_disc_mpd,
            "d_loss_msd": loss_disc_msd,
            "d_loss": loss_d,
            "global_step": self.global_step,
            "lr": self.optim_g.param_groups[0]['lr'],
        }
        self.log_dict(metrics, on_step=True, sync_dist=True)
        self.log("g_l1_loss",
                 loss_mel,
                 prog_bar=True,
                 logger=False,
                 sync_dist=True)

    def validation_step(self, batch, batch_idx):
        if self.input_as_mel:
            audio, audio_len, audio_mel = batch
            audio_mel_len = [audio_mel.shape[1]] * audio_mel.shape[0]
        else:
            audio, audio_len = batch
            audio_mel, audio_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred = self(spec=audio_mel)

        # perform bias denoising
        pred_denoised = self._bias_denoise(audio_pred, audio_mel).squeeze(1)
        pred_denoised_mel, _ = self.audio_to_melspec_precessor(
            pred_denoised, audio_len)

        if self.input_as_mel:
            gt_mel, gt_mel_len = self.audio_to_melspec_precessor(
                audio, audio_len)
        audio_pred_mel, _ = self.audio_to_melspec_precessor(
            audio_pred.squeeze(1), audio_len)
        loss_mel = F.l1_loss(audio_mel, audio_pred_mel)

        self.log_dict({"val_loss": loss_mel}, on_epoch=True, sync_dist=True)

        # plot audio once per epoch
        if batch_idx == 0 and isinstance(self.logger,
                                         WandbLogger) and HAVE_WANDB:
            clips = []
            specs = []
            for i in range(min(5, audio.shape[0])):
                clips += [
                    wandb.Audio(
                        audio[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"real audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        audio_pred[i,
                                   0, :audio_len[i]].data.cpu().numpy().astype(
                                       'float32'),
                        caption=f"generated audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                    wandb.Audio(
                        pred_denoised[i, :audio_len[i]].data.cpu().numpy(),
                        caption=f"denoised audio {i}",
                        sample_rate=self.sample_rate,
                    ),
                ]
                specs += [
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"input mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(audio_pred_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"output mel {i}",
                    ),
                    wandb.Image(
                        plot_spectrogram_to_numpy(pred_denoised_mel[
                            i, :, :audio_mel_len[i]].data.cpu().numpy()),
                        caption=f"denoised mel {i}",
                    ),
                ]
                if self.input_as_mel:
                    specs += [
                        wandb.Image(
                            plot_spectrogram_to_numpy(gt_mel[
                                i, :, :audio_mel_len[i]].data.cpu().numpy()),
                            caption=f"gt mel {i}",
                        ),
                    ]

            self.logger.experiment.log({"audio": clips, "specs": specs})

    def _bias_denoise(self, audio, mel):
        def stft(x):
            comp = stft_patch(x.squeeze(1),
                              n_fft=1024,
                              hop_length=256,
                              win_length=1024)
            real, imag = comp[..., 0], comp[..., 1]
            mags = torch.sqrt(real**2 + imag**2)
            phase = torch.atan2(imag, real)
            return mags, phase

        def istft(mags, phase):
            comp = torch.stack(
                [mags * torch.cos(phase), mags * torch.sin(phase)], dim=-1)
            x = torch.istft(comp, n_fft=1024, hop_length=256, win_length=1024)
            return x

        # create bias tensor
        if self.stft_bias is None or self.stft_bias.shape[0] != audio.shape[0]:
            audio_bias = self(spec=torch.zeros_like(mel, device=mel.device))
            self.stft_bias, _ = stft(audio_bias)
            self.stft_bias = self.stft_bias[:, :, 0][:, :, None]

        audio_mags, audio_phase = stft(audio)
        audio_mags = audio_mags - self.cfg.get("denoise_strength",
                                               0.0025) * self.stft_bias
        audio_mags = torch.clamp(audio_mags, 0.0)
        audio_denoised = istft(audio_mags, audio_phase).unsqueeze(1)

        return audio_denoised

    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        self._validation_dl = self.__setup_dataloader_from_config(
            cfg, shuffle_should_be=False, name="validation")

    @classmethod
    def list_available_models(cls) -> 'Optional[Dict[str, str]]':
        list_of_models = []
        model = PretrainedModelInfo(
            pretrained_model_name="tts_hifigan",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_hifigan/versions/1.0.0rc1/files/tts_hifigan.nemo",
            description=
            "This model is trained on LJSpeech audio sampled at 22050Hz and mel spectrograms generated from Tacotron2, TalkNet, and FastPitch. This model has been tested on generating female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)
        return list_of_models

    def load_state_dict(self, state_dict, strict=True):
        # override load_state_dict to give us some flexibility to be backward-compatible
        # with old checkpoints
        new_state_dict = {}
        num_resblocks = len(self.cfg['generator']['resblock_kernel_sizes'])
        for k, v in state_dict.items():
            new_k = k
            if 'resblocks' in k:
                parts = k.split(".")
                # only do this is the checkpoint type is older
                if len(parts) == 6:
                    layer = int(parts[2])
                    new_layer = f"{layer // num_resblocks}.{layer % num_resblocks}"
                    new_k = f"generator.resblocks.{new_layer}.{'.'.join(parts[3:])}"
            new_state_dict[new_k] = v
        super().load_state_dict(new_state_dict, strict=strict)

    def _prepare_for_export(self, **kwargs):
        """
        Override this method to prepare module for export. This is in-place operation.
        Base version does common necessary module replacements (Apex etc)
        """
        if self.generator is not None:
            self.generator.remove_weight_norm()

    def input_example(self):
        """
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        """
        par = next(self.parameters())
        mel = torch.randn((1, self.cfg['preprocessor']['nfilt'], 96),
                          device=par.device,
                          dtype=par.dtype)
        return ({'spec': mel}, )

    def forward_for_export(self, spec):
        """
        Runs the generator, for inputs and outputs see input_types, and output_types
        """
        return self.generator(x=spec)
コード例 #7
0
ファイル: uniglow.py プロジェクト: askaydevs/ITN_Phore
 def input_types(self):
     return {
         "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
         "audio": NeuralType(('B', 'T'), AudioSignal(), optional=True),
         "sigma": NeuralType(optional=True),
     }