示例#1
0
    def masked_mlm():
        from transformers import ReformerConfig, ReformerForMaskedLM
        config = ReformerConfig.from_pretrained('google/reformer-enwik8')
        config.is_decoder = False
        model = ReformerForMaskedLM.from_pretrained('google/reformer-enwik8',
                                                    config=config)

        sentence = sentence2 = "The quick brown fox jumps over the lazy dog."
        input_ids, attention_masks = encode([sentence])
        if True:
            _input_ids, a = input_ids.clone(), attention_masks.clone()
            for i in [19, 27, 37]:
                a[0, i] = 0
                sentence2 = sentence2[:i] + "%" + sentence2[i + 1:]
        else:
            _input_ids, a = input_ids, attention_masks
        f = model.forward(input_ids=_input_ids,
                          position_ids=None,
                          attention_mask=a,
                          head_mask=None,
                          inputs_embeds=None,
                          num_hashes=None,
                          labels=_input_ids)
        prediction = decode(torch.argmax(f.logits, 2))[0]
        print(sentence2)
        print(prediction)
示例#2
0
from transformers import ReformerConfig, ReformerForMaskedLM, ReformerTokenizer, LineByLineTextDataset,DataCollatorForLanguageModeling
from reformer_utils import encode, decode, CharTokenizer
import torch
from general_tools.utils import get_root
from pathlib import Path
ROOT = get_root("internn")

MSK = 44
config = ReformerConfig.from_pretrained('google/reformer-enwik8')
config.is_decoder = False
model = ReformerForMaskedLM.from_pretrained('google/reformer-enwik8', config=config)

sentence = "The quick brown fox jumps over the lazy dog."

input_ids, attention_masks = encode([sentence])
label_ids, _ = encode([sentence])
for idx in [10,21,26,32,35]:
    input_ids[0,idx] = MSK
    attention_masks[0,idx] = 0

f = model.forward(input_ids=input_ids,
                  position_ids=None,
                  attention_mask=attention_masks,
                  head_mask=None,
                  inputs_embeds=None,
                  num_hashes=None,
                  labels=label_ids)

loss = f.loss
prediction = decode(torch.argmax(f.logits, 2))[0]
print(prediction)