Beispiel #1
0
    def test_multilingual_forward(self):
        num_speakers = 10
        num_langs = 3
        batch_size = 2

        args = VitsArgs(language_ids_file=LANG_FILE,
                        use_language_embedding=True,
                        spec_segment_size=10)
        config = VitsConfig(num_speakers=num_speakers,
                            use_speaker_embedding=True,
                            model_args=args)

        input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(
            config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers,
                                    (batch_size, )).long().to(device)
        lang_ids = torch.randint(0, num_langs,
                                 (batch_size, )).long().to(device)

        model = Vits(config).to(device)
        output_dict = model.forward(
            input_dummy,
            input_lengths,
            spec,
            spec_lengths,
            waveform,
            aux_input={
                "speaker_ids": speaker_ids,
                "language_ids": lang_ids
            },
        )
        self._check_forward_outputs(config, output_dict)
Beispiel #2
0
 def test_forward(self):
     num_speakers = 0
     config = VitsConfig(num_speakers=num_speakers,
                         use_speaker_embedding=True)
     config.model_args.spec_segment_size = 10
     input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(
         config)
     model = Vits(config).to(device)
     output_dict = model.forward(input_dummy, input_lengths, spec,
                                 spec_lengths, waveform)
     self._check_forward_outputs(config, output_dict)
Beispiel #3
0
    def test_secl_forward(self):
        num_speakers = 10
        num_langs = 3
        batch_size = 2

        speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
        speaker_encoder_config.model_params["use_torch_spec"] = True
        speaker_encoder = setup_encoder_model(speaker_encoder_config).to(
            device)
        speaker_manager = SpeakerManager()
        speaker_manager.encoder = speaker_encoder

        args = VitsArgs(
            language_ids_file=LANG_FILE,
            use_language_embedding=True,
            spec_segment_size=10,
            use_speaker_encoder_as_loss=True,
        )
        config = VitsConfig(num_speakers=num_speakers,
                            use_speaker_embedding=True,
                            model_args=args)
        config.audio.sample_rate = 16000

        input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(
            config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers,
                                    (batch_size, )).long().to(device)
        lang_ids = torch.randint(0, num_langs,
                                 (batch_size, )).long().to(device)

        model = Vits(config, speaker_manager=speaker_manager).to(device)
        output_dict = model.forward(
            input_dummy,
            input_lengths,
            spec,
            spec_lengths,
            waveform,
            aux_input={
                "speaker_ids": speaker_ids,
                "language_ids": lang_ids
            },
        )
        self._check_forward_outputs(config, output_dict,
                                    speaker_encoder_config)
Beispiel #4
0
    def test_multispeaker_forward(self):
        num_speakers = 10

        config = VitsConfig(num_speakers=num_speakers,
                            use_speaker_embedding=True)
        config.model_args.spec_segment_size = 10

        input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(
            config)
        speaker_ids = torch.randint(0, num_speakers, (8, )).long().to(device)

        model = Vits(config).to(device)
        output_dict = model.forward(input_dummy,
                                    input_lengths,
                                    spec,
                                    spec_lengths,
                                    waveform,
                                    aux_input={"speaker_ids": speaker_ids})
        self._check_forward_outputs(config, output_dict)