def masked_mlm(): from transformers import ReformerConfig, ReformerForMaskedLM config = ReformerConfig.from_pretrained('google/reformer-enwik8') config.is_decoder = False model = ReformerForMaskedLM.from_pretrained('google/reformer-enwik8', config=config) sentence = sentence2 = "The quick brown fox jumps over the lazy dog." input_ids, attention_masks = encode([sentence]) if True: _input_ids, a = input_ids.clone(), attention_masks.clone() for i in [19, 27, 37]: a[0, i] = 0 sentence2 = sentence2[:i] + "%" + sentence2[i + 1:] else: _input_ids, a = input_ids, attention_masks f = model.forward(input_ids=_input_ids, position_ids=None, attention_mask=a, head_mask=None, inputs_embeds=None, num_hashes=None, labels=_input_ids) prediction = decode(torch.argmax(f.logits, 2))[0] print(sentence2) print(prediction)
from transformers import ReformerConfig, ReformerForMaskedLM, ReformerTokenizer, LineByLineTextDataset,DataCollatorForLanguageModeling from reformer_utils import encode, decode, CharTokenizer import torch from general_tools.utils import get_root from pathlib import Path ROOT = get_root("internn") MSK = 44 config = ReformerConfig.from_pretrained('google/reformer-enwik8') config.is_decoder = False model = ReformerForMaskedLM.from_pretrained('google/reformer-enwik8', config=config) sentence = "The quick brown fox jumps over the lazy dog." input_ids, attention_masks = encode([sentence]) label_ids, _ = encode([sentence]) for idx in [10,21,26,32,35]: input_ids[0,idx] = MSK attention_masks[0,idx] = 0 f = model.forward(input_ids=input_ids, position_ids=None, attention_mask=attention_masks, head_mask=None, inputs_embeds=None, num_hashes=None, labels=label_ids) loss = f.loss prediction = decode(torch.argmax(f.logits, 2))[0] print(prediction)