def setup(self): # called on every GPU self.dataset = load_dataset("wmt14", "de-en", "val") self.tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-de-en") self.vocab_size = self.tokenizer.vocab_size val_len = len(self.dataset["validation"]) n_val = int(val_len*self.val_fraction) n_train = val_len - n_val for key in ["validation", "test"]: self.dataset[key] = self.dataset[key].map(self.tokenize) self.dataset["train"], self.dataset["validation"] = random_split(self.dataset["validation"], [n_train, n_val])
def get_tokenizer(self, mname): if mname not in self.tokenizers_cache: self.tokenizers_cache[mname] = FSMTTokenizer.from_pretrained(mname) return self.tokenizers_cache[mname]
# This script creates a super tiny model that is useful inside tests, when we just want to test that # the machinery works, without needing to the check the quality of the outcomes. # # This version creates a tiny model through reduction of a normal pre-trained model, but keeping the # full vocab, merges file, and thus also resulting in a larger model due to a large vocab size. # This gives ~3MB in total for all files. # # If you want a 50 times smaller than this see `fsmt-make-super-tiny-model.py`, which is slightly more complicated # # # It will be used then as "stas/tiny-wmt19-en-de" # Build from transformers import FSMTTokenizer, FSMTConfig, FSMTForConditionalGeneration mname = "facebook/wmt19-en-de" tokenizer = FSMTTokenizer.from_pretrained(mname) # get the correct vocab sizes, etc. from the master model config = FSMTConfig.from_pretrained(mname) config.update(dict( d_model=4, encoder_layers=1, decoder_layers=1, encoder_ffn_dim=4, decoder_ffn_dim=4, encoder_attention_heads=1, decoder_attention_heads=1)) tiny_model = FSMTForConditionalGeneration(config) print(f"num of params {tiny_model.num_parameters()}") # Test batch = tokenizer(["Making tiny model"], return_tensors="pt") outputs = tiny_model(**batch)
def get_tokenizer(self, mname): return FSMTTokenizer.from_pretrained(mname)
def prepare_data(self): # called only on 1 GPU load_dataset("wmt14", "de-en", "val") FSMTTokenizer.from_pretrained("facebook/wmt19-de-en")