예제 #1
0
from transformers import BertGenerationTokenizer, BertGenerationDecoder, BertGenerationConfig
import torch

tokenizer = BertGenerationTokenizer.from_pretrained(
    'google/bert_for_seq_generation_L-24_bbc_encoder')
config = BertGenerationConfig.from_pretrained(
    "google/bert_for_seq_generation_L-24_bbc_encoder")
config.is_decoder = True
model = BertGenerationDecoder.from_pretrained(
    'google/bert_for_seq_generation_L-24_bbc_encoder',
    config=config,
    return_dict=True)

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)

prediction_logits = outputs.logits
예제 #2
0
import torch
from transformers import BertGenerationConfig, BertGenerationEncoder, BertGenerationDecoder, BertTokenizer, BertGenerationTokenizer, \
    DistilBertModel, DistilBertForMaskedLM, DistilBertTokenizer, DistilBertConfig, \
    DataCollatorForLanguageModeling, Trainer, TrainingArguments, EncoderDecoderModel
from datasets import load_dataset

model_name = 'distilbert-base-multilingual-cased'
tokenizer_name = 'distilbert-base-multilingual-cased'

config = BertGenerationConfig.from_pretrained(model_name)
tokenizer = BertGenerationTokenizer.from_pretrained(tokenizer_name)

# leverage checkpoints for Bert2Bert model...
# use BERT's cls token as BOS token and sep token as EOS token
encoder = BertGenerationEncoder.from_pretrained("bert-large-uncased")
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder = BertGenerationDecoder.from_pretrained("bert-large-uncased",
                                                add_cross_attention=True,
                                                is_decoder=True)
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
# create tokenizer...
tokenizer = DistilBertTokenizer.from_pretrained("bert-large-uncased")
input_ids = tokenizer('This is a long article to summarize',
                      add_special_tokens=False,
                      return_tensors="pt").input_ids
labels = tokenizer('This is a short summary', return_tensors="pt").input_ids
# train...
# loss = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels).loss
# loss.backward()

config.attention_type = 'performer'