コード例 #1
0
 def test_online_tokenizer_config(self):
     """this just tests that the online tokenizer files get correctly fetched and
     loaded via its tokenizer_config.json and it's not slow so it's run by normal CI
     """
     tokenizer = FSMTTokenizer.from_pretrained(FSMT_TINY2)
     self.assertListEqual([tokenizer.src_lang, tokenizer.tgt_lang], ["en", "ru"])
     self.assertEqual(tokenizer.src_vocab_size, 21)
     self.assertEqual(tokenizer.tgt_vocab_size, 21)
コード例 #2
0
 def test_tokenizer_lower(self):
     tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en",
                                               do_lower_case=True)
     tokens = tokenizer.tokenize("USA is United States of America")
     expected = [
         "us", "a</w>", "is</w>", "un", "i", "ted</w>", "st", "ates</w>",
         "of</w>", "am", "er", "ica</w>"
     ]
     self.assertListEqual(tokens, expected)
コード例 #3
0
ファイル: fsmt-paraphrase.py プロジェクト: stjordanis/porting
def translate(src, tgt, text):
    # to switch to local model
    #mname = "/code/huggingface/transformers-fair-wmt/data/wmt19-{src}-{tgt}"
    # s3 uploaded model
    mname = f"stas/wmt19-{src}-{tgt}"
    tokenizer = FSMTTokenizer.from_pretrained(mname)
    model = FSMTForConditionalGeneration.from_pretrained(mname)

    encoded = tokenizer.encode(text, return_tensors='pt')
    # print(encoded)

    output = model.generate(encoded, num_beams=5, early_stopping=True)[0]
    # print(output)

    decoded = tokenizer.decode(output, skip_special_tokens=True)
    #print(decoded)
    return decoded
コード例 #4
0
 def tokenizer_en_ru(self):
     return FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru")
コード例 #5
0
 def tokenizer_ru_en(self):
     return FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en")
コード例 #6
0
        if k in d1 and k in d2:
            if not cmp_func(d1[k], d2[k]):
                ok = 0
                print(f"! Key {k} mismatches:")
                if d1[k].shape != d2[k].shape:
                    print(f"- Shapes: \n{d1[k].shape}\n{d2[k].shape}")
                print(f"- Values:\n{d1[k]}\n{d2[k]}\n")
        else:
            ok = 0
            which = "1st" if k in d2 else "2nd"
            print(f"{which} dict doesn't have key {k}\n")
    if ok:
        print('Models match')


tokenizer = FSMTTokenizer.from_pretrained(mname)
model = FSMTForConditionalGeneration.from_pretrained(mname)

# this fixes the problem
import torch
d2 = torch.load("/tmp/new.pt")
compare_state_dicts(model.state_dict(), d2)
#model.load_state_dict(d2)
#model.load_state_dict(torch.load("/tmp/new.pt"))

print("Wrong shape?",
      model.state_dict()['model.decoder.embed_tokens.weight'].shape)
sentence = "Машинное обучение - это здорово! Ты молодец."

input_ids = tokenizer.encode(sentence, return_tensors='pt')
print(input_ids)
コード例 #7
0
ファイル: fsmt-decode.py プロジェクト: stjordanis/porting
#!/usr/bin/env python
# coding: utf-8

# this script just does a decode of outputs codes

import sys
sys.path.insert(0, "/code/huggingface/transformers-fair-wmt/src")

from transformers.tokenization_fsmt import FSMTTokenizer
tokenizer = FSMTTokenizer.from_pretrained('stas/wmt19-ru-en')

outputs = [[    2,  5494,  3221,    21,  1054,   427,   739,  4952,    11,   700,  18128,     7,     2]]
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded)