Esempio n. 1
0
    def test_auc_computation(self, device: str):
        auc = Auc()
        all_predictions = []
        all_labels = []
        for _ in range(5):
            predictions = torch.randn(8, device=device)
            labels = torch.randint(0, 2, (8,), dtype=torch.long, device=device)

            auc(predictions, labels)

            all_predictions.append(predictions)
            all_labels.append(labels)

        computed_auc_value = auc.get_metric(reset=True)

        false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
            torch.cat(all_labels, dim=0).cpu().numpy(),
            torch.cat(all_predictions, dim=0).cpu().numpy(),
        )
        real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
        assert_allclose(real_auc_value, computed_auc_value)

        # One more computation to assure reset works.
        predictions = torch.randn(8, device=device)
        labels = torch.randint(0, 2, (8,), dtype=torch.long, device=device)

        auc(predictions, labels)
        computed_auc_value = auc.get_metric(reset=True)

        false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
            labels.cpu().numpy(), predictions.cpu().numpy()
        )
        real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
        assert_allclose(real_auc_value, computed_auc_value)
Esempio n. 2
0
    def test_auc_computation(self):
        auc = Auc()
        all_predictions = []
        all_labels = []
        for _ in range(5):
            predictions = torch.randn(8).float()
            labels = torch.randint(0, 2, (8, )).long()

            auc(predictions, labels)

            all_predictions.append(predictions)
            all_labels.append(labels)

        computed_auc_value = auc.get_metric(reset=True)

        false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
            torch.cat(all_labels, dim=0).numpy(),
            torch.cat(all_predictions, dim=0).numpy())
        real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
        assert_almost_equal(real_auc_value, computed_auc_value)

        # One more computation to assure reset works.
        predictions = torch.randn(8).float()
        labels = torch.randint(0, 2, (8, )).long()

        auc(predictions, labels)
        computed_auc_value = auc.get_metric(reset=True)

        false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
            labels.numpy(), predictions.numpy())
        real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
        assert_almost_equal(real_auc_value, computed_auc_value)
Esempio n. 3
0
class Classifier(BasicClassifier):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 seq2vec_encoder: Seq2VecEncoder,
                 seq2seq_encoder: Seq2SeqEncoder = None,
                 num_labels: int = None) -> None:

        super().__init__(vocab,
                         text_field_embedder,
                         seq2vec_encoder,
                         seq2seq_encoder,
                         num_labels=num_labels)
        if self._num_labels == 2:
            self._auc = Auc()
            self._f1 = F1Measure(1)

    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()
        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        if label is not None and self._num_labels == 2:
            self._auc(output_dict['probs'][:, 1], label.long().view(-1))
            self._f1(output_dict['probs'], label.long().view(-1))

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = super().get_metrics(reset)
        if self._num_labels == 2:
            metrics.update({
                'auc': self._auc.get_metric(reset),
                'f1': self._f1.get_metric(reset)[2]
            })
        return metrics
Esempio n. 4
0
class MortalityClassifier(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 regularizer_applicator: RegularizerApplicator = None):
        super().__init__(vocab, regularizer_applicator)
        self.embedder = embedder
        self.encoder = encoder
        num_labels = vocab.get_vocab_size(
            "labels"
        )  # the labels was constructed. i.e. even though we did not explicitly do anything to it, it knows how large it should be!
        logger.info("num labels is as follows: {}".format(num_labels)
                    )  # why does it ned to know the labels converison however?
        self.classifier = torch.nn.Linear(encoder.get_output_dim(), num_labels)
        self.accuracy = CategoricalAccuracy()
        self.auc = Auc()
        self.reg_app = regularizer_applicator

    def forward(self, text: Dict[str, torch.Tensor],
                label: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, num_tokens, embedding_dim)
        embedded_text = self.embedder(text)
        logger.info("Embedded text shape is as follows: {}".format(
            (embedded_text.shape)))

        # Shape: (batch_size, num_tokens)
        mask = util.get_text_field_mask(text)
        # Shape: (batch_size, encoding_dim)
        encoded_text = self.encoder(embedded_text, mask)
        print(encoded_text.shape)

        # Shape: (batch_size, num_labels)
        logits = self.classifier(encoded_text)
        print(logits.shape)
        # Shape: (batch_size, num_labels)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # reg_loss = self.get_regularization_penalty() # should not have to manually apply the regularization
        # Shape: (1,)
        loss = torch.nn.functional.cross_entropy(logits, label)
        self.accuracy(logits, label)
        preds = logits.argmax(-1)
        self.auc(preds, label)
        output = {'loss': loss, 'probs': probs}
        return output

    '''this is called'''

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            "accuracy": self.accuracy.get_metric(reset),
            "auc": self.auc.get_metric(reset)
        }
Esempio n. 5
0
    def test_auc_with_mask(self):
        auc = Auc()

        predictions = torch.randn(8).float()
        labels = torch.randint(0, 2, (8, )).long()
        mask = torch.ByteTensor([1, 1, 1, 1, 0, 0, 0, 0])

        auc(predictions, labels, mask)
        computed_auc_value = auc.get_metric(reset=True)

        false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
            labels[:4].numpy(), predictions[:4].numpy())
        real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
        assert_almost_equal(real_auc_value, computed_auc_value)
