class BartMetricLearningModel(BartPretrainedModel):
    def __init__(self, config: BartConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.model = BartModel(config)
        self.classification_head = BartClassificationHead(
            config.d_model,
            config.d_model,
            config.num_labels,
            config.classifier_dropout,
        )
        self.metric_hidden_size = 256
        self.metric_linear = nn.Linear(config.hidden_size,
                                       self.metric_hidden_size)
        # self.label_metric_linear = nn.Linear(config.hidden_size, self.metric_hidden_size)
        # self.predict_linear = nn.Linear(self.metric_hidden_size * 2, )
        self.scl_t = 1

        self.ce_p = 0.8
        self.scl_p = 0.1
        self.lscl_p = 0.1
        self.ce_loss_fct = CrossEntropyLoss()

        self.model._init_weights(self.classification_head.dense)
        self.model._init_weights(self.classification_head.out_proj)

    def scl_func(self, anchor_vectors, labels):
        """
        <<SUPERVISED CONTRASTIVE LEARNING FOR PRE-TRAINED LANGUAGE MODEL FINE-TUNING>>
        :param anchor_vector: batch_size * hidden_size
        :param labels:
        :return:
        """

        total_losses = 0
        anchor_vectors = anchor_vectors.squeeze(dim=1)
        for i in range(anchor_vectors.shape[0]):
            anchor_vector = anchor_vectors[i, :]
            # other_index = torch.from_numpy(np.tile(np.array(list(filter(lambda x: x != i, range(anchor_vectors.shape[0])))),
            #                                        anchor_vectors.shape[1]).reshape(anchor_vectors.shape[1], -1))
            # other_vectors = torch.gather(anchor_vectors.transpose(1, 0), dim=1, index=other_index).transpose(1, 0)

            other_vectors = np.delete(anchor_vectors.detach().cpu(), i,
                                      0).to(anchor_vector.device)
            same_labels = torch.where(labels == labels[i])
            same_label_vectors = anchor_vectors[same_labels]
            if same_label_vectors.shape[0] > 0:
                up = torch.exp(
                    torch.cosine_similarity(same_label_vectors,
                                            anchor_vector.unsqueeze(0)) /
                    self.scl_t)
                down = torch.sum(
                    torch.exp(
                        torch.cosine_similarity(other_vectors,
                                                anchor_vector.unsqueeze(0)) /
                        self.scl_t))
                singe_sample_loss = torch.sum(torch.log(
                    up / down)) / -(anchor_vectors.shape[0] - 1)
                total_losses += singe_sample_loss
        return total_losses

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        label_positions=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        label_max_position = torch.max(label_positions[-1]).tolist()
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if labels is not None:
            use_cache = False

        if input_ids is None and inputs_embeds is not None:
            raise NotImplementedError(
                f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
            )

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]  # last hidden state

        eos_mask = input_ids.eq(self.config.eos_token_id)

        if len(torch.unique(eos_mask.sum(1))) > 1:
            raise ValueError(
                "All examples must have the same number of <eos> tokens.")
        sentence_representation = sequence_output[eos_mask, :].view(
            sequence_output.size(0), -1, sequence_output.size(-1))[:, -1, :]

        anchor_vector = sentence_representation.unsqueeze(dim=1)

        label_vectors = None
        for positions in label_positions:
            position = positions[0]
            label_vector = sequence_output[:, position, :]
            label_vector = torch.mean(label_vector, dim=1).unsqueeze(dim=1)
            if label_vectors is None:
                label_vectors = label_vector
            else:
                label_vectors = torch.cat([label_vectors, label_vector], dim=1)

        anchor_vector = self.metric_linear(anchor_vector)
        label_vectors = self.metric_linear(label_vectors)
        logits = torch.cosine_similarity(label_vectors, anchor_vector, dim=2)

        loss = None
        if labels is not None:
            ce_loss = self.ce_loss_fct(logits, labels)
            scl_loss = self.scl_func(anchor_vector.squeeze(dim=1), labels) / 10
            # true_label_vectors = label_vectors[range(len(labels)), labels, :]
            # scl_label_loss = self.scl_func(true_label_vectors, labels) / 10

            # center_loss = self.center_loss_fct(anchor_vector, labels)
            # label_distance_loss = self.label_distance_loss_fct(label_vectors)

            loss = ce_loss * self.ce_p + scl_loss * self.scl_p
            # loss = ce_loss * self.ce_p + scl_loss * self.scl_p + scl_label_loss * self.lscl_p
            # loss = ce_loss

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output
        return ZeroShotOutput(loss=loss,
                              logits=logits,
                              anchor_vector=anchor_vector,
                              label_vectors=label_vectors,
                              hidden_states=sequence_output)
Пример #2
0
class BartSumRank(
        BartForConditionalGeneration,
        BartForSequenceClassification  # type: ignore
):
    def __init__(self, config: BartConfig, **kwargs: Any):
        """The classification init is a super set of LM init"""
        PretrainedBartModel.__init__(self, config, **kwargs)
        self.model = BartModel(config)

        self.classification_head = BartClassificationHead(
            config.d_model, config.d_model, config.num_labels,
            config.classif_dropout)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        self.model._init_weights(self.classification_head.dense)
        self.model._init_weights(self.classification_head.out_proj)
        self.model._init_weights(self.lm_head)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.Tensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        decoder_cached_states: Optional[torch.Tensor] = None,
        lm_labels: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        rank_labels: Optional[torch.Tensor] = None,
        mode: str = "summarizer",
        **kwargs: Any,
    ) -> Any:
        """Versatile forward interface. By default it should behaves as an LM head so it's
           compatible with the `generate()` interface.

        lm_batch_mask: Used when the input_ids contain negative documents which are not
                       used for LM.

        rank_labels: Labels for ranking.

        """

        model_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            decoder_cached_states=decoder_cached_states,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        if mode == "summarizer":
            lm_hidden = model_outputs[0]

            # LM head
            lm_logits = self.lm_head(lm_hidden)
            if lm_labels is not None:
                lm_loss = F.cross_entropy(
                    lm_logits.view(-1, self.config.vocab_size),
                    lm_labels.reshape(-1))
                outputs = (lm_loss, lm_logits) + model_outputs[1:]
            else:
                outputs = (lm_logits, ) + model_outputs[1:]
            return outputs
        elif mode == "ranker":
            # Rank head
            rank_hidden = model_outputs[0]  # last hidden state
            bsz_idx = list(range(rank_hidden.size(0)))
            if decoder_attention_mask is not None:
                next_token_idx = decoder_attention_mask.sum(dim=1) - 1
            else:
                assert attention_mask is not None
                next_token_idx = attention_mask.sum(dim=1) - 1

            # Use next word prediction as sentence representation
            sentence_representation = rank_hidden[bsz_idx, next_token_idx]
            rank_logits = self.classification_head(sentence_representation)
            if rank_labels is not None:
                loss = F.cross_entropy(
                    rank_logits.view(-1, self.config.num_labels),
                    rank_labels.view(-1))
                outputs = (loss, rank_logits) + model_outputs[1:]
            else:
                outputs = (rank_logits, ) + model_outputs[1:]
            return outputs
        else:
            assert False, f"Unknown mode {mode}"

    def shared_grads(self) -> Optional[torch.Tensor]:
        grads_list = []
        for name, params in self.model.named_parameters():
            if params.requires_grad:
                if params.grad is not None:
                    grads_list.append(params.grad.flatten().cpu())
        if not grads_list:
            return None
        grads = torch.cat(grads_list)
        return grads