示例#1
0
 def get_model(self, mname):
     if mname not in self.models_cache:
         self.models_cache[
             mname] = FSMTForConditionalGeneration.from_pretrained(
                 mname).to(torch_device)
         if torch_device == "cuda":
             self.models_cache[mname].half()
     return self.models_cache[mname]
示例#2
0
pairs = [
    ["en", "ru"],
    ["ru", "en"],
    ["en", "de"],
    ["de", "en"],
]

for src, tgt in pairs:
    print(f"Testing {src} -> {tgt}")

    # to switch to local model
    #mname = "/code/huggingface/transformers-fair-wmt/data/wmt19-{src}-{tgt}"
    # s3 uploaded model
    mname = f"stas/wmt19-{src}-{tgt}"

    src_sentence = text[src]
    tgt_sentence = text[tgt]

    tokenizer = FSMTTokenizer.from_pretrained(mname)
    model = FSMTForConditionalGeneration.from_pretrained(mname)

    encoded = tokenizer.encode(src_sentence, return_tensors='pt')
    #print(encoded)

    outputs = model.generate(encoded)
    #print(outputs)

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    #print(decoded)
    assert decoded == tgt_sentence, f"\n\ngot: {decoded}\nexp: {tgt_sentence}\n"
 def get_model(self, mname):
     model = FSMTForConditionalGeneration.from_pretrained(mname).to(
         torch_device)
     if torch_device == "cuda":
         model.half()
     return model