示例#1
0
 def forward(self,
             batch,
             task,
             tokenizer,
             compute_loss: bool = False,
             loss_weights: torch.Tensor = None):
     if loss_weights is not None:
         raise NotImplementedError()
     encoder_output = get_output_from_encoder_and_batch(
         encoder=self.encoder, batch=batch)
     logits = self.span_prediction_head(unpooled=encoder_output.unpooled)
     # Ensure logits in valid range is at least self.offset_margin higher than others
     logits_offset = logits.max() - logits.min() + self.offset_margin
     logits = logits + logits_offset * batch.selection_token_mask.unsqueeze(
         dim=2)
     if compute_loss:
         loss_fct = nn.CrossEntropyLoss()
         loss = loss_fct(
             logits.transpose(dim0=1, dim1=2).flatten(end_dim=1),
             batch.gt_span_idxs.flatten(),
         )
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
示例#2
0
    def forward(self,
                batch,
                task,
                tokenizer,
                compute_loss: bool = False,
                loss_weights: torch.Tensor = None):
        input_ids = batch.input_ids
        segment_ids = batch.segment_ids
        input_mask = batch.input_mask

        choice_score_list = []
        encoder_outputs_other_ls = []
        for i in range(self.num_choices):
            encoder_output = get_output_from_encoder(
                encoder=self.encoder,
                input_ids=input_ids[:, i],
                segment_ids=segment_ids[:, i],
                input_mask=input_mask[:, i],
            )
            choice_score = self.choice_scoring_head(
                pooled=encoder_output.pooled)
            choice_score_list.append(choice_score)
            encoder_outputs_other_ls.append(encoder_output.other)

        reshaped_outputs = []
        if encoder_outputs_other_ls[0]:
            for j in range(len(encoder_outputs_other_ls[0])):
                reshaped_outputs.append([
                    torch.stack([
                        misc[j][layer_i] for misc in encoder_outputs_other_ls
                    ],
                                dim=1)
                    for layer_i in range(len(encoder_outputs_other_ls[0][0]))
                ])
            reshaped_outputs = tuple(reshaped_outputs)

        logits = torch.cat([
            choice_score.unsqueeze(1).squeeze(-1)
            for choice_score in choice_score_list
        ],
                           dim=1)

        if compute_loss:
            if loss_weights is not None:
                loss_fct = nn.CrossEntropyLoss(reduction='none')
                loss = loss_fct(logits.view(-1, self.num_choices),
                                batch.label_id.view(-1))
                loss = (loss * loss_weights).mean()
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_choices),
                                batch.label_id.view(-1))
            return LogitsAndLossOutput(logits=logits,
                                       loss=loss,
                                       other=reshaped_outputs)
        else:
            return LogitsOutput(logits=logits, other=reshaped_outputs)