예제 #1
0
    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # BERT Output
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # sequence_output: (Batch_size, max_seq_len, dims)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        
        # stacked_lstm_output: (batch_size, max_seq_len, num_direction * dims) 
        stacked_lstm_output, _  = self.lstm(sequence_output)
        
        # logits: (batch_size, max_seq_len, num_labels)
        logits = self.linear(stacked_lstm_output)


        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(loss=loss,
                                     logits=logits,
                                     hidden_states=outputs.hidden_states, # From BERT
                                     attentions=outputs.attentions)
예제 #2
0
    def forward(self, input, attention_mask=None, logits_mask=None, labels=None,
                return_dict=None, hidden_states=None):
        logits = super(LinearClassifier, self).forward(input)

        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # Only keep active parts of the loss
            if logits_mask is None:
                logits_mask = attention_mask[:, 2:] == 1
            if logits_mask is not None:
                active_loss = logits_mask.view(-1)
                active_logits = logits.view(-1, self.out_features)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.out_features), labels.view(-1))

        if not return_dict:
            output = ((logits,) + hidden_states[1:]) if hidden_states is not None else (logits,)
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=hidden_states.hidden_states,
            attentions=hidden_states.attentions,
        )
예제 #3
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:
        Example::
            >>> from transformers import T5Tokenizer, T5EncoderModel
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5EncoderModel.from_pretrained('t5-small')
            >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model(input_ids=input_ids)
            >>> last_hidden_states = outputs.last_hidden_state
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if attention_mask is None:
            attention_mask = torch.ones(input_ids, device=input_ids.device)

        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = outputs[0]
        last_hidden_state = self.dropout(last_hidden_state)
        emissions = self.position_wise_ff(last_hidden_state)

        loss = None
        if labels is not None:
            mask = attention_mask.to(torch.uint8)
            loss = self.crf(emissions, labels, mask=mask)
            loss = -1 * loss
            logits = self.crf.decode(emissions, mask)
        else:
            mask = attention_mask.to(torch.uint8)
            logits = self.crf.decode(emissions, mask)

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else logits

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #4
0
    def forward(
        self,
        input_ids=None,
        bbox=None,
        image=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.layoutlmv2(
            input_ids=input_ids,
            bbox=bbox,
            image=image,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        seq_length = input_ids.size(1)
        sequence_output, image_output = outputs[0][:, :seq_length], outputs[
            0][:, seq_length:]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()

            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels),
                                labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #5
0
    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            bboxes=None  # added argument
    ) -> TokenClassifierOutput:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.lambert(  # substituted `roberta` with `lambert`
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            bboxes=bboxes  # added argument
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1),
                    torch.tensor(loss_fct.ignore_index).type_as(labels))
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels),
                                labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #6
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.distilbert(
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1),
                    torch.tensor(loss_fct.ignore_index).type_as(labels))
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels),
                                labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[1:]
            return ((loss, ) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #7
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        center_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[
            0]  # (batch_size, sequence_length, hidden_size)

        row_indices = torch.arange(0, sequence_output.size(0)).long()
        sequence_output = sequence_output[row_indices, center_positions, :]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #8
0
    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # BERT Output
        outputs = self.bert(
                    input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    position_ids=position_ids,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=True,
                    return_dict=return_dict,
                )
        
        ### `hidden_states` ###
        # Tuple contains embedding and each layer
        # ((batch_size, seq_len, dims), (ANOTHER_LAYER), ...)
        # sequence_output: (batch_size, max_seq_len, dims)        
        hiddens = outputs[2]
        hidden = hiddens[self.to_layer]
        
        sequence_output = self.dropout(hiddens[self.to_layer])
        logits = self.linear(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(loss=loss,
                                     logits=logits,)
예제 #9
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1),
                    torch.tensor(loss_fct.ignore_index).type_as(labels))
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels),
                                labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #10
