def copy_model_files(config: Coqpit, out_path, new_fields=None): """Copy config.json and other model files to training folder and add new fields. Args: config (Coqpit): Coqpit config defining the training run. out_path (str): output path to copy the file. new_fields (dict): new fileds to be added or edited in the config file. """ copy_config_path = os.path.join(out_path, "config.json") # add extra information fields if new_fields: config.update(new_fields, allow_new=True) # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. with fsspec.open(copy_config_path, "w", encoding="utf8") as f: json.dump(config.to_dict(), f, indent=4) # copy model stats file if available if config.audio.stats_path is not None: copy_stats_path = os.path.join(out_path, "scale_stats.npy") filesystem = fsspec.get_mapper(copy_stats_path).fs if not filesystem.exists(copy_stats_path): with fsspec.open(config.audio.stats_path, "rb") as source_file: with fsspec.open(copy_stats_path, "wb") as target_file: shutil.copyfileobj(source_file, target_file)
def get_characters(config: Coqpit) -> str: # TODO: implement CharacterProcessor if config.characters is not None: symbols, phonemes = make_symbols(**config.characters) else: from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols config.characters = CharactersConfig(**parse_symbols()) model_characters = phonemes if config.use_phonemes else symbols num_chars = len(model_characters) + getattr(config, "add_blank", False) return model_characters, config, num_chars
def init_from_config(config: Coqpit) -> "LanguageManager": """Initialize the language manager from a Coqpit config. Args: config (Coqpit): Coqpit config. """ language_manager = None if check_config_and_model_args(config, "use_language_embedding", True): if config.get("language_ids_file", None): language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) language_manager = LanguageManager(config=config) return language_manager
def get_characters(config: Coqpit): if config.characters is not None: symbols = Vits.make_symbols(config) else: from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel parse_symbols, phonemes, symbols, ) config.characters = parse_symbols() if config.use_phonemes: symbols = phonemes # noqa: F811 num_chars = len(symbols) + getattr(config, "add_blank", False) return symbols, config, num_chars
def get_characters(config: Coqpit) -> str: # TODO: implement CharacterProcessor if config.characters is not None: symbols, phonemes = make_symbols(**config.characters) else: from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel parse_symbols, phonemes, symbols, ) config.characters = parse_symbols() model_characters = phonemes if config.use_phonemes else symbols return model_characters, config
def __init__(self, config: Coqpit): super().__init__(config) chars, self.config = self.get_characters(config) config.num_chars = len(chars) self.decoder_output_dim = config.out_channels # pass all config fields to `self` # for fewer code change for key in config: setattr(self, key, config[key]) # set speaker embedding channel size for determining `in_channels` for the connected layers. # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based # on the number of speakers infered from the dataset. if self.use_speaker_embedding or self.use_d_vector_file: self.init_multispeaker(config) self.decoder_in_features += ( self.embedded_speaker_dim ) # add speaker embedding dim if self.use_gst: self.decoder_in_features += self.gst.gst_embedding_dim # embedding layer self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0) # base model layers self.encoder = Encoder(self.encoder_in_features) self.decoder = Decoder( self.decoder_in_features, self.decoder_output_dim, self.r, self.attention_type, self.attention_win, self.attention_norm, self.prenet_type, self.prenet_dropout, self.use_forward_attn, self.transition_agent, self.forward_attn_mask, self.location_attn, self.attention_heads, self.separate_stopnet, self.max_decoder_steps, ) self.postnet = Postnet(self.out_channels) # setup prenet dropout self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference # global style token layers if self.gst and self.use_gst: self.gst_layer = GST( num_mel=self.decoder_output_dim, num_heads=self.gst.gst_num_heads, num_style_tokens=self.gst.gst_num_style_tokens, gst_embedding_dim=self.gst.gst_embedding_dim, ) # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( self.decoder_in_features, self.decoder_output_dim, self.ddc_r, self.attention_type, self.attention_win, self.attention_norm, self.prenet_type, self.prenet_dropout, self.use_forward_attn, self.transition_agent, self.forward_attn_mask, self.location_attn, self.attention_heads, self.separate_stopnet, self.max_decoder_steps, )
def get_data_loader( self, config: Coqpit, assets: Dict, is_eval: bool, samples: Union[List[Dict], List[List]], verbose: bool, num_gpus: int, rank: int = None, ) -> "DataLoader": if is_eval and not config.run_eval: loader = None else: # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(config, "model_args"): speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None config.use_d_vector_file = config.model_args.use_d_vector_file else: speaker_id_mapping = self.speaker_manager.ids if config.use_speaker_embedding else None d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None else: speaker_id_mapping = None d_vector_mapping = None # setup multi-lingual attributes if hasattr( self, "language_manager") and self.language_manager is not None: language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None else: language_id_mapping = None # init dataloader dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), samples=samples, ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_text_len=config.min_text_len, max_text_len=config.max_text_len, min_audio_len=config.min_audio_len, max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, language_id_mapping=language_id_mapping, ) # wait all the DDP process to be ready if num_gpus > 1: dist.barrier() # sort input sequences from short to long dataset.preprocess_samples() # get samplers sampler = self.get_sampler(config, dataset, num_gpus) loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, shuffle=False, # shuffle is done in the dataset. collate_fn=dataset.collate_fn, drop_last= False, # setting this False might cause issues in AMP training. sampler=sampler, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, ) return loader
def get_data_loader( self, config: Coqpit, assets: Dict, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int = None, ) -> "DataLoader": if is_eval and not config.run_eval: loader = None else: ap = assets["audio_processor"] # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(config, "model_args"): speaker_id_mapping = ( self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None) d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None config.use_d_vector_file = config.model_args.use_d_vector_file else: speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None else: speaker_id_mapping = None d_vector_mapping = None # setup custom symbols if needed custom_symbols = None if hasattr(self, "make_symbols"): custom_symbols = self.make_symbols(self.config) if hasattr(self, "language_manager"): language_id_mapping = ( self.language_manager.language_id_mapping if self.args.use_language_embedding else None) else: language_id_mapping = None # init dataloader dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), meta_data=data_items, ap=ap, characters=config.characters, custom_symbols=custom_symbols, add_blank=config["add_blank"], return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_seq_len=config.min_seq_len, max_seq_len=config.max_seq_len, phoneme_cache_path=config.phoneme_cache_path, use_phonemes=config.use_phonemes, phoneme_language=config.phoneme_language, enable_eos_bos=config.enable_eos_bos_chars, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping, language_id_mapping=language_id_mapping, ) # pre-compute phonemes if config.use_phonemes and config.compute_input_seq_cache and rank in [ None, 0 ]: if hasattr(self, "eval_data_items") and is_eval: dataset.items = self.eval_data_items elif hasattr(self, "train_data_items") and not is_eval: dataset.items = self.train_data_items else: # precompute phonemes for precise estimate of sequence lengths. # otherwise `dataset.sort_items()` uses raw text lengths dataset.compute_input_seq(config.num_loader_workers) # TODO: find a more efficient solution # cheap hack - store items in the model state to avoid recomputing when reinit the dataset if is_eval: self.eval_data_items = dataset.items else: self.train_data_items = dataset.items # halt DDP processes for the main process to finish computing the phoneme cache if num_gpus > 1: dist.barrier() # sort input sequences from short to long dataset.sort_and_filter_items( config.get("sort_by_audio_len", default=False)) # compute pitch frames and write to files. if config.compute_f0 and rank in [None, 0]: if not os.path.exists(config.f0_cache_path): dataset.pitch_extractor.compute_pitch( ap, config.get("f0_cache_path", None), config.num_loader_workers) # halt DDP processes for the main process to finish computing the F0 cache if num_gpus > 1: dist.barrier() # load pitch stats computed above by all the workers if config.compute_f0: dataset.pitch_extractor.load_pitch_stats( config.get("f0_cache_path", None)) # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None # Weighted samplers assert not ( num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) ), "language_weighted_sampler is not supported with DistributedSampler" assert not ( num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) ), "speaker_weighted_sampler is not supported with DistributedSampler" if sampler is None: if getattr(config, "use_language_weighted_sampler", False): print(" > Using Language weighted sampler") sampler = get_language_weighted_sampler(dataset.items) elif getattr(config, "use_speaker_weighted_sampler", False): print(" > Using Language weighted sampler") sampler = get_speaker_weighted_sampler(dataset.items) loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, ) return loader
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): super().__init__(config) self.speaker_manager = speaker_manager chars, self.config, _ = self.get_characters(config) config.num_chars = len(chars) self.decoder_output_dim = config.out_channels # pass all config fields to `self` # for fewer code change for key in config: setattr(self, key, config[key]) # init multi-speaker layers if self.use_speaker_embedding or self.use_d_vector_file: self.init_multispeaker(config) self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim if self.use_gst: self.decoder_in_features += self.gst.gst_embedding_dim # embedding layer self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0) # base model layers self.encoder = Encoder(self.encoder_in_features) self.decoder = Decoder( self.decoder_in_features, self.decoder_output_dim, self.r, self.attention_type, self.attention_win, self.attention_norm, self.prenet_type, self.prenet_dropout, self.use_forward_attn, self.transition_agent, self.forward_attn_mask, self.location_attn, self.attention_heads, self.separate_stopnet, self.max_decoder_steps, ) self.postnet = Postnet(self.out_channels) # setup prenet dropout self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference # global style token layers if self.gst and self.use_gst: self.gst_layer = GST( num_mel=self.decoder_output_dim, num_heads=self.gst.gst_num_heads, num_style_tokens=self.gst.gst_num_style_tokens, gst_embedding_dim=self.gst.gst_embedding_dim, ) # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( self.decoder_in_features, self.decoder_output_dim, self.ddc_r, self.attention_type, self.attention_win, self.attention_norm, self.prenet_type, self.prenet_dropout, self.use_forward_attn, self.transition_agent, self.forward_attn_mask, self.location_attn, self.attention_heads, self.separate_stopnet, self.max_decoder_steps, )