def forward(
        self,  # type: ignore
        premises: Dict[str, torch.LongTensor],
        hypotheses: Dict[str, torch.LongTensor],
        paragraph: Dict[str, torch.LongTensor],
        answer_index: torch.LongTensor = None,
        relevance_presence_mask: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        hypothesis_list = unbind_tensor_dict(hypotheses, dim=1)

        label_logits = []
        premises_attentions = []
        premises_aggregation_attentions = []
        coverage_losses = []
        for hypothesis in hypothesis_list:
            output_dict = super().forward(premises=premises,
                                          hypothesis=hypothesis,
                                          paragraph=paragraph)
            individual_logit = output_dict["label_logits"][:, self._label2idx[
                "entailment"]]  # only useful key
            label_logits.append(individual_logit)

            premises_attention = output_dict.get("premises_attention", None)
            premises_attentions.append(premises_attention)
            premises_aggregation_attention = output_dict.get(
                "premises_aggregation_attention", None)
            premises_aggregation_attentions.append(
                premises_aggregation_attention)
            if relevance_presence_mask is not None:
                coverage_loss = output_dict["coverage_loss"]
                coverage_losses.append(coverage_loss)

        label_logits = torch.stack(label_logits, dim=-1)
        premises_attentions = torch.stack(premises_attentions, dim=1)
        premises_aggregation_attentions = torch.stack(
            premises_aggregation_attentions, dim=1)
        if relevance_presence_mask is not None:
            coverage_losses = torch.stack(coverage_losses, dim=0)

        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "premises_attentions": premises_attentions,
            "premises_aggregation_attentions": premises_aggregation_attentions
        }

        if answer_index is not None:
            # answer_loss
            loss = self._answer_loss(label_logits, answer_index)
            # coverage loss
            if relevance_presence_mask is not None:
                loss += coverage_losses.mean()
            output_dict["loss"] = loss

            self._accuracy(label_logits, answer_index)

        return output_dict
