def get_test_pipeline(self, model, tokenizer, feature_extractor): if isinstance(model.config, MBartConfig): src_lang, tgt_lang = list(tokenizer.lang_code_to_id.keys())[:2] translator = TranslationPipeline(model=model, tokenizer=tokenizer, src_lang=src_lang, tgt_lang=tgt_lang) else: translator = TranslationPipeline(model=model, tokenizer=tokenizer) return translator, ["Some string", "Some other text"]
def test_pipeline(self): pipeline = TranslationPipeline(self.model, self.tokenizer, framework="tf") output = pipeline(self.src_text) self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
def translate_model(text: str): tokenizer = AutoTokenizer.from_pretrained("t5-small") model = AutoModelWithLMHead.from_pretrained("t5-small") use_gpu = 0 if torch.cuda.is_available() else -1 pipeline = TranslationPipeline(model, tokenizer, task="translation_en_to_fr", device=use_gpu) return pipeline(text)[0]['translation_text']
def test_pipeline(self): device = 0 if torch_device == "cuda" else -1 pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=device) output = pipeline(self.src_text) self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
def run_pipeline_test(self, model, tokenizer, feature_extractor): translator = TranslationPipeline(model=model, tokenizer=tokenizer) try: outputs = translator("Some string") except ValueError: # Triggered by m2m langages src_lang, tgt_lang = list(translator.tokenizer.lang_code_to_id.keys())[:2] outputs = translator("Some string", src_lang=src_lang, tgt_lang=tgt_lang) self.assertEqual(outputs, [{"translation_text": ANY(str)}])