def test_multispeaker_inference(self): num_speakers = 10 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) model = Vits(config).to(device) batch_size = 1 input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size, )).long().to(device) outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids}) self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) batch_size = 2 input_dummy, input_lengths, *_ = self._create_inputs( config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size, )).long().to(device) outputs = model.inference(input_dummy, { "x_lengths": input_lengths, "speaker_ids": speaker_ids }) self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
def test_multilingual_inference(self): num_speakers = 10 num_langs = 3 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) model = Vits(config).to(device) input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) speaker_ids = torch.randint(0, num_speakers, (1, )).long().to(device) lang_ids = torch.randint(0, num_langs, (1, )).long().to(device) _ = model.inference(input_dummy, { "speaker_ids": speaker_ids, "language_ids": lang_ids }) batch_size = 1 input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size, )).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size, )).long().to(device) outputs = model.inference(input_dummy, { "speaker_ids": speaker_ids, "language_ids": lang_ids }) self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) batch_size = 2 input_dummy, input_lengths, *_ = self._create_inputs( config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size, )).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size, )).long().to(device) outputs = model.inference( input_dummy, { "x_lengths": input_lengths, "speaker_ids": speaker_ids, "language_ids": lang_ids }) self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
def test_inference(self): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) model = Vits(config).to(device) _ = model.inference(input_dummy)
def test_inference(self): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) model = Vits(config).to(device) batch_size = 1 input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) outputs = model.inference(input_dummy) self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) batch_size = 2 input_dummy, input_lengths, *_ = self._create_inputs( config, batch_size=batch_size) outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths}) self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
def test_multilingual_inference(self): num_speakers = 10 num_langs = 3 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) speaker_ids = torch.randint(0, num_speakers, (1, )).long().to(device) lang_ids = torch.randint(0, num_langs, (1, )).long().to(device) model = Vits(config).to(device) _ = model.inference(input_dummy, { "speaker_ids": speaker_ids, "language_ids": lang_ids })