class BartMetricLearningModel(BartPretrainedModel): def __init__(self, config: BartConfig, **kwargs): super().__init__(config, **kwargs) self.model = BartModel(config) self.classification_head = BartClassificationHead( config.d_model, config.d_model, config.num_labels, config.classifier_dropout, ) self.metric_hidden_size = 256 self.metric_linear = nn.Linear(config.hidden_size, self.metric_hidden_size) # self.label_metric_linear = nn.Linear(config.hidden_size, self.metric_hidden_size) # self.predict_linear = nn.Linear(self.metric_hidden_size * 2, ) self.scl_t = 1 self.ce_p = 0.8 self.scl_p = 0.1 self.lscl_p = 0.1 self.ce_loss_fct = CrossEntropyLoss() self.model._init_weights(self.classification_head.dense) self.model._init_weights(self.classification_head.out_proj) def scl_func(self, anchor_vectors, labels): """ <<SUPERVISED CONTRASTIVE LEARNING FOR PRE-TRAINED LANGUAGE MODEL FINE-TUNING>> :param anchor_vector: batch_size * hidden_size :param labels: :return: """ total_losses = 0 anchor_vectors = anchor_vectors.squeeze(dim=1) for i in range(anchor_vectors.shape[0]): anchor_vector = anchor_vectors[i, :] # other_index = torch.from_numpy(np.tile(np.array(list(filter(lambda x: x != i, range(anchor_vectors.shape[0])))), # anchor_vectors.shape[1]).reshape(anchor_vectors.shape[1], -1)) # other_vectors = torch.gather(anchor_vectors.transpose(1, 0), dim=1, index=other_index).transpose(1, 0) other_vectors = np.delete(anchor_vectors.detach().cpu(), i, 0).to(anchor_vector.device) same_labels = torch.where(labels == labels[i]) same_label_vectors = anchor_vectors[same_labels] if same_label_vectors.shape[0] > 0: up = torch.exp( torch.cosine_similarity(same_label_vectors, anchor_vector.unsqueeze(0)) / self.scl_t) down = torch.sum( torch.exp( torch.cosine_similarity(other_vectors, anchor_vector.unsqueeze(0)) / self.scl_t)) singe_sample_loss = torch.sum(torch.log( up / down)) / -(anchor_vectors.shape[0] - 1) total_losses += singe_sample_loss return total_losses def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, label_positions=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ label_max_position = torch.max(label_positions[-1]).tolist() return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False if input_ids is None and inputs_embeds is not None: raise NotImplementedError( f"Passing input embeddings is currently not supported for {self.__class__.__name__}" ) outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] # last hidden state eos_mask = input_ids.eq(self.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError( "All examples must have the same number of <eos> tokens.") sentence_representation = sequence_output[eos_mask, :].view( sequence_output.size(0), -1, sequence_output.size(-1))[:, -1, :] anchor_vector = sentence_representation.unsqueeze(dim=1) label_vectors = None for positions in label_positions: position = positions[0] label_vector = sequence_output[:, position, :] label_vector = torch.mean(label_vector, dim=1).unsqueeze(dim=1) if label_vectors is None: label_vectors = label_vector else: label_vectors = torch.cat([label_vectors, label_vector], dim=1) anchor_vector = self.metric_linear(anchor_vector) label_vectors = self.metric_linear(label_vectors) logits = torch.cosine_similarity(label_vectors, anchor_vector, dim=2) loss = None if labels is not None: ce_loss = self.ce_loss_fct(logits, labels) scl_loss = self.scl_func(anchor_vector.squeeze(dim=1), labels) / 10 # true_label_vectors = label_vectors[range(len(labels)), labels, :] # scl_label_loss = self.scl_func(true_label_vectors, labels) / 10 # center_loss = self.center_loss_fct(anchor_vector, labels) # label_distance_loss = self.label_distance_loss_fct(label_vectors) loss = ce_loss * self.ce_p + scl_loss * self.scl_p # loss = ce_loss * self.ce_p + scl_loss * self.scl_p + scl_label_loss * self.lscl_p # loss = ce_loss if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else output return ZeroShotOutput(loss=loss, logits=logits, anchor_vector=anchor_vector, label_vectors=label_vectors, hidden_states=sequence_output)
class BartSumRank( BartForConditionalGeneration, BartForSequenceClassification # type: ignore ): def __init__(self, config: BartConfig, **kwargs: Any): """The classification init is a super set of LM init""" PretrainedBartModel.__init__(self, config, **kwargs) self.model = BartModel(config) self.classification_head = BartClassificationHead( config.d_model, config.d_model, config.num_labels, config.classif_dropout) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.model._init_weights(self.classification_head.dense) self.model._init_weights(self.classification_head.out_proj) self.model._init_weights(self.lm_head) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.Tensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, decoder_cached_states: Optional[torch.Tensor] = None, lm_labels: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, rank_labels: Optional[torch.Tensor] = None, mode: str = "summarizer", **kwargs: Any, ) -> Any: """Versatile forward interface. By default it should behaves as an LM head so it's compatible with the `generate()` interface. lm_batch_mask: Used when the input_ids contain negative documents which are not used for LM. rank_labels: Labels for ranking. """ model_outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, decoder_cached_states=decoder_cached_states, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if mode == "summarizer": lm_hidden = model_outputs[0] # LM head lm_logits = self.lm_head(lm_hidden) if lm_labels is not None: lm_loss = F.cross_entropy( lm_logits.view(-1, self.config.vocab_size), lm_labels.reshape(-1)) outputs = (lm_loss, lm_logits) + model_outputs[1:] else: outputs = (lm_logits, ) + model_outputs[1:] return outputs elif mode == "ranker": # Rank head rank_hidden = model_outputs[0] # last hidden state bsz_idx = list(range(rank_hidden.size(0))) if decoder_attention_mask is not None: next_token_idx = decoder_attention_mask.sum(dim=1) - 1 else: assert attention_mask is not None next_token_idx = attention_mask.sum(dim=1) - 1 # Use next word prediction as sentence representation sentence_representation = rank_hidden[bsz_idx, next_token_idx] rank_logits = self.classification_head(sentence_representation) if rank_labels is not None: loss = F.cross_entropy( rank_logits.view(-1, self.config.num_labels), rank_labels.view(-1)) outputs = (loss, rank_logits) + model_outputs[1:] else: outputs = (rank_logits, ) + model_outputs[1:] return outputs else: assert False, f"Unknown mode {mode}" def shared_grads(self) -> Optional[torch.Tensor]: grads_list = [] for name, params in self.model.named_parameters(): if params.requires_grad: if params.grad is not None: grads_list.append(params.grad.flatten().cpu()) if not grads_list: return None grads = torch.cat(grads_list) return grads