def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None): if loss_weights is not None: raise NotImplementedError() encoder_output = get_output_from_encoder_and_batch( encoder=self.encoder, batch=batch) logits = self.span_prediction_head(unpooled=encoder_output.unpooled) # Ensure logits in valid range is at least self.offset_margin higher than others logits_offset = logits.max() - logits.min() + self.offset_margin logits = logits + logits_offset * batch.selection_token_mask.unsqueeze( dim=2) if compute_loss: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.transpose(dim0=1, dim1=2).flatten(end_dim=1), batch.gt_span_idxs.flatten(), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None): input_ids = batch.input_ids segment_ids = batch.segment_ids input_mask = batch.input_mask choice_score_list = [] encoder_outputs_other_ls = [] for i in range(self.num_choices): encoder_output = get_output_from_encoder( encoder=self.encoder, input_ids=input_ids[:, i], segment_ids=segment_ids[:, i], input_mask=input_mask[:, i], ) choice_score = self.choice_scoring_head( pooled=encoder_output.pooled) choice_score_list.append(choice_score) encoder_outputs_other_ls.append(encoder_output.other) reshaped_outputs = [] if encoder_outputs_other_ls[0]: for j in range(len(encoder_outputs_other_ls[0])): reshaped_outputs.append([ torch.stack([ misc[j][layer_i] for misc in encoder_outputs_other_ls ], dim=1) for layer_i in range(len(encoder_outputs_other_ls[0][0])) ]) reshaped_outputs = tuple(reshaped_outputs) logits = torch.cat([ choice_score.unsqueeze(1).squeeze(-1) for choice_score in choice_score_list ], dim=1) if compute_loss: if loss_weights is not None: loss_fct = nn.CrossEntropyLoss(reduction='none') loss = loss_fct(logits.view(-1, self.num_choices), batch.label_id.view(-1)) loss = (loss * loss_weights).mean() else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_choices), batch.label_id.view(-1)) return LogitsAndLossOutput(logits=logits, loss=loss, other=reshaped_outputs) else: return LogitsOutput(logits=logits, other=reshaped_outputs)