def test_inference_ctc_batched(self): # TODO: enable this test once the finetuned models are available model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-100h").to( torch_device) processor = Wav2Vec2Processor.from_pretrained( "asapp/sew-d-tiny-100k-ft-100h", do_lower_case=True) input_speech = self._load_datasamples(2) inputs = processor(input_speech, return_tensors="pt", padding=True) input_values = inputs.input_values.to(torch_device) attention_mask = inputs.attention_mask.to(torch_device) with torch.no_grad(): logits = model(input_values, attention_mask=attention_mask).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_trans = processor.batch_decode(predicted_ids) EXPECTED_TRANSCRIPTIONS = [ "a man said to the universe sir i exist", "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore", ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_ctc_batched(self): model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h").to(torch_device) processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h", do_lower_case=True) input_speech = self._load_datasamples(2) inputs = processor(input_speech, return_tensors="pt", padding=True) input_values = inputs.input_values.to(torch_device) with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_trans = processor.batch_decode(predicted_ids) EXPECTED_TRANSCRIPTIONS = [ "a man said to the universe sir i exist", "swet covered breon's body trickling into the titlowing closs that was the only garmened he war", ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)