Esempio n. 1
0
 def test_d_vector_inference(self):
     args = VitsArgs(
         spec_segment_size=10,
         num_chars=32,
         use_d_vector_file=True,
         d_vector_dim=256,
         d_vector_file=os.path.join(get_tests_data_path(),
                                    "dummy_speakers.json"),
     )
     config = VitsConfig(model_args=args)
     model = Vits.init_from_config(config, verbose=False).to(device)
     model.eval()
     # batch size = 1
     input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
     d_vectors = torch.randn(1, 256).to(device)
     outputs = model.inference(input_dummy,
                               aux_input={"d_vectors": d_vectors})
     self._check_inference_outputs(config, outputs, input_dummy)
     # batch size = 2
     input_dummy, input_lengths, *_ = self._create_inputs(config)
     d_vectors = torch.randn(2, 256).to(device)
     outputs = model.inference(input_dummy,
                               aux_input={
                                   "x_lengths": input_lengths,
                                   "d_vectors": d_vectors
                               })
     self._check_inference_outputs(config,
                                   outputs,
                                   input_dummy,
                                   batch_size=2)
Esempio n. 2
0
 def test_test_run(self):
     config = VitsConfig(model_args=VitsArgs(num_chars=32))
     model = Vits.init_from_config(config, verbose=False).to(device)
     model.run_data_dep_init = False
     model.eval()
     test_figures, test_audios = model.test_run(None)
     self.assertTrue(test_figures is not None)
     self.assertTrue(test_audios is not None)
Esempio n. 3
0
 def test_load_checkpoint(self):
     chkp_path = os.path.join(get_tests_output_path(),
                              "dummy_glow_tts_checkpoint.pth")
     config = VitsConfig(VitsArgs(num_chars=32))
     model = Vits.init_from_config(config, verbose=False).to(device)
     chkp = {}
     chkp["model"] = model.state_dict()
     torch.save(chkp, chkp_path)
     model.load_checkpoint(config, chkp_path)
     self.assertTrue(model.training)
     model.load_checkpoint(config, chkp_path, eval=True)
     self.assertFalse(model.training)
Esempio n. 4
0
    def test_init_from_config(self):
        config = VitsConfig(model_args=VitsArgs(num_chars=32))
        model = Vits.init_from_config(config, verbose=False).to(device)

        config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2))
        model = Vits.init_from_config(config, verbose=False).to(device)
        self.assertTrue(not hasattr(model, "emb_g"))

        config = VitsConfig(model_args=VitsArgs(
            num_chars=32, num_speakers=2, use_speaker_embedding=True))
        model = Vits.init_from_config(config, verbose=False).to(device)
        self.assertEqual(model.num_speakers, 2)
        self.assertTrue(hasattr(model, "emb_g"))

        config = VitsConfig(model_args=VitsArgs(
            num_chars=32,
            num_speakers=2,
            use_speaker_embedding=True,
            speakers_file=os.path.join(get_tests_data_path(), "ljspeech",
                                       "speakers.json"),
        ))
        model = Vits.init_from_config(config, verbose=False).to(device)
        self.assertEqual(model.num_speakers, 10)
        self.assertTrue(hasattr(model, "emb_g"))

        config = VitsConfig(model_args=VitsArgs(
            num_chars=32,
            use_d_vector_file=True,
            d_vector_dim=256,
            d_vector_file=os.path.join(get_tests_data_path(),
                                       "dummy_speakers.json"),
        ))
        model = Vits.init_from_config(config, verbose=False).to(device)
        self.assertTrue(model.num_speakers == 1)
        self.assertTrue(not hasattr(model, "emb_g"))
        self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)
Esempio n. 5
0
    def test_train_eval_log(self):
        batch_size = 2
        config = VitsConfig(
            model_args=VitsArgs(num_chars=32, spec_segment_size=10))
        model = Vits.init_from_config(config, verbose=False).to(device)
        model.run_data_dep_init = False
        model.train()
        batch = self._create_batch(config, batch_size)
        logger = TensorboardLogger(log_dir=os.path.join(
            get_tests_output_path(), "dummy_vits_logs"),
                                   model_name="vits_test_train_log")
        criterion = model.get_criterion()
        criterion = [criterion[0].to(device), criterion[1].to(device)]
        outputs = [None] * 2
        outputs[0], _ = model.train_step(batch, criterion, 0)
        outputs[1], _ = model.train_step(batch, criterion, 1)
        model.train_log(batch, outputs, logger, None, 1)

        model.eval_log(batch, outputs, logger, None, 1)
        logger.finish()
Esempio n. 6
0
 def test_d_vector_forward(self):
     batch_size = 2
     args = VitsArgs(
         spec_segment_size=10,
         num_chars=32,
         use_d_vector_file=True,
         d_vector_dim=256,
         d_vector_file=os.path.join(get_tests_data_path(),
                                    "dummy_speakers.json"),
     )
     config = VitsConfig(model_args=args)
     model = Vits.init_from_config(config, verbose=False).to(device)
     model.train()
     input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(
         config, batch_size=batch_size)
     d_vectors = torch.randn(batch_size, 256).to(device)
     output_dict = model.forward(input_dummy,
                                 input_lengths,
                                 spec,
                                 spec_lengths,
                                 waveform,
                                 aux_input={"d_vectors": d_vectors})
     self._check_forward_outputs(config, output_dict)
Esempio n. 7
0
 def test_get_criterion(self):
     config = VitsConfig(VitsArgs(num_chars=32))
     model = Vits.init_from_config(config, verbose=False).to(device)
     criterion = model.get_criterion()
     self.assertTrue(criterion is not None)