Beispiel #1
0
    def test_multispeaker_inference(self):
        num_speakers = 10
        config = VitsConfig(num_speakers=num_speakers,
                            use_speaker_embedding=True)
        model = Vits(config).to(device)

        batch_size = 1
        input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers,
                                    (batch_size, )).long().to(device)
        outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids})
        self._check_inference_outputs(config,
                                      outputs,
                                      input_dummy,
                                      batch_size=batch_size)

        batch_size = 2
        input_dummy, input_lengths, *_ = self._create_inputs(
            config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers,
                                    (batch_size, )).long().to(device)
        outputs = model.inference(input_dummy, {
            "x_lengths": input_lengths,
            "speaker_ids": speaker_ids
        })
        self._check_inference_outputs(config,
                                      outputs,
                                      input_dummy,
                                      batch_size=batch_size)
Beispiel #2
0
    def test_multilingual_inference(self):
        num_speakers = 10
        num_langs = 3
        args = VitsArgs(language_ids_file=LANG_FILE,
                        use_language_embedding=True,
                        spec_segment_size=10)
        config = VitsConfig(num_speakers=num_speakers,
                            use_speaker_embedding=True,
                            model_args=args)
        model = Vits(config).to(device)

        input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
        speaker_ids = torch.randint(0, num_speakers, (1, )).long().to(device)
        lang_ids = torch.randint(0, num_langs, (1, )).long().to(device)
        _ = model.inference(input_dummy, {
            "speaker_ids": speaker_ids,
            "language_ids": lang_ids
        })

        batch_size = 1
        input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers,
                                    (batch_size, )).long().to(device)
        lang_ids = torch.randint(0, num_langs,
                                 (batch_size, )).long().to(device)
        outputs = model.inference(input_dummy, {
            "speaker_ids": speaker_ids,
            "language_ids": lang_ids
        })
        self._check_inference_outputs(config,
                                      outputs,
                                      input_dummy,
                                      batch_size=batch_size)

        batch_size = 2
        input_dummy, input_lengths, *_ = self._create_inputs(
            config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers,
                                    (batch_size, )).long().to(device)
        lang_ids = torch.randint(0, num_langs,
                                 (batch_size, )).long().to(device)
        outputs = model.inference(
            input_dummy, {
                "x_lengths": input_lengths,
                "speaker_ids": speaker_ids,
                "language_ids": lang_ids
            })
        self._check_inference_outputs(config,
                                      outputs,
                                      input_dummy,
                                      batch_size=batch_size)
Beispiel #3
0
 def test_inference(self):
     num_speakers = 0
     config = VitsConfig(num_speakers=num_speakers,
                         use_speaker_embedding=True)
     input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
     model = Vits(config).to(device)
     _ = model.inference(input_dummy)
Beispiel #4
0
    def test_inference(self):
        num_speakers = 0
        config = VitsConfig(num_speakers=num_speakers,
                            use_speaker_embedding=True)
        model = Vits(config).to(device)

        batch_size = 1
        input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
        outputs = model.inference(input_dummy)
        self._check_inference_outputs(config,
                                      outputs,
                                      input_dummy,
                                      batch_size=batch_size)

        batch_size = 2
        input_dummy, input_lengths, *_ = self._create_inputs(
            config, batch_size=batch_size)
        outputs = model.inference(input_dummy,
                                  aux_input={"x_lengths": input_lengths})
        self._check_inference_outputs(config,
                                      outputs,
                                      input_dummy,
                                      batch_size=batch_size)
Beispiel #5
0
 def test_multilingual_inference(self):
     num_speakers = 10
     num_langs = 3
     args = VitsArgs(language_ids_file=LANG_FILE,
                     use_language_embedding=True,
                     spec_segment_size=10)
     config = VitsConfig(num_speakers=num_speakers,
                         use_speaker_embedding=True,
                         model_args=args)
     input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
     speaker_ids = torch.randint(0, num_speakers, (1, )).long().to(device)
     lang_ids = torch.randint(0, num_langs, (1, )).long().to(device)
     model = Vits(config).to(device)
     _ = model.inference(input_dummy, {
         "speaker_ids": speaker_ids,
         "language_ids": lang_ids
     })