def _test_forward_with_speaker_id(self, batch_size): input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs( batch_size) speaker_ids = torch.randint(0, 24, (batch_size, )).long().to(device) # create model config = GlowTTSConfig( num_chars=32, use_speaker_embedding=True, num_speakers=24, ) model = GlowTTS.init_from_config(config, verbose=False).to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids}) self.assertEqual(y["z"].shape, mel_spec.shape) self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) self.assertEqual(y["y_mean"].shape, mel_spec.shape) self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1], )) self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1, )) self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1, ))
def _test_forward_with_d_vector(self, batch_size): input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs( batch_size) d_vector = torch.rand(batch_size, 256).to(device) # create model config = GlowTTSConfig( num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) model = GlowTTS.init_from_config(config, verbose=False).to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector}) self.assertEqual(y["z"].shape, mel_spec.shape) self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) self.assertEqual(y["y_mean"].shape, mel_spec.shape) self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1], )) self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1, )) self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1, ))
def test_test_run(self): config = GlowTTSConfig(num_chars=32) model = GlowTTS.init_from_config(config, verbose=False).to(device) model.run_data_dep_init = False model.eval() test_figures, test_audios = model.test_run(None) self.assertTrue(test_figures is not None) self.assertTrue(test_audios is not None)
def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") config = GlowTTSConfig(num_chars=32) model = GlowTTS.init_from_config(config, verbose=False).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) model.load_checkpoint(config, chkp_path) self.assertTrue(model.training) model.load_checkpoint(config, chkp_path, eval=True) self.assertFalse(model.training)
def test_init_from_config(self): config = GlowTTSConfig(num_chars=32) model = GlowTTS.init_from_config(config, verbose=False).to(device) config = GlowTTSConfig(num_chars=32, num_speakers=2) model = GlowTTS.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 2) self.assertTrue(not hasattr(model, "emb_g")) config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True) model = GlowTTS.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 2) self.assertTrue(hasattr(model, "emb_g")) config = GlowTTSConfig( num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) model = GlowTTS.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 10) self.assertTrue(hasattr(model, "emb_g")) config = GlowTTSConfig( num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) model = GlowTTS.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.c_in_channels == config.d_vector_dim)
def _test_inference_with_speaker_ids(self, batch_size): input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs( batch_size) speaker_ids = torch.randint(0, 24, (batch_size, )).long().to(device) # create model config = GlowTTSConfig( num_chars=32, use_speaker_embedding=True, num_speakers=24, ) model = GlowTTS.init_from_config(config, verbose=False).to(device) outputs = model.inference(input_dummy, { "x_lengths": input_lengths, "speaker_ids": speaker_ids }) self._assert_inference_outputs(outputs, input_dummy, mel_spec)
def _test_inference_with_d_vector(self, batch_size): input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs( batch_size) d_vector = torch.rand(batch_size, 256).to(device) config = GlowTTSConfig( num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) model = GlowTTS.init_from_config(config, verbose=False).to(device) model.eval() outputs = model.inference(input_dummy, { "x_lengths": input_lengths, "d_vectors": d_vector }) self._assert_inference_outputs(outputs, input_dummy, mel_spec)
def test_train_eval_log(self): batch_size = BATCH_SIZE input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs( batch_size) batch = {} batch["text_input"] = input_dummy batch["text_lengths"] = input_lengths batch["mel_lengths"] = mel_lengths batch["mel_input"] = mel_spec batch["d_vectors"] = None batch["speaker_ids"] = None config = GlowTTSConfig(num_chars=32) model = GlowTTS.init_from_config(config, verbose=False).to(device) model.run_data_dep_init = False model.train() logger = TensorboardLogger(log_dir=os.path.join( get_tests_output_path(), "dummy_glow_tts_logs"), model_name="glow_tts_test_train_log") criterion = model.get_criterion() outputs, _ = model.train_step(batch, criterion) model.train_log(batch, outputs, logger, None, 1) model.eval_log(batch, outputs, logger, None, 1) logger.finish()
def test_get_criterion(self): config = GlowTTSConfig(num_chars=32) model = GlowTTS.init_from_config(config, verbose=False).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None)