Exemple #1
0
 def test_init_multispeaker(self):
     config = GlowTTSConfig(num_chars=32)
     model = GlowTTS(config)
     # speaker embedding with default speaker_embedding_dim
     config.use_speaker_embedding = True
     config.num_speakers = 5
     config.d_vector_dim = None
     model.init_multispeaker(config)
     self.assertEqual(model.c_in_channels, model.hidden_channels_enc)
     # use external speaker embeddings with speaker_embedding_dim = 301
     config = GlowTTSConfig(num_chars=32)
     config.use_d_vector_file = True
     config.d_vector_dim = 301
     model = GlowTTS(config)
     model.init_multispeaker(config)
     self.assertEqual(model.c_in_channels, 301)
     # use speaker embedddings by the provided speaker_manager
     config = GlowTTSConfig(num_chars=32)
     config.use_speaker_embedding = True
     config.speakers_file = os.path.join(get_tests_data_path(), "ljspeech",
                                         "speakers.json")
     speaker_manager = SpeakerManager.init_from_config(config)
     model = GlowTTS(config)
     model.speaker_manager = speaker_manager
     model.init_multispeaker(config)
     self.assertEqual(model.c_in_channels, model.hidden_channels_enc)
     self.assertEqual(model.num_speakers, speaker_manager.num_speakers)
     # use external speaker embeddings by the provided speaker_manager
     config = GlowTTSConfig(num_chars=32)
     config.use_d_vector_file = True
     config.d_vector_dim = 256
     config.d_vector_file = os.path.join(get_tests_data_path(),
                                         "dummy_speakers.json")
     speaker_manager = SpeakerManager.init_from_config(config)
     model = GlowTTS(config)
     model.speaker_manager = speaker_manager
     model.init_multispeaker(config)
     self.assertEqual(model.c_in_channels, speaker_manager.embedding_dim)
     self.assertEqual(model.num_speakers, speaker_manager.num_speakers)
Exemple #2
0
# You can define your custom sample loader returning the list of samples.
# Or define your custom formatter and pass it to the `load_tts_samples`.
# Check `TTS.tts.datasets.load_tts_samples` for more details.
train_samples, eval_samples = load_tts_samples(
    dataset_config,
    eval_split=True,
    eval_split_max_size=config.eval_split_max_size,
    eval_split_size=config.eval_split_size,
)

# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples,
                                  parse_key="speaker_name")
config.num_speakers = speaker_manager.num_speakers

# init model
model = GlowTTS(config, ap, tokenizer, speaker_manager=speaker_manager)

# INITIALIZE THE TRAINER
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(TrainerArgs(),
                  config,
                  output_path,
                  model=model,
                  train_samples=train_samples,
                  eval_samples=eval_samples)

# AND... 3,2,1... 🚀