Esempio n. 1
0
class ElectraForPreTraining(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.electra = AlbertModel(config)
        self.discriminator_predictions = ElectraDiscriminatorPredictions(
            config)
        self.init_weights()
        self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')

    def get_input_embeddings(self):
        return self.electra.get_input_embeddings()

    def get_output_embeddings(self):
        return self.discriminator_predictions

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
    ):
        discriminator_hidden_states = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions,
            output_hidden_states,
        )
        discriminator_sequence_output = discriminator_hidden_states[0]

        logits = self.discriminator_predictions(discriminator_sequence_output)

        output = (logits, )

        if labels is not None:
            losses = ((self.loss_fct(logits.squeeze(-1), labels.float()) *
                       attention_mask).sum(1) /
                      (1e-6 + attention_mask.sum(1))).sum()
            output = (losses, ) + output

        return output  # (loss), scores, (hidden_states), (attentions)
Esempio n. 2
0
class ElectraForMaskedLM(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.electra = AlbertModel(config)
        self.generator_predictions = ElectraGeneratorPredictions(config)

        self.loss_fct = nn.CrossEntropyLoss(
            reduction='none')  # -100 index = padding token

        self.generator_lm_head = nn.Linear(config.embedding_size,
                                           config.vocab_size)
        self.init_weights()

    def get_input_embeddings(self):
        return self.electra.get_input_embeddings()

    def get_output_embeddings(self):
        return self.generator_lm_head

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                **kwargs):
        if "masked_lm_labels" in kwargs:
            warnings.warn(
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("masked_lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

        generator_hidden_states = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions,
            output_hidden_states,
        )
        generator_sequence_output = generator_hidden_states[0]
        prediction_scores = self.generator_predictions(
            generator_sequence_output)
        prediction_scores = self.generator_lm_head(prediction_scores)

        loss = None
        # Masked language modeling softmax layer
        if labels is not None:
            '''
            if prediction_scores.device != torch.device('cpu'):
                labels[labels<0] = 0
                
            loss = cross_entropy(
                prediction_scores.view(-1, prediction_scores.size(-1)),
                labels.view(-1),
                reduction='sum',
            )'''
            per_example_loss = self.loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1)).view(labels.shape[0], labels.shape[1])
            loss = ((per_example_loss * attention_mask).sum(1) /
                    (1e-6 + attention_mask.sum(1))).sum()

            #loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        output = (prediction_scores, )
        return ((loss, ) + output) if loss is not None else output