Exemplo n.º 1
0
 def create_and_check_bert_for_pretraining(
     self,
     config,
     input_ids,
     token_type_ids,
     input_mask,
     sequence_labels,
     token_labels,
     choice_labels,
 ):
     model = BertForPreTraining(config=config)
     model.eval()
     loss, prediction_scores, seq_relationship_score = model(
         input_ids, token_type_ids, input_mask, token_labels,
         sequence_labels)
     result = {
         "loss": loss,
         "prediction_scores": prediction_scores,
         "seq_relationship_score": seq_relationship_score,
     }
     self.parent.assertListEqual(
         list(result["prediction_scores"].size()),
         [self.batch_size, self.seq_length, self.vocab_size],
     )
     self.parent.assertListEqual(
         list(result["seq_relationship_score"].size()),
         [self.batch_size, 2])
     self.check_loss_output(result)
Exemplo n.º 2
0
from pytorch_transformers import (
    WEIGHTS_NAME, AdamW, WarmupLinearSchedule, BertConfig, BertForMaskedLM,
    BertTokenizer, BertForPreTraining, GPT2Config, GPT2LMHeadModel,
    GPT2Tokenizer, OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
    RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)

## extract last layer attention ??

config = BertConfig.from_pretrained('bert-base-uncased')
config.output_attentions = True
config.output_hidden_states = True

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model = BertForPreTraining(config)
model.eval()

input_ids1 = tokenizer.encode("Hello, my dog is cute")  # Batch size 1
input_ids2 = tokenizer.encode("Hello, my dog is one")
input_ids = torch.tensor([input_ids1, input_ids2])
outputs = model(input_ids)

word_dot_distance = torch.randn(2, 1, 4, 3)  ## 2 batch
word_word_relation = torch.LongTensor(
    np.round(np.random.uniform(size=(2, 1, 4, 4), low=0, high=2)))
out = torch.gather(word_dot_distance, dim=3, index=word_word_relation)

distance_type = nn.Embedding(3, 5, padding_idx=0)
distance_type.weight

hidden = torch.randn(2, 3, 4, 5)  ## 2 batch, 3 heads, 4 words, vec=5