Exemplo n.º 1
0
 def forward(self,
             batch,
             task,
             tokenizer,
             compute_loss: bool = False,
             loss_weights: torch.Tensor = None):
     if loss_weights is not None:
         raise NotImplementedError()
     masked_batch = batch.get_masked(
         mlm_probability=task.mlm_probability,
         tokenizer=tokenizer,
         do_mask=task.do_mask,
     )
     encoder_output = get_output_from_encoder(
         encoder=self.encoder,
         input_ids=masked_batch.masked_input_ids,
         segment_ids=masked_batch.segment_ids,
         input_mask=masked_batch.input_mask,
     )
     logits = self.mlm_head(unpooled=encoder_output.unpooled)
     if compute_loss:
         loss = compute_mlm_loss(
             logits=logits, masked_lm_labels=masked_batch.masked_lm_labels)
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 2
0
 def forward(self,
             batch,
             task,
             tokenizer,
             compute_loss: bool = False,
             loss_weights: torch.Tensor = None,
             get_encoder_output: bool = False):
     encoder_output = get_output_from_encoder_and_batch(
         encoder=self.encoder, batch=batch)
     logits = self.classification_head(pooled=encoder_output.pooled)
     if compute_loss:
         if loss_weights is not None:
             loss_fct = nn.CrossEntropyLoss(reduction='none')
             loss = loss_fct(
                 logits.view(-1, self.classification_head.num_labels),
                 batch.label_id.view(-1),
             )
             loss = (loss * loss_weights).mean()
         else:
             loss_fct = nn.CrossEntropyLoss()
             loss = loss_fct(
                 logits.view(-1, self.classification_head.num_labels),
                 batch.label_id.view(-1),
             )
         results = LogitsAndLossOutput(logits=logits,
                                       loss=loss,
                                       other=encoder_output.other)
     else:
         results = LogitsOutput(logits=logits, other=encoder_output.other)
     return (results, StaticEncoderOutput(encoder_output)
             ) if get_encoder_output else results
Exemplo n.º 3
0
    def forward(self, batch, tokenizer, compute_loss: bool = False):
        """Summary

        Args:
            batch (TYPE): Description
            tokenizer (TYPE): Description
            compute_loss (bool, optional): Description

        Returns:
            TYPE: Description
        """
        encoder_output = self.encoder.encode(
            input_ids=batch.input_ids,
            segment_ids=batch.segment_ids,
            input_mask=batch.input_mask,
        )
        logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans)
        if compute_loss:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.head.num_labels),
                batch.label_id.view(-1),
            )
            return LogitsAndLossOutput(logits=logits,
                                       loss=loss,
                                       other=encoder_output.other)
        else:
            return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 4
0
 def forward(self,
             batch,
             task,
             tokenizer,
             compute_loss: bool = False,
             loss_weights: torch.Tensor = None):
     if loss_weights is not None:
         raise NotImplementedError()
     encoder_output = get_output_from_encoder_and_batch(
         encoder=self.encoder, batch=batch)
     logits = self.span_prediction_head(unpooled=encoder_output.unpooled)
     # Ensure logits in valid range is at least self.offset_margin higher than others
     logits_offset = logits.max() - logits.min() + self.offset_margin
     logits = logits + logits_offset * batch.selection_token_mask.unsqueeze(
         dim=2)
     if compute_loss:
         loss_fct = nn.CrossEntropyLoss()
         loss = loss_fct(
             logits.transpose(dim0=1, dim1=2).flatten(end_dim=1),
             batch.gt_span_idxs.flatten(),
         )
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 5
0
 def forward(self,
             batch,
             task,
             tokenizer,
             compute_loss: bool = False,
             loss_weights: torch.Tensor = None,
             get_encoder_output: bool = False):
     encoder_output = get_output_from_encoder_and_batch(
         encoder=self.encoder, batch=batch)
     # TODO: Abuse of notation - these aren't really logits  (issue #1187)
     logits = self.regression_head(pooled=encoder_output.pooled)
     if compute_loss:
         if loss_weights is not None:
             loss_fct = nn.MSELoss(reduction='none')
             loss = loss_fct(logits.view(-1), batch.label.view(-1))
             loss = (loss * loss_weights).mean()
         else:
             loss_fct = nn.MSELoss()
             loss = loss_fct(logits.view(-1), batch.label.view(-1))
         results = LogitsAndLossOutput(logits=logits,
                                       loss=loss,
                                       other=encoder_output.other)
     else:
         results = LogitsOutput(logits=logits, other=encoder_output.other)
     return (results, StaticEncoderOutput(encoder_output)
             ) if get_encoder_output else results
