Esempio n. 1
0
    def test_inference_encoder_base(self):
        model = UniSpeechSatModel.from_pretrained(
            "microsoft/unispeech-sat-base-plus")
        model.to(torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base", return_attention_mask=True)
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
            )

        # fmt: off
        expected_hidden_states_slice = torch.tensor(
            [[[-0.0743, 0.1384], [-0.0845, 0.1704]],
             [[-0.0954, 0.1936], [-0.1123, 0.2095]]],
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(
            torch.allclose(outputs.last_hidden_state[:, :2, -2:],
                           expected_hidden_states_slice,
                           atol=1e-3))
Esempio n. 2
0
    def test_inference_encoder_large(self):
        model = UniSpeechSatModel.from_pretrained(
            "microsoft/unispeech-sat-large")
        model.to(torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "facebook/wav2vec2-large-xlsr-53")
        input_speech = self._load_datasamples(2)

        inputs_dict = feature_extractor(input_speech,
                                        return_tensors="pt",
                                        padding=True)

        with torch.no_grad():
            outputs = model(
                inputs_dict.input_values.to(torch_device),
                attention_mask=inputs_dict.attention_mask.to(torch_device),
            )

        # fmt: off
        expected_hidden_states_slice = torch.tensor(
            [[[-0.1172, -0.0797], [-0.0012, 0.0213]],
             [[-0.1225, -0.1277], [-0.0668, -0.0585]]],
            device=torch_device,
        )
        # fmt: on

        self.assertTrue(
            torch.allclose(outputs.last_hidden_state[:, :2, -2:],
                           expected_hidden_states_slice,
                           atol=1e-3))
Esempio n. 3
0
 def test_model_from_pretrained(self):
     model = UniSpeechSatModel.from_pretrained(
         "microsoft/unispeech-sat-large")
     self.assertIsNotNone(model)