0
    def forward(self,
                input,
                word_index=None,
                word_attention_mask=None,
                labels=None,
                return_dict=None,
                hidden_states=None):
        if word_index is not None:
            input = torch.gather(input,
                                 dim=1,
                                 index=word_index.unsqueeze(-1).expand(
                                     -1, -1, input.size(-1)))

        sequence_output = self.relative_transformer(input, word_attention_mask)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # Only keep active parts of the loss
            if word_attention_mask is not None:
                active_loss = word_attention_mask.view(-1)
                active_logits = logits.view(
                    -1, self.classifier.out_features)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.classifier.out_features),
                                labels.view(-1))

        if not return_dict:
            output = ((logits, ) +
                      hidden_states[1:]) if hidden_states is not None else (
                          logits, )
            return ((loss, ) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=hidden_states.hidden_states,
            attentions=hidden_states.attentions,
        )
    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                labels=None,
                return_dict=None):

        self.lstm.flatten_parameters()

        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            return_dict=self.return_dict)

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        lstm_output = self.lstm(sequence_output)
        lstm_output = lstm_output[0]
        logits = self.classifier(lstm_output)

        loss = None
        if labels is not None:
            ## [TBD] change {label_id:-100 [CLS], [SEP], [PAD]} into {label_id:32 "O"}
            ## It means they contribute loss to loss function, so it need to be improved
            active_idx = labels != -100
            active_labels = torch.where(active_idx, labels,
                                        torch.tensor(0).type_as(labels))
            loss = self.crf(emissions=logits,
                            tags=active_labels,
                            mask=attention_mask.type(torch.uint8))
            loss = -1 * loss

        if self.return_dict:
            return TokenClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
        else:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        sent_bounds=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)

        device = sequence_output.device
        pos_matrix = torch.arange(sequence_output.size()[1], device=device).view(1, 1, -1)
        if_in_sent = torch.logical_and(sent_bounds[:, :, 1].unsqueeze(-1) <= pos_matrix,
                                           pos_matrix <= sent_bounds[:, :, 2].unsqueeze(-1))

        if self.pooling_type == 'average':
            pooling_matrix = torch.where(if_in_sent, torch.tensor((1), device=device), torch.tensor((0), device=device)).float()
            sent_len = torch.sum(pooling_matrix, 2).unsqueeze(2)
            sent_len[sent_len==0] = 1
            pooling_matrix = pooling_matrix / sent_len
            sentence_hiddens = torch.bmm(sequence_output.transpose(-1, -2), pooling_matrix.transpose(-1, -2)).transpose(-1, -2)
        elif self.pooling_type == 'max':
            pooling_matrix = torch.where(if_in_sent.unsqueeze(-1),  sequence_output.unsqueeze(1), torch.tensor((0.0), device=device)).float()
            sentence_hiddens = torch.max(pooling_matrix, dim=2)[0]
        logits = self.output_layer(sentence_hiddens).squeeze(-1)

        mask = torch.where(sent_bounds[:, :, 0] >= 0, torch.tensor(0.0, device=device), torch.tensor((-10000.0), device=device))
        logits += mask

        loss = None
        if labels is not None:
            loss_fct = KLDivLoss()
            # Only keep active parts of the loss
            loss = loss_fct(F.log_softmax(logits, dim=-1), F.softmax(labels, dim=-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #13
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
        )

        sequence_output = outputs[2]
        layers = len(sequence_output)
        batchsize, length, hidden_size = sequence_output[0].size(
            0), sequence_output[0].size(1), sequence_output[0].size(2)

        sequence_output = torch.cat(sequence_output).view(
            layers, batchsize, length, hidden_size)

        sequence_output = sequence_output.transpose(0, 1).transpose(
            1, 2).contiguous()
        sequence_output = self.attn(sequence_output)
        if self.quick_return:
            return sequence_output

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            if self.lossfct == 'diceloss':
                loss_fct = MultiDiceLoss()
                if attention_mask is not None:

                    active_loss = attention_mask.view(-1) == 1

                    active_logits = logits.view(-1, self.num_labels)

                    active_labels = labels.view(-1)
                    active_labels = F.one_hot(active_labels, self.num_labels)

                    mask = attention_mask.view(-1, 1)
                    mask = mask.repeat(1, self.num_labels)

                    loss = loss_fct(active_logits, active_labels, mask)
                    #print(loss)
                else:
                    loss = loss_fct(logits.view(-1, self.num_labels),
                                    labels.view(-1))
            elif self.lossfct == 'focalloss':
                loss_fct = FocalLoss()
                # Only keep active parts of the loss
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                    active_logits = logits.view(-1, self.num_labels)
                    active_labels = torch.where(
                        active_loss, labels.view(-1),
                        torch.tensor(loss_fct.ignore_index).type_as(labels))
                    loss = loss_fct(active_logits, active_labels)
                else:
                    loss = loss_fct(logits.view(-1, self.num_labels),
                                    labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss(reduction=self.CEL_type)
                # Only keep active parts of the loss
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                    active_logits = logits.view(-1, self.num_labels)
                    active_labels = torch.where(
                        active_loss, labels.view(-1),
                        torch.tensor(loss_fct.ignore_index).type_as(labels))
                    loss = loss_fct(active_logits, active_labels)
                else:
                    loss = loss_fct(logits.view(-1, self.num_labels),
                                    labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #14
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        seq_len = input_ids.shape[1]
        if attention_mask is not None and token_type_ids is not None:
            ones = torch.ones((1, seq_len, seq_len),
                              dtype=torch.float32,
                              device=self.device)
            a_mask = ones.tril()

            part_a_mask = (1 - token_type_ids) - (1 - attention_mask)
            ex_part_a_mask_2 = part_a_mask.unsqueeze(1).float()
            ex_part_a_mask_3 = part_a_mask.unsqueeze(2).float()
            ex_token_type_13 = token_type_ids.unsqueeze(2).float()
            a_mask = ex_part_a_mask_2 * ex_part_a_mask_3 + ex_token_type_13 * a_mask
            attention_mask = a_mask

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        logits = self.predictions(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            predictions = logits[:, :-1].contiguous()
            target_mask = token_type_ids[:, 1:].contiguous() == 1
            active_logits = predictions.view(-1, self.config.vocab_size)
            active_labels = torch.where(
                target_mask.view(-1), labels.view(-1),
                torch.tensor(loss_fct.ignore_index).type_as(labels))
            loss = loss_fct(active_logits, active_labels)

            # if attention_mask is not None:
            #     active_loss = attention_mask.view(-1) == 1
            #     active_logits = logits.view(-1, self.num_labels)
            #     active_labels = torch.where(
            #         active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
            #     )
            #     loss = loss_fct(active_logits, active_labels)
            # else:
            #     loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
예제 #15
0
def get_token_classifier_output(model, logits, labels, attention_mask,
                                return_dict, outputs):
    loss = None
    if labels is not None:

        if "softmax" in model.top_model["name"]:
            wt = torch.as_tensor(model.xargs['wt'],
                                 dtype=logits.dtype,
                                 device=logits.device)
            if "class_wt" in model.xargs.keys():
                loss_fct = nn.CrossEntropyLoss(weight=wt)
            else:
                loss_fct = nn.CrossEntropyLoss()

            if model.xargs.get('dce_loss'):
                loss_fct = SelfAdjDiceLoss(model.xargs.get('dce_loss_alpha'),
                                           model.xargs.get('dce_loss_gamma'),
                                           wt=wt)
            # Only keep active parts of the loss
            if attention_mask is not None:
                if model.xargs.get('random'):
                    attention_mask = attention_mask.detach().clone()
                    attention_mask = attention_mask.view(-1)
                    all_idx = torch.arange(0, attention_mask.shape[0]).to(
                        attention_mask.device)
                    masked_idx = torch.mask_select(all_idx,
                                                   labels.view(-1) == 2)
                    rand_idx = torch.randint(
                        0, masked_idx.shape[0],
                        (model.xargs.get('random') * masked_idx.shape[0]) //
                        100).to(attention_mask.device)
                    attention_mask[
                        masked_idx[rand_idx]] = loss_fct.ignore_index
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, model.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1),
                    torch.tensor(loss_fct.ignore_index).type_as(labels))
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, model.num_labels),
                                labels.view(-1))

        elif "crf" in model.top_model["name"]:
            labels_copy = labels.detach().clone()
            #             if self.xargs.get('skip_subset',False):
            #                 labels[:,0] = 2
            #                 for i in range(labels.shape[0]):
            #                     #TODO
            #             labels_copy[labels_copy == -100] = 2

            attention_mask_copy = attention_mask.detach().clone()
            if model.xargs.get('skip_subset', False):
                attention_mask_copy[labels_copy == -100] = 0
                attention_mask_copy[:, 0] = 1
            labels_copy[labels_copy == -100] = 2
            if attention_mask is not None:
                loss = -model.crf.forward(logits,
                                          labels_copy,
                                          attention_mask_copy.type(
                                              torch.uint8),
                                          reduction="mean")
            else:
                loss = -model.crf.forward(logits, labels_copy)

    if not return_dict:
        output = (logits, ) + outputs[2:]
        return ((loss, ) + output) if loss is not None else output

    return TokenClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
