예제 #1
0
    def test_inference_large(self):
        model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(
            torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "microsoft/wavlm-large", return_attention_mask=True)

        input_speech = self._load_datasamples(2)

        inputs = feature_extractor(input_speech,
                                   return_tensors="pt",
                                   padding=True)

        input_values = inputs.input_values.to(torch_device)
        attention_mask = inputs.attention_mask.to(torch_device)

        with torch.no_grad():
            hidden_states_slice = (model(
                input_values,
                attention_mask=attention_mask).last_hidden_state[:, -2:,
                                                                 -2:].cpu())

        EXPECTED_HIDDEN_STATES_SLICE = torch.tensor([[[0.2122, 0.0500],
                                                      [0.2118, 0.0563]],
                                                     [[0.1353, 0.1818],
                                                      [0.2453, 0.0595]]])

        self.assertTrue(
            torch.allclose(hidden_states_slice,
                           EXPECTED_HIDDEN_STATES_SLICE,
                           rtol=5e-2))
예제 #2
0
    def test_inference_base(self):
        model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(
            torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "microsoft/wavlm-base-plus", return_attention_mask=True)

        input_speech = self._load_datasamples(2)

        inputs = feature_extractor(input_speech,
                                   return_tensors="pt",
                                   padding=True)

        input_values = inputs.input_values.to(torch_device)
        attention_mask = inputs.attention_mask.to(torch_device)

        with torch.no_grad():
            hidden_states_slice = (model(
                input_values,
                attention_mask=attention_mask).last_hidden_state[:, -2:,
                                                                 -2:].cpu())

        EXPECTED_HIDDEN_STATES_SLICE = torch.tensor([[[0.0577, 0.1161],
                                                      [0.0579, 0.1165]],
                                                     [[0.0199, 0.1237],
                                                      [0.0059, 0.0605]]])
        # TODO: update the tolerance after the CI moves to torch 1.10
        self.assertTrue(
            torch.allclose(hidden_states_slice,
                           EXPECTED_HIDDEN_STATES_SLICE,
                           atol=5e-2))
예제 #3
0
 def test_model_from_pretrained(self):
     model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")
     self.assertIsNotNone(model)