def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1]``. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.albert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict, ) sequence_output = outputs[2] layers = len(sequence_output) batchsize, length, hidden_size = sequence_output[0].size( 0), sequence_output[0].size(1), sequence_output[0].size(2) sequence_output = torch.cat(sequence_output).view( layers, batchsize, length, hidden_size) sequence_output = sequence_output.transpose(0, 1).transpose( 1, 2).contiguous() sequence_output = self.attn(sequence_output) if self.quick_return: return sequence_output sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: if self.lossfct == 'diceloss': loss_fct = MultiDiceLoss() if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = labels.view(-1) active_labels = F.one_hot(active_labels, self.num_labels) mask = attention_mask.view(-1, 1) mask = mask.repeat(1, self.num_labels) loss = loss_fct(active_logits, active_labels, mask) #print(loss) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.lossfct == 'focalloss': loss_fct = FocalLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) else: loss_fct = CrossEntropyLoss(reduction=self.CEL_type) # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None, #use_mems=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.transformer( input_ids, attention_mask=attention_mask, mems=mems, perm_mask=perm_mask, target_mapping=target_mapping, token_type_ids=token_type_ids, input_mask=input_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, #use_mems=use_mems, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict, ) sequence_output = outputs.hidden_states # tuple( [batch, len, hiddenstate],) layers = len(sequence_output) batchsize, length, hidden_size = sequence_output[0].size( 0), sequence_output[0].size(1), sequence_output[0].size(2) '''print(layers) print(batchsize) print(length) print(hidden_size)''' # print(sequence_output.size()) sequence_output = torch.cat(sequence_output).view( layers, batchsize, length, hidden_size) # tensor.size([layers, batch, len, hiddenstate]) # print(sequence_output.size()) sequence_output = sequence_output.transpose(0, 1).transpose( 1, 2).contiguous() sequence_output = self.attn(sequence_output) if self.quick_return: return sequence_output logits = self.classifier(sequence_output) loss = None if labels is not None: if self.lossfct == 'diceloss': loss_fct = MultiDiceLoss() if attention_mask is not None: '''print(attention_mask) print(attention_mask.shape) #torch.Size([4, 80]) batch,len''' active_loss = attention_mask.view(-1) == 1 '''print(active_loss) print(active_loss.shape)#torch.Size([320]) 4*80 print(logits) print(logits.shape) #torch.Size([4, 80, 6])''' active_logits = logits.view(-1, self.num_labels) '''print(active_logits) print(active_logits.shape)#torch.Size([320, 6]) 4*80*6''' #active_logits = torch.masked_select(active_logits, (active_loss == 1)) active_labels = labels.view(-1) #->torch.Size([320]) active_labels = F.one_hot(active_labels, self.num_labels) '''print(labels) print(labels.shape)#torch.Size([4, 80]) print(active_labels) print(active_labels.shape)#torch.Size([320,6]) print(active_logits)''' mask = attention_mask.view(-1, 1) mask = mask.repeat(1, self.num_labels) '''print(mask) print(mask.shape)#torch.Size([320, 6])''' loss = loss_fct(active_logits, active_labels, mask) #print(loss) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.lossfct == 'focalloss': loss_fct = FocalLoss() # 'sum' # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)) loss = loss_fct(active_logits, active_labels) # 320*6, 320 else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) else: loss_fct = CrossEntropyLoss(reduction=self.CEL_type) #'sum' # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits, ) + outputs[1:] return ((loss, ) + output) if loss is not None else output return XLNetForTokenClassificationOutput( loss=loss, logits=logits, mems=outputs.mems, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )