Example #1
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator()
        self.msd = MultiScaleDiscriminator()
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        if isinstance(self._train_dl.dataset, MelAudioDataset):
            self.finetune = True
            logging.info("fine-tuning on pre-computed mels")
        else:
            self.finetune = False
            logging.info("training on ground-truth mels")
Example #2
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator()
        self.msd = MultiScaleDiscriminator()
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.l1_factor = cfg.get("l1_loss_factor", 45)

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        if self._train_dl and isinstance(self._train_dl.dataset,
                                         MelAudioDataset):
            self.input_as_mel = True
        else:
            self.input_as_mel = False

        self.automatic_optimization = False
Example #3
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # We use separate preprocessor for training, because we need to pass grads and remove pitch fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator(debug=cfg.debug if "debug" in cfg else False)
        self.msd = MultiScaleDiscriminator(debug=cfg.debug if "debug" in cfg else False)
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.l1_factor = cfg.get("l1_loss_factor", 45)

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        self.input_as_mel = False
        if self._train_dl:
            # TODO(Oktai15): remove it in 1.8.0 version
            if isinstance(self._train_dl.dataset, MelAudioDataset):
                self.input_as_mel = True
            elif isinstance(self._train_dl.dataset, VocoderDataset):
                self.input_as_mel = self._train_dl.dataset.load_precomputed_mel

        self.automatic_optimization = False
Example #4
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        self.encoder = instantiate(cfg.encoder)
        self.variance_adapter = instantiate(cfg.variance_adaptor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()

        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)
        self.mel_val_loss = L1MelLoss()
        self.durationloss = DurationLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()
        self.mseloss = torch.nn.MSELoss()

        self.energy = cfg.add_energy_predictor
        self.pitch = cfg.add_pitch_predictor
        self.mel_loss_coeff = cfg.mel_loss_coeff
        self.pitch_loss_coeff = cfg.pitch_loss_coeff
        self.energy_loss_coeff = cfg.energy_loss_coeff
        self.splice_length = cfg.splice_length

        self.use_energy_pred = False
        self.use_pitch_pred = False
        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size

        # Parser and mappings are used for inference only.
        self.parser = parsers.make_parser(name='en')
        if 'mappings_filepath' in cfg:
            mappings_filepath = cfg.get('mappings_filepath')
        else:
            logging.error(
                "ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath."
            )
        mappings_filepath = self.register_artifact('mappings_filepath',
                                                   mappings_filepath)
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']
Example #5
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # We use separate preprocessor for training, because we need to pass grads and remove pitch fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(
            cfg.generator,
            n_mel_channels=cfg.preprocessor.nfilt,
            hop_length=cfg.preprocessor.n_window_stride)
        self.mpd = MultiPeriodDiscriminator(
            cfg.discriminator.mpd,
            debug=cfg.debug if "debug" in cfg else False)
        self.mrd = MultiResolutionDiscriminator(
            cfg.discriminator.mrd,
            debug=cfg.debug if "debug" in cfg else False)

        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        # Reshape MRD resolutions hyperparameter and apply them to MRSTFT loss
        self.stft_resolutions = cfg.discriminator.mrd.resolutions
        self.fft_sizes = [res[0] for res in self.stft_resolutions]
        self.hop_sizes = [res[1] for res in self.stft_resolutions]
        self.win_lengths = [res[2] for res in self.stft_resolutions]
        self.mrstft_loss = MultiResolutionSTFTLoss(self.fft_sizes,
                                                   self.hop_sizes,
                                                   self.win_lengths)
        self.stft_lamb = cfg.stft_lamb

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        self.input_as_mel = False
        if self._train_dl:
            # TODO(Oktai15): remove it in 1.8.0 version
            if isinstance(self._train_dl.dataset, MelAudioDataset):
                self.input_as_mel = True
            elif isinstance(self._train_dl.dataset, VocoderDataset):
                self.input_as_mel = self._train_dl.dataset.load_precomputed_mel

        self.automatic_optimization = False
Example #6
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # use a different melspec extractor because:
        # 1. we need to pass grads
        # 2. we need remove fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(cfg.generator)
        self.mpd = MultiPeriodDiscriminator()
        self.msd = MultiScaleDiscriminator()
        self.feature_loss = FeatureMatchingLoss()
        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        self.sample_rate = self._cfg.preprocessor.sample_rate
Example #7
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)

        self._parser = parsers.make_parser(
            labels=cfg.labels,
            name='en',
            unk_id=-1,
            blank_id=-1,
            do_normalize=True,
            abbreviation_version="fastpitch",
            make_table=False,
        )

        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(FastPitchHifiGanE2EConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.preprocessor = instantiate(cfg.preprocessor)
        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)

        self.encoder = instantiate(cfg.input_fft)
        self.duration_predictor = instantiate(cfg.duration_predictor)
        self.pitch_predictor = instantiate(cfg.pitch_predictor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()
        self.mel_val_loss = L1MelLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()

        self.max_token_duration = cfg.max_token_duration

        self.pitch_emb = torch.nn.Conv1d(
            1,
            cfg.symbols_embedding_dim,
            kernel_size=cfg.pitch_embedding_kernel_size,
            padding=int((cfg.pitch_embedding_kernel_size - 1) / 2),
        )

        # Store values precomputed from training data for convenience
        self.register_buffer('pitch_mean', torch.zeros(1))
        self.register_buffer('pitch_std', torch.zeros(1))

        self.pitchloss = PitchLoss()
        self.durationloss = DurationLoss()

        self.mel_loss_coeff = cfg.mel_loss_coeff

        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.hann_window = None
        self.splice_length = cfg.splice_length
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size