Exemplo n.º 6
0
    def forward(self, batch, task, tokenizer, compute_loss: bool = False):
        with transformer_utils.output_hidden_states_context(self.encoder):
            encoder_output = get_output_from_encoder_and_batch(
                encoder=self.encoder, batch=batch)
        # A tuple of layers of hidden states
        hidden_states = take_one(encoder_output.other)
        layer_hidden_states = hidden_states[self.layer]

        if isinstance(self.pooler_head, heads.MeanPoolerHead):
            logits = self.pooler_head(unpooled=layer_hidden_states,
                                      input_mask=batch.input_mask)
        elif isinstance(self.pooler_head, heads.FirstPoolerHead):
            logits = self.pooler_head(layer_hidden_states)
        else:
            raise TypeError(type(self.pooler_head))

        # TODO: Abuse of notation - these aren't really logits  (issue #1187)
        if compute_loss:
            # TODO: make this optional?   (issue #1187)
            return LogitsAndLossOutput(
                logits=logits,
                loss=torch.tensor([0.0]),  # This is a horrible hack
                other=encoder_output.other,
            )
        else:
            return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 7
0
    def forward(self,
                batch,
                task,
                tokenizer,
                compute_loss: bool = False,
                loss_weights: torch.Tensor = None):
        input_ids = batch.input_ids
        segment_ids = batch.segment_ids
        input_mask = batch.input_mask

        choice_score_list = []
        encoder_outputs_other_ls = []
        for i in range(self.num_choices):
            encoder_output = get_output_from_encoder(
                encoder=self.encoder,
                input_ids=input_ids[:, i],
                segment_ids=segment_ids[:, i],
                input_mask=input_mask[:, i],
            )
            choice_score = self.choice_scoring_head(
                pooled=encoder_output.pooled)
            choice_score_list.append(choice_score)
            encoder_outputs_other_ls.append(encoder_output.other)

        reshaped_outputs = []
        if encoder_outputs_other_ls[0]:
            for j in range(len(encoder_outputs_other_ls[0])):
                reshaped_outputs.append([
                    torch.stack([
                        misc[j][layer_i] for misc in encoder_outputs_other_ls
                    ],
                                dim=1)
                    for layer_i in range(len(encoder_outputs_other_ls[0][0]))
                ])
            reshaped_outputs = tuple(reshaped_outputs)

        logits = torch.cat([
            choice_score.unsqueeze(1).squeeze(-1)
            for choice_score in choice_score_list
        ],
                           dim=1)

        if compute_loss:
            if loss_weights is not None:
                loss_fct = nn.CrossEntropyLoss(reduction='none')
                loss = loss_fct(logits.view(-1, self.num_choices),
                                batch.label_id.view(-1))
                loss = (loss * loss_weights).mean()
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_choices),
                                batch.label_id.view(-1))
            return LogitsAndLossOutput(logits=logits,
                                       loss=loss,
                                       other=reshaped_outputs)
        else:
            return LogitsOutput(logits=logits, other=reshaped_outputs)
Exemplo n.º 8
0
 def forward(self, batch, task, tokenizer, compute_loss: bool = False):
     encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
     # TODO: Abuse of notation - these aren't really logits  (Issue #45)
     logits = self.regression_head(pooled=encoder_output.pooled)
     if compute_loss:
         loss_fct = nn.MSELoss()
         loss = loss_fct(logits.view(-1), batch.label.view(-1))
         return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 9
0
 def forward(self, batch, task, tokenizer, compute_loss: bool = False):
     encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
     logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans)
     if compute_loss:
         loss_fct = nn.BCEWithLogitsLoss()
         loss = loss_fct(
             logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(),
         )
         return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 10
0
 def forward(self, batch, task, tokenizer, compute_loss: bool = False):
     encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
     logits = self.classification_head(pooled=encoder_output.pooled)
     if compute_loss:
         loss_fct = nn.CrossEntropyLoss()
         loss = loss_fct(
             logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1),
         )
         return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 11
0
 def forward(self, batch, task, tokenizer, compute_loss: bool = False):
     encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
     logits = self.qa_head(unpooled=encoder_output.unpooled)
     if compute_loss:
         loss = compute_qa_loss(
             logits=logits,
             start_positions=batch.start_position,
             end_positions=batch.end_position,
         )
         return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 12
0
 def forward(self, batch, task, tokenizer, compute_loss: bool = False):
     masked_batch = batch.get_masked(
         mlm_probability=task.mlm_probability, tokenizer=tokenizer, do_mask=task.do_mask,
     )
     encoder_output = get_output_from_encoder(
         encoder=self.encoder,
         input_ids=masked_batch.masked_input_ids,
         segment_ids=masked_batch.segment_ids,
         input_mask=masked_batch.input_mask,
     )
     logits = self.mlm_head(unpooled=encoder_output.unpooled)
     if compute_loss:
         loss = compute_mlm_loss(logits=logits, masked_lm_labels=masked_batch.masked_lm_labels)
         return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 13
