def get_test_pipeline(self, model, tokenizer, feature_extractor):
     if isinstance(model.config, MBartConfig):
         src_lang, tgt_lang = list(tokenizer.lang_code_to_id.keys())[:2]
         translator = TranslationPipeline(model=model, tokenizer=tokenizer, src_lang=src_lang, tgt_lang=tgt_lang)
     else:
         translator = TranslationPipeline(model=model, tokenizer=tokenizer)
     return translator, ["Some string", "Some other text"]
コード例 #2
0
 def test_pipeline(self):
     pipeline = TranslationPipeline(self.model,
                                    self.tokenizer,
                                    framework="tf")
     output = pipeline(self.src_text)
     self.assertEqual(self.expected_text,
                      [x["translation_text"] for x in output])
コード例 #3
0
def translate_model(text: str):
    tokenizer = AutoTokenizer.from_pretrained("t5-small")
    model = AutoModelWithLMHead.from_pretrained("t5-small")
    use_gpu = 0 if torch.cuda.is_available() else -1
    pipeline = TranslationPipeline(model, tokenizer, task="translation_en_to_fr", device=use_gpu)

    return pipeline(text)[0]['translation_text']
コード例 #4
0
 def test_pipeline(self):
     device = 0 if torch_device == "cuda" else -1
     pipeline = TranslationPipeline(self.model,
                                    self.tokenizer,
                                    framework="pt",
                                    device=device)
     output = pipeline(self.src_text)
     self.assertEqual(self.expected_text,
                      [x["translation_text"] for x in output])
コード例 #5
0
 def run_pipeline_test(self, model, tokenizer, feature_extractor):
     translator = TranslationPipeline(model=model, tokenizer=tokenizer)
     try:
         outputs = translator("Some string")
     except ValueError:
         # Triggered by m2m langages
         src_lang, tgt_lang = list(translator.tokenizer.lang_code_to_id.keys())[:2]
         outputs = translator("Some string", src_lang=src_lang, tgt_lang=tgt_lang)
     self.assertEqual(outputs, [{"translation_text": ANY(str)}])