Beispiel #1
0
    def test_inference_encoder_large(self):
        model = UniSpeechSatModel.from_pretrained(
            "microsoft/unispeech-sat-large")
        model.to(torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-large-xlsr-53")
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
            )

        # fmt: off
        expected_hidden_states_slice = torch.tensor(
            [[[-0.1172, -0.0797], [-0.0012, 0.0213]],
             [[-0.1225, -0.1277], [-0.0668, -0.0585]]],
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(
            torch.allclose(outputs.last_hidden_state[:, :2, -2:],
                           expected_hidden_states_slice,
                           atol=1e-3))
Beispiel #2
0
    def test_inference_encoder_base(self):
        model = UniSpeechSatModel.from_pretrained(
            "microsoft/unispeech-sat-base-plus")
        model.to(torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base", return_attention_mask=True)
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
            )

        # fmt: off
        expected_hidden_states_slice = torch.tensor(
            [[[-0.0743, 0.1384], [-0.0845, 0.1704]],
             [[-0.0954, 0.1936], [-0.1123, 0.2095]]],
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(
            torch.allclose(outputs.last_hidden_state[:, :2, -2:],
                           expected_hidden_states_slice,
                           atol=1e-3))
Beispiel #3
0
    def create_and_check_batch_inference(self, config, input_values, *args):
        # test does not pass for models making use of `group_norm`
        # check: https://github.com/pytorch/fairseq/issues/3227
        model = UniSpeechSatModel(config=config)
        model.to(torch_device)
        model.eval()

        input_values = input_values[:3]
        attention_mask = torch.ones(input_values.shape,
                                    device=torch_device,
                                    dtype=torch.bool)

        input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]

        # pad input
        for i in range(len(input_lengths)):
            input_values[i, input_lengths[i]:] = 0.0
            attention_mask[i, input_lengths[i]:] = 0.0

        batch_outputs = model(input_values,
                              attention_mask=attention_mask).last_hidden_state

        for i in range(input_values.shape[0]):
            input_slice = input_values[i:i + 1, :input_lengths[i]]
            output = model(input_slice).last_hidden_state

            batch_output = batch_outputs[i:i + 1, :output.shape[1]]
            self.parent.assertTrue(
                torch.allclose(output, batch_output, atol=1e-3))
Beispiel #4
0
 def create_and_check_model(self, config, input_values, attention_mask):
     model = UniSpeechSatModel(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_values, attention_mask=attention_mask)
     self.parent.assertEqual(
         result.last_hidden_state.shape,
         (self.batch_size, self.output_seq_length, self.hidden_size))
Beispiel #5
0
 def test_model_from_pretrained(self):
     model = UniSpeechSatModel.from_pretrained(
         "microsoft/unispeech-sat-large")
     self.assertIsNotNone(model)