Beispiel #1
0
    def test_inference_large(self):
        model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(
            torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "microsoft/wavlm-large", return_attention_mask=True)

        input_speech = self._load_datasamples(2)

        inputs = feature_extractor(input_speech,
                                   return_tensors="pt",
                                   padding=True)

        input_values = inputs.input_values.to(torch_device)
        attention_mask = inputs.attention_mask.to(torch_device)

        with torch.no_grad():
            hidden_states_slice = (model(
                input_values,
                attention_mask=attention_mask).last_hidden_state[:, -2:,
                                                                 -2:].cpu())

        EXPECTED_HIDDEN_STATES_SLICE = torch.tensor([[[0.2122, 0.0500],
                                                      [0.2118, 0.0563]],
                                                     [[0.1353, 0.1818],
                                                      [0.2453, 0.0595]]])

        self.assertTrue(
            torch.allclose(hidden_states_slice,
                           EXPECTED_HIDDEN_STATES_SLICE,
                           rtol=5e-2))
Beispiel #2
0
    def test_inference_base(self):
        model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(
            torch_device)
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            "microsoft/wavlm-base-plus", return_attention_mask=True)

        input_speech = self._load_datasamples(2)

        inputs = feature_extractor(input_speech,
                                   return_tensors="pt",
                                   padding=True)

        input_values = inputs.input_values.to(torch_device)
        attention_mask = inputs.attention_mask.to(torch_device)

        with torch.no_grad():
            hidden_states_slice = (model(
                input_values,
                attention_mask=attention_mask).last_hidden_state[:, -2:,
                                                                 -2:].cpu())

        EXPECTED_HIDDEN_STATES_SLICE = torch.tensor([[[0.0577, 0.1161],
                                                      [0.0579, 0.1165]],
                                                     [[0.0199, 0.1237],
                                                      [0.0059, 0.0605]]])
        # TODO: update the tolerance after the CI moves to torch 1.10
        self.assertTrue(
            torch.allclose(hidden_states_slice,
                           EXPECTED_HIDDEN_STATES_SLICE,
                           atol=5e-2))
def convert_wavlm_checkpoint(checkpoint_path,
                             pytorch_dump_folder_path,
                             config_path=None):

    # load the pre-trained checkpoints
    checkpoint = torch.load(checkpoint_path)
    cfg = WavLMConfigOrig(checkpoint["cfg"])
    model = WavLMOrig(cfg)
    model.load_state_dict(checkpoint["model"])
    model.eval()

    if config_path is not None:
        config = WavLMConfig.from_pretrained(config_path)
    else:
        config = WavLMConfig()

    hf_wavlm = WavLMModel(config)

    recursively_load_weights(model, hf_wavlm)

    hf_wavlm.save_pretrained(pytorch_dump_folder_path)
Beispiel #4
0
    def create_and_check_batch_inference(self, config, input_values, *args):
        # test does not pass for models making use of `group_norm`
        # check: https://github.com/pytorch/fairseq/issues/3227
        model = WavLMModel(config=config)
        model.to(torch_device)
        model.eval()

        input_values = input_values[:3]
        attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)

        input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]

        # pad input
        for i in range(len(input_lengths)):
            input_values[i, input_lengths[i] :] = 0.0
            attention_mask[i, input_lengths[i] :] = 0.0

        batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state

        for i in range(input_values.shape[0]):
            input_slice = input_values[i : i + 1, : input_lengths[i]]
            output = model(input_slice).last_hidden_state

            batch_output = batch_outputs[i : i + 1, : output.shape[1]]
            self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
Beispiel #5
0
 def create_and_check_model(self, config, input_values, attention_mask):
     model = WavLMModel(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_values, attention_mask=attention_mask)
     self.parent.assertEqual(
         result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
     )
Beispiel #6
0
 def test_model_from_pretrained(self):
     model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")
     self.assertIsNotNone(model)