예제 #16
0
        def forward(
                self,
                input_ids=None,
                attention_mask=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
        ):
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict

            outputs = self.distilbert(
                input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            sequence_output = outputs[0]
            sequence_output = self.dropout(sequence_output)

            if tfidf:
                shape = list(sequence_output.shape)
                shape[-1] += 1
                new_sequence_output = torch.zeros(shape)

                for i, doc in enumerate(sequence_output):
                    offset = i + self.batch_offset

                    feature_index = X[offset, :].nonzero()[1]
                    tfidf_scores = zip(feature_index, [X[offset, x] for x in feature_index])
                    words, scores = map(list, zip(*[(feature_names[k], score) for (k, score) in tfidf_scores]))

                    for j, token in enumerate(doc):
                        for word, score in zip(words, scores):
                            try:
                                if word == corpus[offset][j - 1]:
                                    new_sequence_output[i][j] = torch.cat((token, torch.tensor([score])))
                                    break
                            except IndexError:
                                new_sequence_output[i][j] = torch.cat((token, torch.tensor([0])))
                                break

                sequence_output = new_sequence_output

            if wordlist:
                shape = list(sequence_output.shape)
                shape[-1] += 1
                new_sequence_output = torch.zeros(shape)

                for i, doc in enumerate(sequence_output):
                    offset = i + self.batch_offset
                    for j, token in enumerate(doc):
                        try:
                            if corpus[offset][j - 1] in toxic_words:
                                new_sequence_output[i][j] = torch.cat((token, torch.tensor([1])))
                            else:
                                new_sequence_output[i][j] = torch.cat((token, torch.tensor([0])))
                        except IndexError:
                            new_sequence_output[i][j] = torch.cat((token, torch.tensor([0])))

                sequence_output = new_sequence_output

            logits = self.classifier(sequence_output)
            batch_size = logits.shape[0]

            loss = None
            if labels is not None:
                if crf:
                    prediction_mask = torch.ones(labels.shape, dtype=torch.bool)
                    for i, seq_labels in enumerate(labels):
                        for j, label in enumerate(seq_labels):
                            if label == -100:
                                prediction_mask[i][j] = 0

                    loss = 0
                    for seq_logits, seq_labels, seq_mask in zip(logits, labels, prediction_mask):
                        seq_logits = seq_logits[seq_mask].unsqueeze(0)
                        seq_labels = seq_labels[seq_mask].unsqueeze(0)
                        loss -= self.crf(seq_logits, seq_labels, reduction='token_mean')

                    loss /= batch_size
                else:
                    loss_fct = torch.nn.CrossEntropyLoss()
                    if attention_mask is not None:
                        active_loss = attention_mask.view(-1) == 1
                        active_logits = logits.view(-1, self.num_labels)
                        active_labels = torch.where(
                            active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                        )
                        loss = loss_fct(active_logits, active_labels)
                    else:
                        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                        
            self.batch_offset += batch_size
            if self.batch_offset >= len(texts):
                self.batch_offset = 0

            if not return_dict:
                output = (logits,) + outputs[1:]
                return ((loss,) + output) if loss is not None else output

            return TokenClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )