def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None): if loss_weights is not None: raise NotImplementedError() masked_batch = batch.get_masked( mlm_probability=task.mlm_probability, tokenizer=tokenizer, do_mask=task.do_mask, ) encoder_output = get_output_from_encoder( encoder=self.encoder, input_ids=masked_batch.masked_input_ids, segment_ids=masked_batch.segment_ids, input_mask=masked_batch.input_mask, ) logits = self.mlm_head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_mlm_loss( logits=logits, masked_lm_labels=masked_batch.masked_lm_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None, get_encoder_output: bool = False): encoder_output = get_output_from_encoder_and_batch( encoder=self.encoder, batch=batch) logits = self.classification_head(pooled=encoder_output.pooled) if compute_loss: if loss_weights is not None: loss_fct = nn.CrossEntropyLoss(reduction='none') loss = loss_fct( logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1), ) loss = (loss * loss_weights).mean() else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1), ) results = LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: results = LogitsOutput(logits=logits, other=encoder_output.other) return (results, StaticEncoderOutput(encoder_output) ) if get_encoder_output else results
def forward(self, batch, tokenizer, compute_loss: bool = False): """Summary Args: batch (TYPE): Description tokenizer (TYPE): Description compute_loss (bool, optional): Description Returns: TYPE: Description """ encoder_output = self.encoder.encode( input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, ) logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.head.num_labels), batch.label_id.view(-1), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None): if loss_weights is not None: raise NotImplementedError() encoder_output = get_output_from_encoder_and_batch( encoder=self.encoder, batch=batch) logits = self.span_prediction_head(unpooled=encoder_output.unpooled) # Ensure logits in valid range is at least self.offset_margin higher than others logits_offset = logits.max() - logits.min() + self.offset_margin logits = logits + logits_offset * batch.selection_token_mask.unsqueeze( dim=2) if compute_loss: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.transpose(dim0=1, dim1=2).flatten(end_dim=1), batch.gt_span_idxs.flatten(), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None, get_encoder_output: bool = False): encoder_output = get_output_from_encoder_and_batch( encoder=self.encoder, batch=batch) # TODO: Abuse of notation - these aren't really logits (issue #1187) logits = self.regression_head(pooled=encoder_output.pooled) if compute_loss: if loss_weights is not None: loss_fct = nn.MSELoss(reduction='none') loss = loss_fct(logits.view(-1), batch.label.view(-1)) loss = (loss * loss_weights).mean() else: loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), batch.label.view(-1)) results = LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: results = LogitsOutput(logits=logits, other=encoder_output.other) return (results, StaticEncoderOutput(encoder_output) ) if get_encoder_output else results
def forward(self, batch, task, tokenizer, compute_loss: bool = False): with transformer_utils.output_hidden_states_context(self.encoder): encoder_output = get_output_from_encoder_and_batch( encoder=self.encoder, batch=batch) # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] if isinstance(self.pooler_head, heads.MeanPoolerHead): logits = self.pooler_head(unpooled=layer_hidden_states, input_mask=batch.input_mask) elif isinstance(self.pooler_head, heads.FirstPoolerHead): logits = self.pooler_head(layer_hidden_states) else: raise TypeError(type(self.pooler_head)) # TODO: Abuse of notation - these aren't really logits (issue #1187) if compute_loss: # TODO: make this optional? (issue #1187) return LogitsAndLossOutput( logits=logits, loss=torch.tensor([0.0]), # This is a horrible hack other=encoder_output.other, ) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None): input_ids = batch.input_ids segment_ids = batch.segment_ids input_mask = batch.input_mask choice_score_list = [] encoder_outputs_other_ls = [] for i in range(self.num_choices): encoder_output = get_output_from_encoder( encoder=self.encoder, input_ids=input_ids[:, i], segment_ids=segment_ids[:, i], input_mask=input_mask[:, i], ) choice_score = self.choice_scoring_head( pooled=encoder_output.pooled) choice_score_list.append(choice_score) encoder_outputs_other_ls.append(encoder_output.other) reshaped_outputs = [] if encoder_outputs_other_ls[0]: for j in range(len(encoder_outputs_other_ls[0])): reshaped_outputs.append([ torch.stack([ misc[j][layer_i] for misc in encoder_outputs_other_ls ], dim=1) for layer_i in range(len(encoder_outputs_other_ls[0][0])) ]) reshaped_outputs = tuple(reshaped_outputs) logits = torch.cat([ choice_score.unsqueeze(1).squeeze(-1) for choice_score in choice_score_list ], dim=1) if compute_loss: if loss_weights is not None: loss_fct = nn.CrossEntropyLoss(reduction='none') loss = loss_fct(logits.view(-1, self.num_choices), batch.label_id.view(-1)) loss = (loss * loss_weights).mean() else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_choices), batch.label_id.view(-1)) return LogitsAndLossOutput(logits=logits, loss=loss, other=reshaped_outputs) else: return LogitsOutput(logits=logits, other=reshaped_outputs)
def forward(self, batch, task, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) # TODO: Abuse of notation - these aren't really logits (Issue #45) logits = self.regression_head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), batch.label.view(-1)) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct( logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) logits = self.classification_head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) logits = self.qa_head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_qa_loss( logits=logits, start_positions=batch.start_position, end_positions=batch.end_position, ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False): masked_batch = batch.get_masked( mlm_probability=task.mlm_probability, tokenizer=tokenizer, do_mask=task.do_mask, ) encoder_output = get_output_from_encoder( encoder=self.encoder, input_ids=masked_batch.masked_input_ids, segment_ids=masked_batch.segment_ids, input_mask=masked_batch.input_mask, ) logits = self.mlm_head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_mlm_loss(logits=logits, masked_lm_labels=masked_batch.masked_lm_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = self.encoder.encode( input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, ) # TODO: Abuse of notation - these aren't really logits (issue #1187) logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), batch.label.view(-1)) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = self.encoder.encode( input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, ) logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.head.num_labels), batch.label_id.view(-1), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = self.encoder.encode( input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, ) logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_qa_loss( logits=logits, start_positions=batch.start_position, end_positions=batch.end_position, ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = self.encoder.encode( input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, ) logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct( logits.view(-1, self.head.num_labels), batch.label_ids.float(), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = self.encoder.encode( input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, ) logits = self.head(unpooled=encoder_output.unpooled) # Ensure logits in valid range is at least self.offset_margin higher than others logits_offset = logits.max() - logits.min() + self.offset_margin logits = logits + logits_offset * batch.selection_token_mask.unsqueeze( dim=2) if compute_loss: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.transpose(dim0=1, dim1=2).flatten(end_dim=1), batch.gt_span_idxs.flatten(), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, tokenizer, compute_loss: bool = False): choice_score_list = [] encoder_output_other_ls = [] for i in range(self.num_choices): encoder_output = self.encoder.encode( input_ids=batch.input_ids[:, i], segment_ids=batch.segment_ids[:, i], input_mask=batch.input_mask[:, i], ) choice_score = self.head(pooled=encoder_output.pooled) choice_score_list.append(choice_score) encoder_output_other_ls.append(encoder_output.other) reshaped_outputs = [] if encoder_output_other_ls[0]: for j in range(len(encoder_output_other_ls[0])): reshaped_outputs.append([ torch.stack( [misc[j][layer_i] for misc in encoder_output_other_ls], dim=1, ) for layer_i in range(len(encoder_output_other_ls[0][0])) ]) reshaped_outputs = tuple(reshaped_outputs) logits = torch.cat([ choice_score.unsqueeze(1).squeeze(-1) for choice_score in choice_score_list ], dim=1) if compute_loss: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_choices), batch.label_id.view(-1)) return LogitsAndLossOutput(logits=logits, loss=loss, other=reshaped_outputs) else: return LogitsOutput(logits=logits, other=reshaped_outputs)
def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None): if loss_weights is not None: raise NotImplementedError() encoder_output = get_output_from_encoder_and_batch( encoder=self.encoder, batch=batch) logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct( logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(), ) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def forward(self, batch, task, tokenizer, compute_loss: bool = False, loss_weights: torch.Tensor = None): if loss_weights is not None: raise NotImplementedError() encoder_output = get_output_from_encoder_and_batch( encoder=self.encoder, batch=batch) logits = self.token_classification_head( unpooled=encoder_output.unpooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() active_loss = batch.label_mask.view(-1) == 1 active_logits = logits.view( -1, self.token_classification_head.num_labels)[active_loss] active_labels = batch.label_ids.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other)