def test_d_vector_inference(self): args = VitsArgs( spec_segment_size=10, num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) config = VitsConfig(model_args=args) model = Vits.init_from_config(config, verbose=False).to(device) model.eval() # batch size = 1 input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) d_vectors = torch.randn(1, 256).to(device) outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors}) self._check_inference_outputs(config, outputs, input_dummy) # batch size = 2 input_dummy, input_lengths, *_ = self._create_inputs(config) d_vectors = torch.randn(2, 256).to(device) outputs = model.inference(input_dummy, aux_input={ "x_lengths": input_lengths, "d_vectors": d_vectors }) self._check_inference_outputs(config, outputs, input_dummy, batch_size=2)
def test_test_run(self): config = VitsConfig(model_args=VitsArgs(num_chars=32)) model = Vits.init_from_config(config, verbose=False).to(device) model.run_data_dep_init = False model.eval() test_figures, test_audios = model.test_run(None) self.assertTrue(test_figures is not None) self.assertTrue(test_audios is not None)
def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") config = VitsConfig(VitsArgs(num_chars=32)) model = Vits.init_from_config(config, verbose=False).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) model.load_checkpoint(config, chkp_path) self.assertTrue(model.training) model.load_checkpoint(config, chkp_path, eval=True) self.assertFalse(model.training)
def test_init_from_config(self): config = VitsConfig(model_args=VitsArgs(num_chars=32)) model = Vits.init_from_config(config, verbose=False).to(device) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2)) model = Vits.init_from_config(config, verbose=False).to(device) self.assertTrue(not hasattr(model, "emb_g")) config = VitsConfig(model_args=VitsArgs( num_chars=32, num_speakers=2, use_speaker_embedding=True)) model = Vits.init_from_config(config, verbose=False).to(device) self.assertEqual(model.num_speakers, 2) self.assertTrue(hasattr(model, "emb_g")) config = VitsConfig(model_args=VitsArgs( num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), )) model = Vits.init_from_config(config, verbose=False).to(device) self.assertEqual(model.num_speakers, 10) self.assertTrue(hasattr(model, "emb_g")) config = VitsConfig(model_args=VitsArgs( num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), )) model = Vits.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)
def test_train_eval_log(self): batch_size = 2 config = VitsConfig( model_args=VitsArgs(num_chars=32, spec_segment_size=10)) model = Vits.init_from_config(config, verbose=False).to(device) model.run_data_dep_init = False model.train() batch = self._create_batch(config, batch_size) logger = TensorboardLogger(log_dir=os.path.join( get_tests_output_path(), "dummy_vits_logs"), model_name="vits_test_train_log") criterion = model.get_criterion() criterion = [criterion[0].to(device), criterion[1].to(device)] outputs = [None] * 2 outputs[0], _ = model.train_step(batch, criterion, 0) outputs[1], _ = model.train_step(batch, criterion, 1) model.train_log(batch, outputs, logger, None, 1) model.eval_log(batch, outputs, logger, None, 1) logger.finish()
def test_d_vector_forward(self): batch_size = 2 args = VitsArgs( spec_segment_size=10, num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) config = VitsConfig(model_args=args) model = Vits.init_from_config(config, verbose=False).to(device) model.train() input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs( config, batch_size=batch_size) d_vectors = torch.randn(batch_size, 256).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}) self._check_forward_outputs(config, output_dict)
def test_get_criterion(self): config = VitsConfig(VitsArgs(num_chars=32)) model = Vits.init_from_config(config, verbose=False).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None)