Esempio n. 6
0
class HatefulMemeModel(Model):
    def __init__(self, vocab: Vocabulary, text_model_name: str):
        super().__init__(vocab)
        self._text_model = BertForSequenceClassification.from_pretrained(
            text_model_name)
        self._num_labels = vocab.get_vocab_size()

        self._accuracy = Average()
        self._auc = Auc()

        self._softmax = torch.nn.Softmax(dim=1)

    def forward(
        self,
        source_tokens: TextFieldTensors,
        box_features: Optional[Tensor] = None,
        box_coordinates: Optional[Tensor] = None,
        box_mask: Optional[Tensor] = None,
        label: Optional[Tensor] = None,
        metadata: Optional[Dict] = None,
    ) -> Dict[str, torch.Tensor]:
        input_ids = source_tokens["tokens"]["token_ids"]
        input_mask = source_tokens["tokens"]["mask"]
        token_type_ids = source_tokens["tokens"]["type_ids"]
        outputs = self._text_model(
            input_ids=input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
            labels=label,
        )

        if label is not None:
            predictions = torch.argmax(self._softmax(outputs.logits), dim=-1)
            for index in range(predictions.shape[0]):
                correct = float((predictions[index] == label[index]))
                self._accuracy(int(correct))

            self._auc(predictions, label)

        return outputs

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        if not self.training:
            metrics["accuracy"] = self._accuracy.get_metric(reset=reset)
            metrics["auc"] = self._auc.get_metric(reset=reset)
        return metrics
Esempio n. 7
0
    def test_auc_with_mask(self, device: str):
        auc = Auc()

        predictions = torch.randn(8, device=device)
        labels = torch.randint(0, 2, (8,), dtype=torch.long, device=device)
        mask = torch.tensor([True, True, True, True, False, False, False, False], device=device)

        auc(predictions, labels, mask)
        computed_auc_value = auc.get_metric(reset=True)

        false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
            labels[:4].cpu().numpy(), predictions[:4].cpu().numpy()
        )
        real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
        assert_allclose(real_auc_value, computed_auc_value)
class DecompensationClassifier(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 regularizer_applicator: RegularizerApplicator = None):
        super().__init__(vocab, regularizer_applicator)
        self.embedder = embedder
        self.encoder = encoder
        num_labels = vocab.get_vocab_size("labels")
        self.classifier = torch.nn.Linear(encoder.get_output_dim(), num_labels)
        self.accuracy = CategoricalAccuracy()
        self.auc = Auc()
        self.reg_app = regularizer_applicator

    def forward(self, text: Dict[str, torch.Tensor],
                label: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, num_tokens, embedding_dim)
        embedded_text = self.embedder(text)
        # Shape: (batch_size, num_tokens)
        mask = util.get_text_field_mask(text)
        # Shape: (batch_size, encoding_dim)
        encoded_text = self.encoder(
            embedded_text,
            mask)  #horizontal; vertical (partial depth) might be good
        # Shape: (batch_size, num_labels)
        logits = self.classifier(encoded_text)
        # Shape: (batch_size, num_labels)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # reg_loss = self.get_regularization_penalty() # should not have to manually apply the regularization
        # Shape: (1,)
        loss = torch.nn.functional.cross_entropy(logits, label)
        self.accuracy(logits, label)
        preds = logits.argmax(-1)
        self.auc(preds, label)
        output = {'loss': loss, 'probs': probs}
        return output

    '''this is called'''

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            "accuracy": self.accuracy.get_metric(reset),
            "auc": self.auc.get_metric(reset)
        }
Esempio n. 9
0
    def test_auc_gold_labels_behaviour(self):
        # Check that it works with different pos_label
        auc = Auc(positive_label=4)

        predictions = torch.randn(8).float()
        labels = torch.randint(3, 5, (8, )).long()

        auc(predictions, labels)
        computed_auc_value = auc.get_metric(reset=True)

        false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
            labels.numpy(), predictions.numpy(), pos_label=4)
        real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
        assert_almost_equal(real_auc_value, computed_auc_value)

        # Check that it errs on getting more than 2 labels.
        with pytest.raises(ConfigurationError) as _:
            labels = torch.LongTensor([3, 4, 5, 6, 7, 8, 9, 10])
            auc(predictions, labels)
Esempio n. 10
0
    def test_auc_gold_labels_behaviour(self, device: str):
        # Check that it works with different pos_label
        auc = Auc(positive_label=4)

        predictions = torch.randn(8, device=device)
        labels = torch.randint(3, 5, (8, ), dtype=torch.long, device=device)
        # We make sure that the positive label is always present.
        labels[0] = 4
        auc(predictions, labels)
        computed_auc_value = auc.get_metric(reset=True)

        false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
            labels.cpu().numpy(), predictions.cpu().numpy(), pos_label=4)
        real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
        assert_allclose(real_auc_value, computed_auc_value)

        # Check that it errs on getting more than 2 labels.
        with pytest.raises(ConfigurationError) as _:
            labels = torch.tensor([3, 4, 5, 6, 7, 8, 9, 10], device=device)
            auc(predictions, labels)
