class BertPretrainingLoss(BertPreTrainedModel):
    def __init__(self, bert_encoder, config):
        super(BertPretrainingLoss, self).__init__(config)
        self.bert = bert_encoder
        self.cls = BertPreTrainingHeads(
            config, self.bert.embeddings.word_embeddings.weight)
        self.cls.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                masked_lm_labels=None,
                next_sentence_label=None):
        sequence_output, pooled_output = self.bert(
            input_ids,
            token_type_ids,
            attention_mask,
            output_all_encoded_layers=False)
        prediction_scores, seq_relationship_score = self.cls(
            sequence_output, pooled_output)

        if masked_lm_labels is not None and next_sentence_label is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2),
                                          next_sentence_label.view(-1))
            masked_lm_loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                masked_lm_labels.view(-1))
            total_loss = masked_lm_loss + next_sentence_loss
            return total_loss
        else:
            return prediction_scores, seq_relationship_score
	def __init__(self, config, num_labels=2, num_rel=0, num_sentlvl_labels=0, no_nsp=False):
		super(BertForPreTrainingLossMask, self).__init__(config)
		self.bert = BertModel(config)
		self.cls = BertPreTrainingHeads(
			config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels)
		self.num_sentlvl_labels = num_sentlvl_labels
		self.cls2 = None
		if self.num_sentlvl_labels > 0:
			self.secondary_pred_proj = nn.Embedding(
				num_sentlvl_labels, config.hidden_size)
			self.cls2 = BertPreTrainingHeads(
				config, self.secondary_pred_proj.weight, num_labels=num_sentlvl_labels)
		self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none')
		if no_nsp:
			self.crit_next_sent = None
		else:
			self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1)
		self.num_labels = num_labels
		self.num_rel = num_rel
		if self.num_rel > 0:
			self.crit_pair_rel = BertPreTrainingPairRel(
				config, num_rel=num_rel)
		if hasattr(config, 'label_smoothing') and config.label_smoothing:
			self.crit_mask_lm_smoothed = LabelSmoothingLoss(
				config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none')
		else:
			self.crit_mask_lm_smoothed = None
		self.apply(self.init_bert_weights)
		self.bert.rescale_some_parameters()
Exemple #3
0
 def __init__(self, config):
     super(BertForMTPostTraining, self).__init__(config)
     self.bert = BertModel(config)
     self.cls = BertPreTrainingHeads(
         config, self.bert.embeddings.word_embeddings.weight)
     self.qa_outputs = torch.nn.Linear(config.hidden_size, 2)
     self.apply(self.init_bert_weights)
def test_BertPreTrainingHeads():
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000,
                        hidden_size=768,
                        num_hidden_layers=12,
                        num_attention_heads=12,
                        intermediate_size=3072)
    embeddings = BertEmbeddings(config)
    model = BertPreTrainingHeads(config, embeddings.word_embeddings.weight)

    embedding_output = embeddings(input_ids, token_type_ids)
    print(model(embedding_output, embedding_output))
 def __init__(self, bert_encoder, config):
     super(BertPretrainingLoss, self).__init__(config)
     self.bert = bert_encoder
     self.cls = BertPreTrainingHeads(
         config, self.bert.embeddings.word_embeddings.weight)
     self.cls.apply(self.init_bert_weights)
Exemple #6
0
 def __init__(self, config):
     super(BertForPreTraining, self).__init__(config)
     self.bert = BertModel(config)
     self.cls = BertPreTrainingHeads(
         config, self.bert.embeddings.word_embeddings.weight)
     self.apply(self.init_bert_weights)