0
 def forward(self, batch, tokenizer, compute_loss: bool = False):
     encoder_output = self.encoder.encode(
         input_ids=batch.input_ids,
         segment_ids=batch.segment_ids,
         input_mask=batch.input_mask,
     )
     # TODO: Abuse of notation - these aren't really logits  (issue #1187)
     logits = self.head(pooled=encoder_output.pooled)
     if compute_loss:
         loss_fct = nn.MSELoss()
         loss = loss_fct(logits.view(-1), batch.label.view(-1))
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 14
0
 def forward(self, batch, tokenizer, compute_loss: bool = False):
     encoder_output = self.encoder.encode(
         input_ids=batch.input_ids,
         segment_ids=batch.segment_ids,
         input_mask=batch.input_mask,
     )
     logits = self.head(pooled=encoder_output.pooled)
     if compute_loss:
         loss_fct = nn.CrossEntropyLoss()
         loss = loss_fct(
             logits.view(-1, self.head.num_labels),
             batch.label_id.view(-1),
         )
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 15
0
 def forward(self, batch, tokenizer, compute_loss: bool = False):
     encoder_output = self.encoder.encode(
         input_ids=batch.input_ids,
         segment_ids=batch.segment_ids,
         input_mask=batch.input_mask,
     )
     logits = self.head(unpooled=encoder_output.unpooled)
     if compute_loss:
         loss = compute_qa_loss(
             logits=logits,
             start_positions=batch.start_position,
             end_positions=batch.end_position,
         )
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 16
0
 def forward(self, batch, tokenizer, compute_loss: bool = False):
     encoder_output = self.encoder.encode(
         input_ids=batch.input_ids,
         segment_ids=batch.segment_ids,
         input_mask=batch.input_mask,
     )
     logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans)
     if compute_loss:
         loss_fct = nn.BCEWithLogitsLoss()
         loss = loss_fct(
             logits.view(-1, self.head.num_labels),
             batch.label_ids.float(),
         )
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 17
0
 def forward(self, batch, tokenizer, compute_loss: bool = False):
     encoder_output = self.encoder.encode(
         input_ids=batch.input_ids,
         segment_ids=batch.segment_ids,
         input_mask=batch.input_mask,
     )
     logits = self.head(unpooled=encoder_output.unpooled)
     # Ensure logits in valid range is at least self.offset_margin higher than others
     logits_offset = logits.max() - logits.min() + self.offset_margin
     logits = logits + logits_offset * batch.selection_token_mask.unsqueeze(
         dim=2)
     if compute_loss:
         loss_fct = nn.CrossEntropyLoss()
         loss = loss_fct(
             logits.transpose(dim0=1, dim1=2).flatten(end_dim=1),
             batch.gt_span_idxs.flatten(),
         )
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 18
0
    def forward(self, batch, tokenizer, compute_loss: bool = False):
        choice_score_list = []
        encoder_output_other_ls = []
        for i in range(self.num_choices):
            encoder_output = self.encoder.encode(
                input_ids=batch.input_ids[:, i],
                segment_ids=batch.segment_ids[:, i],
                input_mask=batch.input_mask[:, i],
            )
            choice_score = self.head(pooled=encoder_output.pooled)
            choice_score_list.append(choice_score)
            encoder_output_other_ls.append(encoder_output.other)

        reshaped_outputs = []
        if encoder_output_other_ls[0]:
            for j in range(len(encoder_output_other_ls[0])):
                reshaped_outputs.append([
                    torch.stack(
                        [misc[j][layer_i] for misc in encoder_output_other_ls],
                        dim=1,
                    ) for layer_i in range(len(encoder_output_other_ls[0][0]))
                ])
            reshaped_outputs = tuple(reshaped_outputs)

        logits = torch.cat([
            choice_score.unsqueeze(1).squeeze(-1)
            for choice_score in choice_score_list
        ],
                           dim=1)

        if compute_loss:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_choices),
                            batch.label_id.view(-1))
            return LogitsAndLossOutput(logits=logits,
                                       loss=loss,
                                       other=reshaped_outputs)
        else:
            return LogitsOutput(logits=logits, other=reshaped_outputs)
Exemplo n.º 19
0
 def forward(self,
             batch,
             task,
             tokenizer,
             compute_loss: bool = False,
             loss_weights: torch.Tensor = None):
     if loss_weights is not None:
         raise NotImplementedError()
     encoder_output = get_output_from_encoder_and_batch(
         encoder=self.encoder, batch=batch)
     logits = self.span_comparison_head(unpooled=encoder_output.unpooled,
                                        spans=batch.spans)
     if compute_loss:
         loss_fct = nn.BCEWithLogitsLoss()
         loss = loss_fct(
             logits.view(-1, self.span_comparison_head.num_labels),
             batch.label_ids.float(),
         )
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)
Exemplo n.º 20
0
 def forward(self,
             batch,
             task,
             tokenizer,
             compute_loss: bool = False,
             loss_weights: torch.Tensor = None):
     if loss_weights is not None:
         raise NotImplementedError()
     encoder_output = get_output_from_encoder_and_batch(
         encoder=self.encoder, batch=batch)
     logits = self.token_classification_head(
         unpooled=encoder_output.unpooled)
     if compute_loss:
         loss_fct = nn.CrossEntropyLoss()
         active_loss = batch.label_mask.view(-1) == 1
         active_logits = logits.view(
             -1, self.token_classification_head.num_labels)[active_loss]
         active_labels = batch.label_ids.view(-1)[active_loss]
         loss = loss_fct(active_logits, active_labels)
         return LogitsAndLossOutput(logits=logits,
                                    loss=loss,
                                    other=encoder_output.other)
     else:
         return LogitsOutput(logits=logits, other=encoder_output.other)