Beispiel #2
0
    def forward(self,  # type: ignore
                premises: Dict[str, torch.LongTensor],
                hypotheses: Dict[str, torch.LongTensor],
                paragraph: Dict[str, torch.LongTensor],
                answer_correctness_mask: torch.IntTensor = None,
                relevance_presence_mask: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        hypothesis_list = unbind_tensor_dict(hypotheses, dim=1)

        label_logits = []
        premises_attentions = []
        coverage_losses = []
        for hypothesis in hypothesis_list:
            output_dict = super().forward(premises=premises, hypothesis=hypothesis,
            	                          paragraph=paragraph, relevance_presence_mask=relevance_presence_mask)
            individual_logit = output_dict["label_logits"]
            label_logits.append(individual_logit)

            if relevance_presence_mask is not None:
                premises_attention = output_dict["premises_attention"]
                premises_attentions.append(premises_attention)
                coverage_loss = output_dict["coverage_loss"]
                coverage_losses.append(coverage_loss)

        label_logits = torch.stack(label_logits, dim=1)
        if relevance_presence_mask is not None:
            premises_attentions = torch.stack(premises_attentions, dim=1)
            coverage_losses = torch.stack(coverage_losses, dim=0)

        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
        output_dict = {"label_logits": label_logits[:, :, self._label2idx["entailment"]],
                       "label_probs": label_probs[:, :, self._label2idx["entailment"]]}
        if relevance_presence_mask is not None:
            output_dict["premises_attentions"] = premises_attentions

        if answer_correctness_mask is not None:
            label = ((answer_correctness_mask == 1).long()*self._label2idx["entailment"]
                     + (answer_correctness_mask == 0).long()*self._label2idx["neutral"]
                     + (answer_correctness_mask == -1).long()*self._ignore_index)
            loss = self._answer_loss(label_logits.reshape((-1, label_logits.shape[-1])), label.reshape((-1)))

            # coverage loss
            if relevance_presence_mask is not None:
                loss += coverage_losses.mean()
            output_dict["loss"] = loss

            mask = answer_correctness_mask != -1
            self._accuracy(label_logits, label, mask)
            self._entailment_f1(label_logits, label, mask)

        return output_dict
Beispiel #3
0
    def forward(
        self,  # type: ignore
        premises: Dict[str, torch.LongTensor],
        hypotheses: Dict[str, torch.LongTensor],
        paragraph: Dict[str, torch.LongTensor],
        answer_index: torch.LongTensor = None,
        relevance_presence_mask: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        hypothesis_list = unbind_tensor_dict(hypotheses, dim=1)

        label_logits = []
        premises_attentions = []
        premises_aggregation_attentions = []
        #coverage_losses = []
        for hypothesis in hypothesis_list:  # single hypothesis even to the parent class
            #print("super().forward",len(premises), len(hypothesis), len(paragraph))
            output_dict = super().forward(premises=premises,
                                          hypothesis=hypothesis,
                                          paragraph=paragraph)  #paragraph?
            individual_logit = output_dict["label_logits"][:, self._label2idx[
                "entailment"]]  # only useful key
            label_logits.append(individual_logit)
            #
            premises_attention = output_dict.get("premises_attention", None)
            premises_attentions.append(premises_attention)
            premises_aggregation_attention = output_dict.get(
                "premises_aggregation_attention", None)
            premises_aggregation_attentions.append(
                premises_aggregation_attention)
            #if relevance_presence_mask is not None:
            #coverage_loss = output_dict["coverage_loss"]
            #coverage_losses.append(coverage_loss)
            del output_dict, individual_logit, premises_attention, premises_aggregation_attention

        label_logits = torch.stack(label_logits, dim=-1)
        premises_attentions = torch.stack(premises_attentions, dim=1)
        premises_aggregation_attentions = torch.stack(
            premises_aggregation_attentions, dim=1)
        #if relevance_presence_mask is not None:
        #coverage_losses = torch.stack(coverage_losses, dim=0)

        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
        # @todo:  Check covaraince of label_logits and label_probs
        if label_logits.shape[1] < self.max_sent_count:
            label_logits = torch.nn.functional.pad(
                input=label_logits,
                pad=(0, self.max_sent_count - label_logits.shape[1], 0, 0),
                mode='constant',
                value=0)

        single_output_logit = self.fc3(self.fc2(self.fc1(label_logits)))
        sigmoid_output = self.out_sigmoid(single_output_logit)
        #import pdb; pdb.set_trace()

        output_dict = {
            "label_logits": single_output_logit,
            "label_probs": sigmoid_output,
            "premises_attentions": premises_attentions,
            "premises_aggregation_attentions": premises_aggregation_attentions
        }

        if answer_index is not None:
            #print("_answer_loss",single_output_logit, answer_index)
            cudadevice = single_output_logit.device  # torch.device('cuda:'+ str(single_output_logit.get_device()))
            temp_tensor = torch.tensor([[k]
                                        for k in answer_index]).to(cudadevice)
            sgd = torch.nn.Sigmoid()
            loss = self._answer_loss(sgd(single_output_logit),
                                     sgd(temp_tensor.float()))
            output_dict["loss"] = loss
            output_dict["novelty"] = (single_output_logit > 0.5)
            temp_tensor = torch.tensor([[k] for k in answer_index])
            #print("_answer_loss",single_output_logit, temp_tensor)
            self._accuracy(single_output_logit > 0.5, temp_tensor.byte())
            del temp_tensor, loss, cudadevice

            #self._accuracy(single_output_logit>0.5, answer_index)
        del label_logits, label_probs, hypothesis_list,
        # if answer_index is not None:
        # answer_loss
        # loss = self._answer_loss(label_logits, answer_index)
        # coverage loss
        # if relevance_presence_mask is not None:
        #     loss += coverage_losses.mean()
        # output_dict["loss"] = loss

        # self._accuracy(label_logits, answer_index)

        return output_dict