class ElectraForPreTraining(AlbertPreTrainedModel): def __init__(self, config): super().__init__(config) self.electra = AlbertModel(config) self.discriminator_predictions = ElectraDiscriminatorPredictions( config) self.init_weights() self.loss_fct = nn.BCEWithLogitsLoss(reduction='none') def get_input_embeddings(self): return self.electra.get_input_embeddings() def get_output_embeddings(self): return self.discriminator_predictions def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, ): discriminator_hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, ) discriminator_sequence_output = discriminator_hidden_states[0] logits = self.discriminator_predictions(discriminator_sequence_output) output = (logits, ) if labels is not None: losses = ((self.loss_fct(logits.squeeze(-1), labels.float()) * attention_mask).sum(1) / (1e-6 + attention_mask.sum(1))).sum() output = (losses, ) + output return output # (loss), scores, (hidden_states), (attentions)
class ElectraForMaskedLM(AlbertPreTrainedModel): def __init__(self, config): super().__init__(config) self.electra = AlbertModel(config) self.generator_predictions = ElectraGeneratorPredictions(config) self.loss_fct = nn.CrossEntropyLoss( reduction='none') # -100 index = padding token self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size) self.init_weights() def get_input_embeddings(self): return self.electra.get_input_embeddings() def get_output_embeddings(self): return self.generator_lm_head def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, **kwargs): if "masked_lm_labels" in kwargs: warnings.warn( "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", FutureWarning, ) labels = kwargs.pop("masked_lm_labels") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." generator_hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, ) generator_sequence_output = generator_hidden_states[0] prediction_scores = self.generator_predictions( generator_sequence_output) prediction_scores = self.generator_lm_head(prediction_scores) loss = None # Masked language modeling softmax layer if labels is not None: ''' if prediction_scores.device != torch.device('cpu'): labels[labels<0] = 0 loss = cross_entropy( prediction_scores.view(-1, prediction_scores.size(-1)), labels.view(-1), reduction='sum', )''' per_example_loss = self.loss_fct( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)).view(labels.shape[0], labels.shape[1]) loss = ((per_example_loss * attention_mask).sum(1) / (1e-6 + attention_mask.sum(1))).sum() #loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) output = (prediction_scores, ) return ((loss, ) + output) if loss is not None else output