Esempio n. 11
0
class L2AFBertBaseline(Model):
    """
    BERT baseline model

    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        attnpool: AttnPooling,
        output_ffl: FeedForward,
        initializer: InitializerApplicator,
        dropout: float = 0.3,
    ) -> None:
        super().__init__(vocab)
        self._text_field_embedder = text_field_embedder

        self._variational_dropout = InputVariationalDropout(dropout)
        self._attn_pool = attnpool
        self._output_ffl = output_ffl

        self._num_labels = vocab.get_vocab_size(namespace="labels")
        self._auc = Auc()
        self._loss = torch.nn.BCELoss()
        initializer(self)

    def forward(
            self,  # type: ignore
            combined_source: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # pylint: disable=arguments-differ
        """

        :param combined_source:
        :param label:
        :param metadata:
        :return:
        """
        embedded_source = self._text_field_embedder(
            combined_source)  # B * T * H
        source_mask = get_text_field_mask(combined_source)  # B * T
        embedded_source = self._variational_dropout(embedded_source)

        pooled = self._attn_pool(embedded_source, source_mask)  # B * H
        choice_score = self._output_ffl(pooled)  # B * 1

        output = torch.sigmoid(choice_score).squeeze(-1)  # B

        output_dict = {
            "label_logits": choice_score.squeeze(-1),
            "label_probs": output,
            "metadata": metadata
        }

        if label is not None:
            label = label.long().view(-1)
            loss = self._loss(output, label.float())
            self._auc(output, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'auc': self._auc.get_metric(reset),
        }
Esempio n. 12
0
class DocumentRanker(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        relevance_matcher: RelevanceMatcher,
        dropout: float = None,
        num_labels: int = None,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:

        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        self._relevance_matcher = TimeDistributed(relevance_matcher)

        self._dropout = dropout and torch.nn.Dropout(dropout)

        self._auc = Auc()
        self._mrr = MRR(padding_value=-1)
        self._ndcg = NDCG(padding_value=-1)

        self._loss = torch.nn.MSELoss(reduction='none')
        initializer(self)

    # @torchsnooper.snoop()
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,  # batch * words
        options: TextFieldTensors,  # batch * num_options * words
        labels: torch.IntTensor = None  # batch * num_options
    ) -> Dict[str, torch.Tensor]:
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).long()

        embedded_options = self._text_field_embedder(
            options, num_wrapping_dims=1)  # options_mask.dim() - 2
        options_mask = get_text_field_mask(options).long()

        if self._dropout:
            embedded_text = self._dropout(embedded_text)
            embedded_options = self._dropout(embedded_options)
        """
        This isn't exactly a 'hack', but it's definitely not the most efficient way to do it.
        Our matcher expects a single (query, document) pair, but we have (query, [d_0, ..., d_n]).
        To get around this, we expand the query embeddings to create these pairs, and then
        flatten both into the 3D tensor [batch*num_options, words, dim] expected by the matcher. 
        The expansion does this:

        [
            (q_0, [d_{0,0}, ..., d_{0,n}]), 
            (q_1, [d_{1,0}, ..., d_{1,n}])
        ]
        =>
        [
            [ (q_0, d_{0,0}), ..., (q_0, d_{0,n}) ],
            [ (q_1, d_{1,0}), ..., (q_1, d_{1,n}) ]
        ]

        Which we then flatten along the batch dimension. It would likely be more efficient
        to rewrite the matrix multiplications in the relevance matchers, but this is a more general solution.
        """

        embedded_text = embedded_text.unsqueeze(1).expand(
            -1, embedded_options.size(1), -1,
            -1)  # [batch, num_options, words, dim]
        mask = mask.unsqueeze(1).expand(-1, embedded_options.size(1), -1)

        scores = self._relevance_matcher(embedded_text, embedded_options, mask,
                                         options_mask).squeeze(-1)
        probs = torch.sigmoid(scores)

        output_dict = {"logits": scores, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)
        if labels is not None:
            label_mask = (labels != -1)

            self._mrr(probs, labels, label_mask)
            self._ndcg(probs, labels, label_mask)

            probs = probs.view(-1)
            labels = labels.view(-1)
            label_mask = label_mask.view(-1)

            self._auc(probs, labels.ge(0.5).long(), label_mask)

            loss = self._loss(probs, labels)
            output_dict["loss"] = loss.masked_fill(~label_mask,
                                                   0).sum() / label_mask.sum()

        return output_dict

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {
            "auc": self._auc.get_metric(reset),
            "mrr": self._mrr.get_metric(reset),
            "ndcg": self._ndcg.get_metric(reset),
        }
        return metrics

    default_predictor = "document_ranker"
Esempio n. 13
0
 def test_auc_works_without_calling_metric_at_all(self, device: str):
     auc = Auc()
     auc.get_metric()
Esempio n. 14
0
File: model.py Progetto: buptzzs/MQA
class MultiStepParaRankModel(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 pq_attention: MatrixAttention,
                 p_selfattention: MatrixAttention,
                 supports_pooling: Seq2VecEncoder,
                 query_pooling: Seq2VecEncoder,
                 candidates_pooling: Seq2VecEncoder,
                 decoder: Decoder,
                 dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(MultiStepParaRankModel, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder

        self.phrase_layer = phrase_layer

        self.pq_attention = pq_attention
        self.p_selfattention = p_selfattention

        self.supports_pooling = supports_pooling
        self.query_pooling = query_pooling
        self.candidates_pooling = candidates_pooling

        self.decoder = decoder

        self.dropout = InputVariationalDropout(p=dropout)

        self._support_accuracy = Auc()
        self._candidate_accuracy = CategoricalAccuracy()

        initializer(self)

    def forward(
        self,
        query: Dict[str, torch.LongTensor],
        supports: Dict[str, torch.LongTensor],
        candidates: Dict[str, torch.LongTensor],
        answer: Dict[str, torch.LongTensor] = None,
        answer_index: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
        supports_labels: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        embedded_supports = self.text_field_embedder(supports)
        embedded_query = self.text_field_embedder(query)
        embedded_candidates = self.text_field_embedder(candidates)

        batch_size, support_num, support_length, embed_dim = embedded_supports.size(
        )
        _, candidate_num, candidate_length, _ = embedded_candidates.size()
        _, query_length, _ = embedded_query.size()

        supports_mask_para = util.get_text_field_mask(supports)
        supports_mask_seq = util.get_text_field_mask(supports,
                                                     num_wrapping_dims=1)
        supports_mask_seq_expand = supports_mask_seq.view(
            -1, supports_mask_seq.size(-1))

        candidates_mask_para = util.get_text_field_mask(candidates)
        candidates_mask_seq = util.get_text_field_mask(candidates,
                                                       num_wrapping_dims=1)
        candidates_mask_seq_expand = candidates_mask_seq.view(
            -1, candidates_mask_seq.size(-1))

        query_mask = util.get_text_field_mask(query)
        query_mask_expand = query_mask.unsqueeze(1).expand(
            query_mask.size(0), support_num, query_mask.size(1))
        query_mask_expand = query_mask_expand.contiguous().view(
            -1, query_mask_expand.size(-1))

        embedded_supports_expand = embedded_supports.view(
            -1, support_length, embed_dim)
        embedded_candidates_expand = embedded_candidates.view(
            -1, candidate_length, embed_dim)

        encoded_query = self.phrase_layer(self.dropout(embedded_query),
                                          query_mask)
        encoded_supports = self.phrase_layer(
            self.dropout(embedded_supports_expand), supports_mask_seq_expand)
        encoded_candidates = self.phrase_layer(
            self.dropout(embedded_candidates_expand),
            candidates_mask_seq_expand)

        encoded_query_expand = encoded_query.unsqueeze(1).expand(
            batch_size, support_num, encoded_query.size(1),
            encoded_query.size(2))
        encoded_query_expand = encoded_query_expand.contiguous().view(
            -1, encoded_query.size(1), encoded_query.size(2))

        # Co-attention

        # shape: (batch_size*passage_num, passage_length, question_length )
        supports_query_similarity = self.pq_attention(encoded_supports,
                                                      encoded_query_expand)

        # shape: (batch_size*passage_num, passage_length, question_length )
        supports_query_attention = util.masked_softmax(
            supports_query_similarity,
            query_mask_expand,
            memory_efficient=True)
        # shape: (batch_size*passage_num, passage_length, encoding_dim)
        supports_query_vectors = util.weighted_sum(encoded_query_expand,
                                                   supports_query_attention)

        # shape: (batch_size*passage_num, query_length, passage_length)
        query_passage_attention = util.masked_softmax(
            supports_query_similarity.transpose(1, 2),
            supports_mask_seq_expand,
            memory_efficient=True)
        # shape: (batch_size*passage_num, query_length, encoding_dim)
        query_supports_vectors = util.weighted_sum(encoded_supports,
                                                   query_passage_attention)

        # shape: (batch_size*passage_num, passage_length, encoding_dim)
        supports_query_vectors_2 = torch.bmm(supports_query_attention,
                                             query_supports_vectors)
        # shape: (batch_size*passage_num, passage_length, encoding_dim*2)
        supports_coattention_vectors = torch.cat(
            [supports_query_vectors, supports_query_vectors_2], dim=-1)

        # Fusion, 暂时用简单的fusion函数
        #supports_coattention_vectors = co_attention_fusion(torch.cat([encoded_supports,supports_query_vectors_final], dim=-1))
        suppports_self_similarity = self.p_selfattention(
            supports_coattention_vectors, supports_coattention_vectors)
        supports_selfattention = util.masked_softmax(suppports_self_similarity,
                                                     supports_mask_seq_expand,
                                                     memory_efficient=True)
        supports_selfatt_vectors = util.weighted_sum(
            supports_coattention_vectors, supports_selfattention)
        #support_selfatt_fusion = self_attention_fusion(util.combine_tensors('1,2,1-2,1*2',[supports_coattention_vectors, supports_selfatt_vectors]))

        supports_pooling_vectors = self.supports_pooling(
            supports_selfatt_vectors, supports_mask_seq_expand)
        supports_pooling_vectors = supports_pooling_vectors.view(
            batch_size, support_num, -1)
        question_pooling_vectors = self.query_pooling(encoded_query,
                                                      query_mask)

        candidates_pooling_vectors = self.candidates_pooling(
            encoded_candidates, candidates_mask_seq_expand)
        candidates_pooling_vectors = candidates_pooling_vectors.view(
            batch_size, -1, candidates_pooling_vectors.size(-1))

        # supports porb normalized, candidates_score: unnormalized
        supports_prob, candidates_score = self.decoder(
            supports_pooling_vectors, question_pooling_vectors,
            candidates_pooling_vectors, supports_mask_para)
        candidates_score = util.replace_masked_values(candidates_score,
                                                      candidates_mask_para,
                                                      -1e7)

        output_dict = {
            "supports_prob": supports_prob,
            "candidates_score": candidates_score
        }
        if supports_labels is not None:
            supports_prob = util.replace_masked_values(supports_prob,
                                                       supports_mask_para,
                                                       -1e32)
            s_loss = binary_cross_entropy_with_logits(supports_prob,
                                                      supports_labels)
            c_loss = nll_loss(
                util.masked_log_softmax(candidates_score,
                                        candidates_mask_para),
                answer_index.squeeze(-1))
            loss = s_loss + c_loss
            self._support_accuracy(
                supports_prob.view(-1, 1).squeeze(),
                supports_labels.view(-1, 1).squeeze(),
                supports_mask_para.view(-1).squeeze().detach().cpu())
            self._candidate_accuracy(candidates_score,
                                     answer_index.squeeze(-1))

            output_dict['loss'] = loss
            output_dict['s_loss'] = s_loss
            output_dict['c_loss'] = c_loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'support_auc': self._support_accuracy.get_metric(reset),
            'candidate_acc': self._candidate_accuracy.get_metric(reset)
        }
class AnswerHelpfulPredictionModel(Model):
    """
    This is the implementation of the RAPH model proposed in the paper.
    Given a question, its answer and relevant reviews, predict whether 
        the answer is helpful or not.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 context_encoder: Seq2SeqEncoder,
                 qa_attention_module: MatrixAttention,
                 text_encoder_qa_matching: Seq2VecEncoder,
                 qa_matching_layer: FeedForward,
                 qr_attention_module: Attention,
                 text_encoder_ra_entailment: Seq2VecEncoder,
                 ra_matching_layer: FeedForward,
                 predict_layer: FeedForward,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None):
        super(AnswerHelpfulPredictionModel, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.context_encoder = context_encoder
        self.qa_attention_module = qa_attention_module
        self.text_encoder_qa_matching = text_encoder_qa_matching
        self.qa_matching_layer = qa_matching_layer
        self.qr_attention_module = qr_attention_module
        self.text_encoder_ra_entailment = text_encoder_ra_entailment
        self.ra_matching_layer = ra_matching_layer
        self.predict_layer = predict_layer

        # performance scores are running values, reset the values every epoch
        self.f1_measure = F1Measure(positive_label=1)
        self.auc_score = Auc(positive_label=1)
        self.accuracy = CategoricalAccuracy()

        self.criterion = torch.nn.CrossEntropyLoss()
        initializer(self)

    @overrides
    def forward(self, question, answer, reviews, helpful=None):

        # ----------------------------------------------------------
        # layer-1: Embed q/a/reviews and Encode context infomation

        # shape = (batch_size, seq_len)
        q_mask = get_text_field_mask(question)
        a_mask = get_text_field_mask(answer)
        # shape = (batch_size, 5, seq_len)
        r_mask = get_text_field_mask(reviews, num_wrapping_dims=1)

        # shape = (batch_size, seq_len, embed_dim)
        embedded_q = self.text_field_embedder(question)
        embedded_a = self.text_field_embedder(answer)
        # shape = (batch_size, 5, seq_len, embed_dim)
        embedded_r = self.text_field_embedder(reviews, num_wrapping_dims=1)
        review_ls = [(embedded_r[:, i, :, :], r_mask[:, i, :])
                     for i in range(5)]

        context_q = self.context_encoder(embedded_q, q_mask)
        context_a = self.context_encoder(embedded_a, a_mask)
        # shape of context_r[i]: (bs, seq_len, encoding_dim)
        context_r = [self.context_encoder(r[0], r[1]) for r in review_ls]

        # ----------------------------------------------------------
        # layer-2: QA Matching

        # shape = (bs, len_q, len_a)
        sim_matrix = self.qa_attention_module(context_q, context_a)
        # masked attention to remove those paddings in Q/A
        a2q_attention = masked_softmax(sim_matrix, a_mask)
        q2a_attention = masked_softmax(
            sim_matrix.transpose(1, 2).contiguous(), q_mask)

        # shape = (batch_size, len_q, encoding_dim)
        attended_q_from_a = weighted_sum(context_a, a2q_attention)
        attended_a_from_q = weighted_sum(context_q, q2a_attention)

        v_q = torch.cat([context_q, attended_q_from_a], dim=-1)
        v_a = torch.cat([context_a, attended_a_from_q], dim=-1)

        # encoding the sequence info to a fixed vector
        o_q = self.text_encoder_qa_matching(v_q, q_mask)
        o_a_q = self.text_encoder_qa_matching(v_a, a_mask)
        qa_matching_score = self.qa_matching_layer(
            torch.cat([o_q, o_a_q], dim=-1))

        # ----------------------------------------------------------
        # layer-3: RA coherence modeling

        # use question text to highlight relevant review info
        q_enhanced_r = []
        qr_attention_weights = []

        for i in range(5):
            # Q2R attention (vec-matrix attention)
            # shape = (bs, len_r_i)
            beta = self.qr_attention_module(o_q, context_r[i], review_ls[i][1])

            qr_attention_weights.append(beta)  # for visulazation

            enhanced_r_i = torch.matmul(beta.unsqueeze(1),
                                        context_r[i]).squeeze(1)
            q_enhanced_r.append(enhanced_r_i)

        # for visulazation attention weights
        qr_attention_weights_op = torch.cat(qr_attention_weights, dim=-1)

        # v_a/v_r shape = (bs, encoding_shape)
        o_a_r = self.text_encoder_ra_entailment(context_a, a_mask)
        o_r = [
            self.text_encoder_ra_entailment(context_r[i], review_ls[i][1]) +
            q_enhanced_r[i] for i in range(5)
        ]

        # "K entailment checking"
        entailment_result = [
            self.ra_matching_layer(torch.cat([o_a_r, o_r[i]], dim=-1))
            for i in range(5)
        ]
        entailment_score = torch.cat(entailment_result, dim=1)

        # ----------------------------------------------------------
        # layer-4: Final prediction layer

        logits = self.predict_layer(
            torch.cat([qa_matching_score, entailment_score], dim=-1))
        probs = F.softmax(logits, dim=-1)
        output_dict = {
            'probs': probs,
            'entailment_score': entailment_score,
            'qr_attention': qr_attention_weights_op,
            "qa_attention": sim_matrix
        }

        if helpful is not None:
            loss = self.criterion(logits, helpful)
            self.f1_measure(logits, helpful)
            self.auc_score(probs[:, 1], helpful)
            # self.accuracy(logits, helpful)
            output_dict['loss'] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {}
        precision, recall, f1_measure = self.f1_measure.get_metric(reset)
        metrics["f1"] = f1_measure
        metrics["auc"] = self.auc_score.get_metric(reset)
        # metrics["accuracy"] = self.accuracy.get_metric(reset)
        return metrics
Esempio n. 16
0
class ContextualAtt(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sec_name_encoder: Seq2SeqEncoder,
                 sent_encoder: Seq2SeqEncoder,
                 classifier_feedforward: FeedForward,
                 encoder_attention: Attention = DotProductAttention(
                     normalize=True),
                 label_namespace: str = "labels",
                 class_weight=[1.0, 1.0],
                 dropout: Optional[float] = None,
                 calculate_f1: bool = None,
                 calculate_auc: bool = None,
                 calculate_auc_pr: bool = None,
                 positive_label: int = 1,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(ContextualAtt, self).__init__(vocab, regularizer)

        self.label_namespace = label_namespace
        self.num_tags = self.vocab.get_vocab_size()
        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size(label_namespace)
        self.sec_name_encoder = sec_name_encoder
        self.sent_encoder = sent_encoder
        self.attention = encoder_attention

        # self.attention_scale = math.sqrt(encoder.get_output_dim())

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = None
        self.classifier_feedforward = classifier_feedforward
        if classifier_feedforward is not None:
            output_dim = classifier_feedforward.get_output_dim()

        self.metrics = {"accuracy": CategoricalAccuracy()}

        # if isinstance(class_weight, list) and len(class_weight)>0:
        #     # [0.2419097587861097, 1.0]
        #     self.loss = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weight))
        # else:
        #     self.loss = torch.nn.CrossEntropyLoss()

        self.loss = torch.nn.CrossEntropyLoss()

        self.positive_label = positive_label
        self.calculate_f1 = calculate_f1
        self.calculate_auc = calculate_auc
        self.calculate_auc_pr = calculate_auc_pr

        if calculate_f1:
            self._f1_metric = F1Measure(positive_label)

        if calculate_auc:
            self._auc = Auc(positive_label)
        if calculate_auc_pr:
            self._auc_pr = AucPR(positive_label)

        # check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
        #                        "text field embedding dim", "encoder input dim")

        if classifier_feedforward is not None:
            check_dimensions_match(
                sent_encoder.get_output_dim() * 3 +
                sec_name_encoder.get_output_dim() + 8,
                classifier_feedforward.get_input_dim(), "encoder output dim",
                "feedforward input dim")

        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            section_name: Dict[str, torch.LongTensor],
            prev_sent: Dict[str, torch.LongTensor],
            cur_sent: Dict[str, torch.LongTensor],
            next_sent: Dict[str, torch.LongTensor],
            additional_features: torch.Tensor = None,
            label: torch.LongTensor = None,
            metadata: List[str] = None,
            **kwargs) -> Dict[str, torch.Tensor]:

        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        section_name : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_tensor()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.

        prev_sent : Dict[str, Variable], required
            The output of ``TextField.as_array()``.

        cur_sent : Dict[str, Variable], required
            The output of ``TextField.as_array()``.

        next_sent : Dict[str, Variable], required
            The output of ``TextField.as_array()``.

         additional_features :  Variable
            A variable representing the additional features for each instance in the batch.

        label : Variable
            A variable representing the label for each instance in the batch.

        metadata : ``List[Dict[str, Any]]``, optional (default = None)
            Metadata containing the original "cur_sent_id" field.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the
            label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        def build_encoder_layer(field, encoder, need_attention=True):
            mask = util.get_text_field_mask(field)
            embedded_text_input = self.text_field_embedder(field)
            if self.dropout:
                embedded_text_input = self.dropout(embedded_text_input)
            encoded_text = encoder(embedded_text_input, mask)
            if self.dropout:
                encoded_text = self.dropout(encoded_text)
            # (batch_size, num_directions * hidden_size)
            try:
                final_hidden_states = util.get_final_encoder_states(
                    encoded_text, mask, encoder.is_bidirectional())
            except:
                print(field)

            if not need_attention:
                return final_hidden_states

            # Add attention here
            attention_weights = self.attention(final_hidden_states,
                                               encoded_text,
                                               mask).unsqueeze(-1)

            # rnn_output (batch_size, seq_len, num_directions * hidden_size)
            # ==> (batch_size, num_directions * hidden_size, seq_len)
            rnn_output_re_order = encoded_text.permute(0, 2, 1)
            attention_output = torch.bmm(rnn_output_re_order,
                                         attention_weights).squeeze(-1)
            return attention_output

        final_states_sec_name = build_encoder_layer(section_name,
                                                    self.sec_name_encoder,
                                                    need_attention=True)
        final_states_prev_sent = build_encoder_layer(prev_sent,
                                                     self.sent_encoder,
                                                     need_attention=True)
        final_states_cur_sent = build_encoder_layer(cur_sent,
                                                    self.sent_encoder,
                                                    need_attention=True)
        final_states_next_sent = build_encoder_layer(next_sent,
                                                     self.sent_encoder,
                                                     need_attention=True)

        embeded_features = torch.cat(
            (final_states_sec_name, final_states_prev_sent,
             final_states_cur_sent, final_states_next_sent,
             additional_features),
            dim=-1)
        logits = self.classifier_feedforward(embeded_features)

        output_dict = {
            "cur_sent_id": metadata,
            "logits": logits,
            "golden_label": label
        }

        if label is not None:
            loss = self.loss(logits, label)
            output_dict["loss"] = loss
            for metric in self.metrics.values():
                metric(logits, label)
            if self.calculate_f1:
                self._f1_metric(logits, label)

            class_probabilities = F.softmax(logits, dim=-1)
            output_dict['class_probabilities'] = class_probabilities
            positive_class_prob = class_probabilities[:, 1].detach()
            # label_prob, label_index = torch.max(class_probabilities, -1)
            # argmax_indices = numpy.argmax(class_probabilities, axis=-1)
            output_dict['positive_class_prob'] = positive_class_prob
            if self.calculate_auc:
                self._auc(positive_class_prob, label)
            if self.calculate_auc_pr:
                self._auc_pr(positive_class_prob, label)

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics_to_return = {
            metric_name: metric.get_metric(reset)
            for metric_name, metric in self.metrics.items()
        }

        if self.calculate_f1:
            p_r_f1 = self._f1_metric.get_metric(reset=reset)
            precision, recall, f1_measure = p_r_f1
            f1_dict = {
                'precision': precision,
                'recall': recall,
                "f1": f1_measure
            }
            metrics_to_return.update(f1_dict)

        if self.calculate_auc:
            metrics_to_return["auc"] = self._auc.get_metric(reset=reset)
        if self.calculate_auc_pr:
            metrics_to_return["auc_pr"] = self._auc_pr.get_metric(reset=reset)

        return metrics_to_return
class MortalityClassifier(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 regularizer_applicator: RegularizerApplicator = None,
                 num_classes=1):
        super().__init__(vocab, regularizer_applicator)
        self.embedder = embedder
        self.encoder = encoder
        self.classifier = torch.nn.Linear(encoder.get_output_dim(),
                                          num_classes)
        self.accuracy = CategoricalAccuracy()
        self.auc = Auc()
        self.reg_app = regularizer_applicator

    def forward(self, text: Dict[str, torch.Tensor], label: torch.Tensor,
                metadata) -> Dict[str, torch.Tensor]:

        # assert that metadata has the same length as the other ones. Then, they are parallel

        # Shape: (batch_size, num_tokens, embedding_dim)
        embedded_text = self.embedder(text)
        # Shape: (batch_size, num_tokens)
        mask = util.get_text_field_mask(text)
        # Shape: (batch_size, encoding_dim)
        encoded_text = self.encoder(
            embedded_text,
            mask)  #horizontal; vertical (partial depth) might be good
        # Shape: (batch_size, num_labels)
        logits = self.classifier(encoded_text)
        # Shape: (batch_size, num_labels)
        # probs = torch.nn.functional.softmax(logits, dim=-1)
        probs = torch.sigmoid(logits)

        # reg_loss = self.get_regularization_penalty() # should not have to manually apply the regularization
        # Shape: (1,)
        # loss = torch.nn.functional.cross_entropy(logits, label)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, label.float())

        # self.accuracy(logits, label.squeeze())
        # preds = logits.argmax(-1)
        # probs_1 = logits[:,-1]
        # AUC can no longer be computed directly now. instead, we will need to supply a user function
        # self.auc(probs_1, label)

        # bleed everything through into the output
        output = {
            'loss': loss,
            'probs': probs,
            "metadata": metadata,
            "label": label
        }  #no need to yield the label here

        return output

    '''this is called. it both gets, and resets, if reset=True'''

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            "accuracy": self.accuracy.get_metric(reset),
            "auc": self.auc.get_metric(reset)
        }
