def test_online_tokenizer_config(self): """this just tests that the online tokenizer files get correctly fetched and loaded via its tokenizer_config.json and it's not slow so it's run by normal CI """ tokenizer = FSMTTokenizer.from_pretrained(FSMT_TINY2) self.assertListEqual([tokenizer.src_lang, tokenizer.tgt_lang], ["en", "ru"]) self.assertEqual(tokenizer.src_vocab_size, 21) self.assertEqual(tokenizer.tgt_vocab_size, 21)
def test_tokenizer_lower(self): tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en", do_lower_case=True) tokens = tokenizer.tokenize("USA is United States of America") expected = [ "us", "a</w>", "is</w>", "un", "i", "ted</w>", "st", "ates</w>", "of</w>", "am", "er", "ica</w>" ] self.assertListEqual(tokens, expected)
def translate(src, tgt, text): # to switch to local model #mname = "/code/huggingface/transformers-fair-wmt/data/wmt19-{src}-{tgt}" # s3 uploaded model mname = f"stas/wmt19-{src}-{tgt}" tokenizer = FSMTTokenizer.from_pretrained(mname) model = FSMTForConditionalGeneration.from_pretrained(mname) encoded = tokenizer.encode(text, return_tensors='pt') # print(encoded) output = model.generate(encoded, num_beams=5, early_stopping=True)[0] # print(output) decoded = tokenizer.decode(output, skip_special_tokens=True) #print(decoded) return decoded
def tokenizer_en_ru(self): return FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru")
def tokenizer_ru_en(self): return FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en")
if k in d1 and k in d2: if not cmp_func(d1[k], d2[k]): ok = 0 print(f"! Key {k} mismatches:") if d1[k].shape != d2[k].shape: print(f"- Shapes: \n{d1[k].shape}\n{d2[k].shape}") print(f"- Values:\n{d1[k]}\n{d2[k]}\n") else: ok = 0 which = "1st" if k in d2 else "2nd" print(f"{which} dict doesn't have key {k}\n") if ok: print('Models match') tokenizer = FSMTTokenizer.from_pretrained(mname) model = FSMTForConditionalGeneration.from_pretrained(mname) # this fixes the problem import torch d2 = torch.load("/tmp/new.pt") compare_state_dicts(model.state_dict(), d2) #model.load_state_dict(d2) #model.load_state_dict(torch.load("/tmp/new.pt")) print("Wrong shape?", model.state_dict()['model.decoder.embed_tokens.weight'].shape) sentence = "Машинное обучение - это здорово! Ты молодец." input_ids = tokenizer.encode(sentence, return_tensors='pt') print(input_ids)
#!/usr/bin/env python # coding: utf-8 # this script just does a decode of outputs codes import sys sys.path.insert(0, "/code/huggingface/transformers-fair-wmt/src") from transformers.tokenization_fsmt import FSMTTokenizer tokenizer = FSMTTokenizer.from_pretrained('stas/wmt19-ru-en') outputs = [[ 2, 5494, 3221, 21, 1054, 427, 739, 4952, 11, 700, 18128, 7, 2]] decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) print(decoded)