Exemplo n.º 1
0
 def forward(self, input_ids, attention_mask=None, token_type_ids=None,
             position_ids=None, head_mask=None, labels=None):
     outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
     sequence_output = outputs[0]
     sequence_output = self.dropout(sequence_output)
     logits = self.classifier(sequence_output)
     outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
     if labels is not None:
         assert self.loss_type in ['lsr', 'focal', 'ce']
         if self.loss_type == 'lsr':
             loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
         elif self.loss_type == 'focal':
             loss_fct = FocalLoss(ignore_index=0)
         else:
             loss_fct = CrossEntropyLoss(ignore_index=0)
         # Only keep active parts of the loss
         if attention_mask is not None:
             active_loss = attention_mask.contiguous().view(-1) == 1
             active_logits = logits.view(-1, self.num_labels)[active_loss]
             active_labels = labels.contiguous().view(-1)[active_loss]
             loss = loss_fct(active_logits, active_labels)
         else:
             loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
         outputs = (loss,) + outputs
     return outputs  # (loss), scores, (hidden_states), (attentions)
Exemplo n.º 2
0
    def forward(self, input_ids,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None, 
                head_mask=None, 
                labels=None,
                input_lens=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask)
        last_hidden_state = outputs[0]
        sequence_output = self.dropout(last_hidden_state)
        logits = self.classifier(sequence_output)
        outputs = (logits,) + outputs[2:]
        if labels is not None:
            assert self.loss_type in ['ce', 'fl', 'lsc']
            if self.loss_type == 'ce':
                loss_fct = CrossEntropyLoss(ignore_index=0)
            elif self.loss_type == 'fl':
                loss_fct = FocalLoss(ignore_index=0)
            elif self.loss_type == 'lsc':
                loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
            
            if attention_mask is not None:
                active_loss = attention_mask.contiguous().view(-1) == 1
                active_logits = logits.contiguous().view(-1, self.num_labels)[active_loss]
                active_targets = labels.contiguous().view(-1)[active_loss]
                loss = loss_fct(active_logits, active_targets)
            else:
                loss = loss_fct(logits.contiguous().view(-1, self.num_labels), labels.contiguous().view(-1))
            outputs = (loss,) + outputs

        return outputs
Exemplo n.º 3
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                start_positions=None,
                end_positions=None):
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        start_logits = self.start_fc(sequence_output)
        if start_positions is not None and self.training:
            if self.soft_label:
                batch_size = input_ids.size(0)
                seq_len = input_ids.size(1)
                label_logits = torch.FloatTensor(batch_size, seq_len,
                                                 self.num_labels)
                label_logits.zero_()
                label_logits = label_logits.to(input_ids.device)
                label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
            else:
                label_logits = start_positions.unsqueeze(2).float()
        else:
            label_logits = F.softmax(start_logits, -1)
            if not self.soft_label:
                label_logits = torch.argmax(label_logits,
                                            -1).unsqueeze(2).float()
        end_logits = self.end_fc(sequence_output, label_logits)
        outputs = (
            start_logits,
            end_logits,
        ) + outputs[2:]

        if start_positions is not None and end_positions is not None:
            assert self.loss_type in ['lsr', 'focal', 'ce']
            if self.loss_type == 'lsr':
                loss_fct = LabelSmoothingCrossEntropy()
            elif self.loss_type == 'focal':
                loss_fct = FocalLoss()
            else:
                loss_fct = CrossEntropyLoss()
            start_logits = start_logits.view(-1, self.num_labels)
            end_logits = end_logits.view(-1, self.num_labels)
            active_loss = attention_mask.view(-1) == 1
            active_start_logits = start_logits[active_loss]
            active_end_logits = end_logits[active_loss]

            active_start_labels = start_positions.view(-1)[active_loss]
            active_end_labels = end_positions.view(-1)[active_loss]

            start_loss = loss_fct(active_start_logits, active_start_labels)
            end_loss = loss_fct(active_end_logits, active_end_labels)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss, ) + outputs
        return outputs
Exemplo n.º 4
0
    def forward(self,
                input_ids,
                attention_mask,
                labels,
                token_type_ids=None,
                input_lens=None):
        embs = self.embedding(input_ids)
        embs = self.dropout(embs)
        embs = embs * attention_mask.float().unsqueeze(2)
        sequence_output, _ = self.bilstm(embs)
        sequence_output = self.layer_norm(sequence_output)
        logits = self.classifier(sequence_output)
        outputs = (logits, )
        if labels is not None:
            if self.use_crf:
                loss = self.crf(emissions=logits,
                                tags=labels,
                                mask=attention_mask)
                outputs = (-1 * loss, ) + outputs
            else:
                assert self.loss_type in ['ce', 'fl', 'lsc']
                if self.loss_type == 'ce':
                    loss_fct = CrossEntropyLoss(ignore_index=0)
                elif self.loss_type == 'fl':
                    loss_fct = FocalLoss(ignore_index=0)
                elif self.loss_type == 'lsc':
                    loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)

                if attention_mask is not None:
                    active_loss = attention_mask.contiguous().view(-1) == 1
                    active_logits = logits.contiguous().view(
                        -1, self.num_labels)[active_loss]
                    active_targets = labels.contiguous().view(-1)[active_loss]
                    loss = loss_fct(active_logits, active_targets)
                else:
                    loss = loss_fct(
                        logits.contiguous().view(-1, self.num_labels),
                        labels.contiguous().view(-1))
                outputs = (loss, ) + outputs
        return outputs  # (loss), scores
Exemplo n.º 5
0
 def forward(self, input_ids,
             attention_mask=None,
             token_type_ids=None,
             position_ids=None,
             head_mask=None,
             labels=None,
             input_lens=None):
     outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids,
                         position_ids=position_ids, head_mask=head_mask, attention_mask=attention_mask)
     last_hidden_state = outputs[0]  # (batch_size, sequence_length, hidden_size)
     if self.use_lstm:
         last_hidden_state, _ = self.bilstm(last_hidden_state)
     sequence_output = self.dropout(last_hidden_state)
     print(sequence_output.shape)  # (batch_size, sequence_length, hidden_size)
     logits = self.classifier(sequence_output)  # (batch_size, seq_length, num_labels)
     outputs = (logits,) + outputs
     if labels is not None:
         if self.use_crf:
             loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
             outputs =(-1*loss,)+outputs
         else:
             assert self.loss_type in ['ce', 'fl', 'lsc']
             if self.loss_type == 'ce':
                 loss_fct = CrossEntropyLoss(ignore_index=0)
             elif self.loss_type == 'fl':
                 loss_fct = FocalLoss(ignore_index=0)
             elif self.loss_type == 'lsc':
                 loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
             
             if attention_mask is not None:
                 active_loss = attention_mask.contiguous().view(-1) == 1
                 active_logits = logits.contiguous().view(-1, self.num_labels)[active_loss]
                 active_targets = labels.contiguous().view(-1)[active_loss]
                 loss = loss_fct(active_logits, active_targets)
             else:
                 loss = loss_fct(logits.contiguous().view(-1, self.num_labels), labels.contiguous().view(-1))
             outputs = (loss,) + outputs
     return outputs