Esempio n. 18
0
class L2AF3WayAPM(Model):
    """
    3-way Attentive Pooling Network
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 pseqlevelenc: Seq2SeqEncoder,
                 qaseqlevelenc: Seq2SeqEncoder,
                 choicelevelenc: Seq2SeqEncoder,
                 cartesian_attn: SeqAttnMat,
                 pcattnmat: SeqAttnMat,
                 gate_qdep_penc: GatedEncoding,
                 qdep_penc_rnn: Seq2SeqEncoder,
                 mfa_enc: GatedMultifactorSelfAttnEnc,
                 mfa_rnn: Seq2SeqEncoder,
                 pqaattnmat: SeqAttnMat,
                 cqaattnmat: SeqAttnMat,
                 initializer: InitializerApplicator,
                 dropout: float = 0.3,
                 is_qdep_penc: bool = True,
                 is_mfa_enc: bool = True,
                 with_knowledge: bool = True,
                 is_qac_ap: bool = True,
                 shared_rnn: bool = True) -> None:
        super().__init__(vocab)
        self._text_field_embedder = text_field_embedder
        self._pseqlevel_enc = pseqlevelenc
        self._qaseqlevel_enc = qaseqlevelenc
        self._cseqlevel_enc = choicelevelenc

        self._cart_attn = cartesian_attn
        self._pqaattnmat = pqaattnmat
        self._pcattnmat = pcattnmat
        self._cqaattnmat = cqaattnmat

        self._gate_qdep_penc = gate_qdep_penc
        self._qdep_penc_rnn = qdep_penc_rnn
        self._multifactor_attn = mfa_enc
        self._mfarnn = mfa_rnn

        self._with_knowledge = with_knowledge
        self._qac_ap = is_qac_ap
        if not self._with_knowledge:
            if not self._qac_ap:
                raise AssertionError
        self._is_qdep_penc = is_qdep_penc
        self._is_mfa_enc = is_mfa_enc
        self._shared_rnn = shared_rnn

        self._variational_dropout = InputVariationalDropout(dropout)

        self._num_labels = vocab.get_vocab_size(namespace="labels")
        self._auc = Auc()
        self._loss = torch.nn.BCELoss()
        initializer(self)

    def forward(
            self,  # type: ignore
            passage: Dict[str, torch.LongTensor],
            all_qa: Dict[str, torch.LongTensor],
            candidate: Dict[str, torch.LongTensor],
            combined_source: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """"""
        if self._with_knowledge:
            embedded_passage = self._text_field_embedder(passage)  # B * T * d
            passage_len = embedded_passage.size(1)
        embedded_all_qa = self._text_field_embedder(all_qa)  # B * U * d
        embedded_choice = self._text_field_embedder(candidate)  # B * V * d

        if self._with_knowledge:
            embedded_passage = self._variational_dropout(
                embedded_passage)  # B * T * d
        embedded_all_qa = self._variational_dropout(embedded_all_qa)
        embedded_choice = self._variational_dropout(
            embedded_choice)  # B * V * d

        all_qa_mask = util.get_text_field_mask(all_qa)  # B * U
        choice_mask = util.get_text_field_mask(candidate)  # B * V

        # Encoding
        if self._with_knowledge:
            # B * T * H
            passage_mask = util.get_text_field_mask(passage)  # B * T
            encoded_passage = self._variational_dropout(
                self._pseqlevel_enc(embedded_passage, passage_mask))
        # B * U * H
        if self._shared_rnn:
            encoded_allqa = self._variational_dropout(
                self._pseqlevel_enc(embedded_all_qa, all_qa_mask))
        else:
            encoded_allqa = self._variational_dropout(
                self._qaseqlevel_enc(embedded_all_qa, all_qa_mask))

        if self._with_knowledge and self._is_qdep_penc:
            # similarity matrix
            _, normalized_attn_mat = self._cart_attn(encoded_passage,
                                                     encoded_allqa,
                                                     all_qa_mask)  # B * T * U
            # question dependent passage encoding
            q_aware_passage_rep = sequential_weighted_avg(
                encoded_allqa, normalized_attn_mat)  # B * T * H

            q_dep_passage_enc_rnn_input = torch.cat(
                [encoded_passage, q_aware_passage_rep], 2)  # B * T * 2H

            # gated question dependent passage encoding
            gated_qaware_passage_rep = self._gate_qdep_penc(
                q_dep_passage_enc_rnn_input)  # B * T * 2H
            encoded_qdep_penc = self._qdep_penc_rnn(gated_qaware_passage_rep,
                                                    passage_mask)  # B * T * H

        # multi factor attentive encoding
        if self._with_knowledge and self._is_mfa_enc:
            if self._is_qdep_penc:
                mfa_enc = self._multifactor_attn(encoded_qdep_penc,
                                                 passage_mask)  # B * T * 2H
            else:
                mfa_enc = self._multifactor_attn(encoded_passage,
                                                 passage_mask)  # B * T * 2H
            encoded_passage = self._mfarnn(mfa_enc, passage_mask)  # B * T * H

        # B * V * H
        if self._shared_rnn:
            encoded_choice = self._variational_dropout(
                self._pseqlevel_enc(embedded_choice, choice_mask))  # B * V * H
        else:
            encoded_choice = self._variational_dropout(
                self._cseqlevel_enc(embedded_choice, choice_mask))  # B * V * H

        if self._with_knowledge:
            attn_pq, _ = self._pqaattnmat(encoded_passage, encoded_allqa,
                                          all_qa_mask)  # B * T * U
            combined_pqa_mask = passage_mask.unsqueeze(-1) * \
                                all_qa_mask.unsqueeze(1)  # B * T * U
            max_attn_pqa = masked_max(attn_pq, combined_pqa_mask,
                                      dim=1)  # B * U
            norm_attn_pqa = masked_softmax(max_attn_pqa, all_qa_mask,
                                           dim=-1)  # B * U
            agg_prev_qa = norm_attn_pqa.unsqueeze(1).bmm(
                encoded_allqa).squeeze(1)  # B * H

            attn_pc, _ = self._pcattnmat(encoded_passage, encoded_choice,
                                         choice_mask)  # B * T * V
            combined_pc_mask = passage_mask.unsqueeze(-1) * \
                               choice_mask.unsqueeze(1)  # B * T * V
            max_attn_pc = masked_max(attn_pc, combined_pc_mask, dim=1)  # B * V
            norm_attn_pc = masked_softmax(max_attn_pc, choice_mask,
                                          dim=-1)  # B * V
            agg_c = norm_attn_pc.unsqueeze(1).bmm(encoded_choice)  # B * 1 * H

            choice_scores_wk = agg_c.bmm(agg_prev_qa.unsqueeze(-1)).squeeze(
                -1)  # B * 1

        if self._qac_ap:
            attn_qac, _ = self._cqaattnmat(encoded_allqa, encoded_choice,
                                           choice_mask)  # B * U * V
            combined_qac_mask = all_qa_mask.unsqueeze(-1) * \
                                choice_mask.unsqueeze(1)  # B * U * V

            max_attn_c = masked_max(attn_qac, combined_qac_mask,
                                    dim=1)  # B * V
            max_attn_qa = masked_max(attn_qac, combined_qac_mask,
                                     dim=2)  # B * U
            norm_attn_c = masked_softmax(max_attn_c, choice_mask,
                                         dim=-1)  # B * V
            norm_attn_qa = masked_softmax(max_attn_qa, all_qa_mask,
                                          dim=-1)  # B * U
            agg_c_qa = norm_attn_c.unsqueeze(1).bmm(encoded_choice).squeeze(
                1)  # B * H
            agg_qa_c = norm_attn_qa.unsqueeze(1).bmm(encoded_allqa).squeeze(
                1)  # B * H

            choice_scores_nk = agg_c_qa.unsqueeze(1).bmm(
                agg_qa_c.unsqueeze(-1)).squeeze(-1)  # B * 1

        if self._with_knowledge and self._qac_ap:
            choice_score = choice_scores_wk + choice_scores_nk
        elif self._qac_ap:
            choice_score = choice_scores_nk
        elif self._with_knowledge:
            choice_score = choice_scores_wk
        else:
            raise NotImplementedError

        output = torch.sigmoid(choice_score).squeeze(-1)  # B

        output_dict = {
            "label_logits": choice_score.squeeze(-1),
            "label_probs": output,
            "metadata": metadata
        }

        if label is not None:
            label = label.long().view(-1)
            loss = self._loss(output, label.float())
            self._auc(output, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'auc': self._auc.get_metric(reset),
        }