def test_inference_pretrained_batched(self): model = SEWModel.from_pretrained("asapp/sew-tiny-100k").to(torch_device) processor = Wav2Vec2FeatureExtractor.from_pretrained("asapp/sew-tiny-100k") input_speech = self._load_datasamples(2) inputs = processor(input_speech, return_tensors="pt", padding=True) input_values = inputs.input_values.to(torch_device) with torch.no_grad(): outputs = model(input_values).last_hidden_state # expected outputs taken from the original SEW implementation expected_outputs_first = torch.tensor( [ [ [0.1509, 0.5372, 0.3061, -0.1694], [-0.1700, 0.5764, 0.2753, -0.1299], [0.1281, 0.7949, 0.2342, -0.1624], [-0.1627, 0.6710, 0.2215, -0.1317], ], [ [0.0408, 1.4355, 0.8605, -0.0968], [0.0393, 1.2368, 0.6826, 0.0364], [-0.1269, 1.9215, 1.1677, -0.1297], [-0.1654, 1.6524, 0.6877, -0.0196], ], ], device=torch_device, ) expected_outputs_last = torch.tensor( [ [ [1.3379, -0.1450, -0.1500, -0.0515], [0.8364, -0.1680, -0.1248, -0.0689], [1.2791, -0.1507, -0.1523, -0.0564], [0.8208, -0.1690, -0.1199, -0.0751], ], [ [0.6959, -0.0861, -0.1235, -0.0861], [0.4700, -0.1686, -0.1141, -0.1199], [1.0776, -0.1137, -0.0124, -0.0472], [0.5774, -0.1675, -0.0376, -0.0823], ], ], device=torch_device, ) expected_output_sum = 62146.7422 self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3)) self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3)) self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
def create_and_check_batch_inference(self, config, input_values, *args): # test does not pass for models making use of `group_norm` # check: https://github.com/pytorch/fairseq/issues/3227 model = SEWModel(config=config) model.to(torch_device) model.eval() input_values = input_values[:3] attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool) input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] # pad input for i in range(len(input_lengths)): input_values[i, input_lengths[i] :] = 0.0 attention_mask[i, input_lengths[i] :] = 0.0 batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state for i in range(input_values.shape[0]): input_slice = input_values[i : i + 1, : input_lengths[i]] output = model(input_slice).last_hidden_state batch_output = batch_outputs[i : i + 1, : output.shape[1]] self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
def create_and_check_model(self, config, input_values, attention_mask): model = SEWModel(config=config) model.to(torch_device) model.eval() result = model(input_values, attention_mask=attention_mask) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) )
def test_model_from_pretrained(self): model = SEWModel.from_pretrained("asapp/sew-tiny-100k") self.assertIsNotNone(model)
def convert_sew_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True): """ Copy/paste/tweak model's weights to transformers design. """ if is_finetuned: model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}) else: model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path]) if config_path is not None: config = SEWConfig.from_pretrained(config_path) else: config = convert_config(model[0], is_finetuned) model = model[0].eval() return_attention_mask = True if config.feat_extract_norm == "layer" else False feature_extractor = Wav2Vec2FeatureExtractor( feature_size=1, sampling_rate=16000, padding_value=0, do_normalize=True, return_attention_mask=return_attention_mask, ) if is_finetuned: if dict_path: target_dict = Dictionary.load(dict_path) # important change bos & pad token id since CTC symbol is <pad> and # not <s> as in fairseq target_dict.indices[target_dict.bos_word] = target_dict.pad_index target_dict.indices[target_dict.pad_word] = target_dict.bos_index config.bos_token_id = target_dict.pad_index config.pad_token_id = target_dict.bos_index config.eos_token_id = target_dict.eos_index config.vocab_size = len(target_dict.symbols) vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") if not os.path.isdir(pytorch_dump_folder_path): logger.error( "--pytorch_dump_folder_path ({}) should be a directory". format(pytorch_dump_folder_path)) return os.makedirs(pytorch_dump_folder_path, exist_ok=True) with open(vocab_path, "w", encoding="utf-8") as vocab_handle: json.dump(target_dict.indices, vocab_handle) tokenizer = Wav2Vec2CTCTokenizer( vocab_path, unk_token=target_dict.unk_word, pad_token=target_dict.pad_word, bos_token=target_dict.bos_word, eos_token=target_dict.eos_word, word_delimiter_token="|", do_lower_case=False, ) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) processor.save_pretrained(pytorch_dump_folder_path) hf_model = SEWForCTC(config) else: hf_model = SEWModel(config) feature_extractor.save_pretrained(pytorch_dump_folder_path) recursively_load_weights(model, hf_model, is_finetuned) hf_model.save_pretrained(pytorch_dump_folder_path)