def get_model(self, mname): if mname not in self.models_cache: self.models_cache[ mname] = FSMTForConditionalGeneration.from_pretrained( mname).to(torch_device) if torch_device == "cuda": self.models_cache[mname].half() return self.models_cache[mname]
pairs = [ ["en", "ru"], ["ru", "en"], ["en", "de"], ["de", "en"], ] for src, tgt in pairs: print(f"Testing {src} -> {tgt}") # to switch to local model #mname = "/code/huggingface/transformers-fair-wmt/data/wmt19-{src}-{tgt}" # s3 uploaded model mname = f"stas/wmt19-{src}-{tgt}" src_sentence = text[src] tgt_sentence = text[tgt] tokenizer = FSMTTokenizer.from_pretrained(mname) model = FSMTForConditionalGeneration.from_pretrained(mname) encoded = tokenizer.encode(src_sentence, return_tensors='pt') #print(encoded) outputs = model.generate(encoded) #print(outputs) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) #print(decoded) assert decoded == tgt_sentence, f"\n\ngot: {decoded}\nexp: {tgt_sentence}\n"
def get_model(self, mname): model = FSMTForConditionalGeneration.from_pretrained(mname).to( torch_device) if torch_device == "cuda": model.half() return model