def test_multilingual_forward(self): num_speakers = 10 num_langs = 3 batch_size = 2 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs( config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size, )).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size, )).long().to(device) model = Vits(config).to(device) output_dict = model.forward( input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={ "speaker_ids": speaker_ids, "language_ids": lang_ids }, ) self._check_forward_outputs(config, output_dict)
def test_forward(self): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs( config) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) self._check_forward_outputs(config, output_dict)
def test_secl_forward(self): num_speakers = 10 num_langs = 3 batch_size = 2 speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG) speaker_encoder_config.model_params["use_torch_spec"] = True speaker_encoder = setup_encoder_model(speaker_encoder_config).to( device) speaker_manager = SpeakerManager() speaker_manager.encoder = speaker_encoder args = VitsArgs( language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10, use_speaker_encoder_as_loss=True, ) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs( config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size, )).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size, )).long().to(device) model = Vits(config, speaker_manager=speaker_manager).to(device) output_dict = model.forward( input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={ "speaker_ids": speaker_ids, "language_ids": lang_ids }, ) self._check_forward_outputs(config, output_dict, speaker_encoder_config)
def test_multispeaker_forward(self): num_speakers = 10 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs( config) speaker_ids = torch.randint(0, num_speakers, (8, )).long().to(device) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids}) self._check_forward_outputs(config, output_dict)