Beispiel #1
0
    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()
        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        if label is not None and self._num_labels == 2:
            self._auc(output_dict['probs'][:, 1], label.long().view(-1))
            self._f1(output_dict['probs'], label.long().view(-1))

        return output_dict
Beispiel #2
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            metadata: Dict[str, Any] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        input_ids = tokens[self._index]
        token_type_ids = tokens[f"{self._index}-type-ids"]
        input_mask = (input_ids != 0).long()

        _, pooled = self.bert_model(input_ids=input_ids,
                                    token_type_ids=token_type_ids,
                                    attention_mask=input_mask)

        pooled = self._dropout(pooled)

        # apply classification layer
        label_logits = self._classification_layer(pooled)

        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_probs": label_probs[..., 1]}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._f1(label_probs[..., 1], label.long().view(-1))
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["metadata"] = metadata

        return output_dict
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise)
        hypothesis_mask = get_text_field_mask(hypothesis)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        output_dict = self.esim_forward(encoded_premise, encoded_hypothesis, premise_mask, hypothesis_mask, label=label)

        # If we're training, also compute loss and accuracy for the bias-only model
        if not self.evaluation_mode:
            hyp_only_logits = self._classification_layer_hyp_only(get_final_encoder_states(encoded_hypothesis, hypothesis_mask,self._encoder.is_bidirectional()))

            log_probs_pair = torch.log_softmax(output_dict["label_logits"], dim=1)
            log_probs_hyp = torch.log_softmax(hyp_only_logits, dim=1)

            # Combine with product of experts (normalized log space sum)
            # Do not require gradients from hyp-only classifier
            combined = log_probs_pair + log_probs_hyp.detach()

            # NLL loss over combined labels
            loss = self._nll_loss(combined, label.long().view(-1))
            hyp_loss = self._nll_loss(log_probs_hyp, label.long().view(-1))
            self._accuracy(combined, label)
            self._hyp_only_accuracy(hyp_only_logits, label)

            output_dict = {
                "loss": loss + self._beta * hyp_loss
            }
            return output_dict
        else:
            loss = self._cross_ent_loss(output_dict["label_logits"],label)
            output_dict["loss"] = loss
            self._accuracy(output_dict["label_logits"], label)

            return output_dict
Beispiel #4
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = F.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)
            self._f1_measure(logits, label)

        return output_dict
    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        ######################## token embedding #####################
        embedded_text = self._text_field_embedder(tokens)

        ######################## token padding #####################
        mask = get_text_field_mask(tokens).float()

        ######################## dropout+model #####################
        encoded_text = self._dropout(
            self._seq2vec_encoder(embedded_text, mask=mask))

        ######################## 分类结果和预测概率 #####################
        logits = self._classification_layer(encoded_text)
        probs = F.softmax(logits, dim=1)

        output_dict = {'logits': logits, 'probs': probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict['loss'] = loss
            self._accuracy(logits, label)

        return output_dict
Beispiel #6
0
    def _hidden_to_output(
        self,
        # (batch_size, hidden_size)
        hidden_state: torch.LongTensor,
        page: torch.IntTensor = None,
    ):
        # (batch_size, n_classes)
        logits = self._classifier(hidden_state)
        # (batch_size, n_classes)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        output_dict = {
            "logits": logits,
            "probs": probs,
            "preds": torch.argmax(logits, 1)
        }

        if page is not None:
            loss = self._loss(logits, page.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, page)
        if self.top_k is not None:
            output_dict["top_k_scores"], output_dict[
                "top_k_indices"] = torch.topk(probs, self.top_k, dim=-1)

        return output_dict
Beispiel #7
0
    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        embedded_text = self.bn(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = F.softmax(logits, dim=1)

        output_dict = {'logits': logits, 'probs': probs}
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict['loss'] = loss
            self._accuracy(logits, label)

        return output_dict
    def forward(  # type: ignore
            self,
            tokens: Dict[str, torch.Tensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        # (batch_size, max_len, embedding_dim)
        embeddings = self.embedder(tokens)

        # the first embedding is for the [CLS] token
        # NOTE: this pre-supposes BERT encodings; not the most elegant!
        # (batch_size, embedding_dim)
        cls_embedding = embeddings[:, 0, :]

        # apply classification layer
        # (batch_size, num_labels)
        logits = self._classification_layer(cls_embedding)

        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict
Beispiel #9
0
    def forward(self,
                features: torch.Tensor,
                metadata: List[Dict[str, Any]] = None,
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        features: torch.Tensor,
            From a ``FloatField`` over the overlap features computed by the SimpleOverlapReader
        metadata: List[Dict[str, Any]]
            Metadata information
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        label_logits = self.linear_mlp(features)
        label_probs = torch.nn.functional.softmax(label_logits)
        output_dict = {"label_logits": label_logits, "label_probs": label_probs}
        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss
        return output_dict
Beispiel #10
0
    def forward(  # type: ignore
            self,
            tokens,
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        token_ids = tokens['tokens']['token_ids']
        type_ids = tokens['tokens']['type_ids']
        # mask = tokens['tokens']['mask']
        segment_concat_mask = tokens['tokens']['segment_concat_mask']
        # print(token_ids)
        # print(type_ids)
        # print(segment_concat_mask)

        sequence_output, pooled_output = self.bert_model(
            input_ids=token_ids,
            token_type_ids=type_ids,
            attention_mask=segment_concat_mask)

        logits = self._classification_layer(pooled_output)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # print(tokens)

        outputs = self._bert(tokens['bert'],
                             attention_mask=None,
                             token_type_ids=None,
                             position_ids=None,
                             head_mask=None)
        if self._dropout:
            embedded_text = self._dropout(outputs[1])

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)
            self._f1_measure(logits, label)

        return output_dict
Beispiel #12
0
    def forward(self,
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        premise = self._text_field_embedder(premise)
        hypothesis = self._text_field_embedder(hypothesis)

        premise = self.encode(premise, premise_mask)
        hypothesis = self.encode(hypothesis, hypothesis_mask)

        aggregate_input = torch.cat([
            premise, hypothesis,
            torch.abs(premise - hypothesis), premise * hypothesis
        ], 1)

        output_dict = {
            "final_hidden": aggregate_input,
        }

        if self._aggregate_feedforward:
            label_logits = self._aggregate_feedforward(aggregate_input)
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
            output_dict["label_logits"] = label_logits
            output_dict["label_probs"] = label_probs

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Beispiel #13
0
    def forward(  # type: ignore
        self,
        sent1: TextFieldTensors,
        sent2: TextFieldTensors,
        label: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:
        with adv_utils.forward_context("sent1"):
            encoded_sent1 = self.encoder(
                self.rotation(self.word_embedders(sent1)),
                get_text_field_mask(sent1))
        with adv_utils.forward_context("sent2"):
            encoded_sent2 = self.encoder(
                self.rotation(self.word_embedders(sent2)),
                get_text_field_mask(sent2))

        encoded = torch.cat([encoded_sent1, encoded_sent2], dim=1)

        output_hidden = self.feedforward(encoded)
        label_logits = self.output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"logits": label_logits, "probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Beispiel #14
0
    def forward(
            self,  # type: ignore
            image: torch.FloatTensor,
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        output = image
        if self._im2im_encoder:
            output = self._im2im_encoder(output)

        output = self._im2vec_encoder(output)

        if self._dropout:
            output = self._dropout(output)

        logits = self._classification_layer(output)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict
Beispiel #15
0
    def forward(  # type: ignore
        self,
        sent1: TextFieldTensors,
        sent2: TextFieldTensors,
        label: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:
        with adv_utils.forward_context("sent1"):
            embedded_sent1 = self._text_field_embedder(sent1)
        with adv_utils.forward_context("sent2"):
            embedded_sent2 = self._text_field_embedder(sent2)
        sent1_mask = get_text_field_mask(sent1)
        sent2_mask = get_text_field_mask(sent2)

        projected_sent1 = self._attend_feedforward(embedded_sent1)
        projected_sent2 = self._attend_feedforward(embedded_sent2)
        # Shape: (batch_size, sent1_length, sent2_length)
        similarity_matrix = self._matrix_attention(projected_sent1,
                                                   projected_sent2)

        # Shape: (batch_size, sent1_length, sent2_length)
        p2h_attention = masked_softmax(similarity_matrix, sent2_mask)
        # Shape: (batch_size, sent1_length, embedding_dim)
        attended_sent2 = weighted_sum(embedded_sent2, p2h_attention)

        # Shape: (batch_size, sent2_length, sent1_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), sent1_mask)
        # Shape: (batch_size, sent2_length, embedding_dim)
        attended_sent1 = weighted_sum(embedded_sent1, h2p_attention)

        sent1_compare_input = torch.cat([embedded_sent1, attended_sent2],
                                        dim=-1)
        sent2_compare_input = torch.cat([embedded_sent2, attended_sent1],
                                        dim=-1)

        compared_sent1 = self._compare_feedforward(sent1_compare_input)
        compared_sent1 = compared_sent1 * sent1_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_sent1 = compared_sent1.sum(dim=1)

        compared_sent2 = self._compare_feedforward(sent2_compare_input)
        compared_sent2 = compared_sent2 * sent2_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_sent2 = compared_sent2.sum(dim=1)

        aggregate_input = torch.cat([compared_sent1, compared_sent2], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "logits": label_logits,
            "probs": label_probs,
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Beispiel #16
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            premise_tags: torch.LongTensor,
            hypothesis: Dict[str, torch.LongTensor],
            hypothesis_tags: torch.LongTensor,
            input_ids: torch.Tensor,
            token_type_ids: torch.Tensor,
            attention_mask: torch.Tensor,
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # running the parser
        encoded_p_parse, p_parse_mask = self._parser(premise, premise_tags)
        p_parse_encoder_final_state = get_final_encoder_states(
            encoded_p_parse, p_parse_mask)
        encoded_h_parse, h_parse_mask = self._parser(hypothesis,
                                                     hypothesis_tags)
        h_parse_encoder_final_state = get_final_encoder_states(
            encoded_h_parse, h_parse_mask)

        logits = self.bert_sc_model(p_parse_encoder_final_state,
                                    h_parse_encoder_final_state,
                                    torch.stack(input_ids),
                                    torch.stack(token_type_ids),
                                    torch.stack(attention_mask))
        output_dict = {"logits": logits}
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            self._accuracy(logits, label)
            output_dict["loss"] = loss

        return output_dict
    def forward(self,
                premise: Dict[str, torch.LongTensor],
                premise_entities: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                hypothesis_entities: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None
                ) -> Dict[str, torch.Tensor]:

        # Feed the data to text and graph model
        text_out = self._text_model(premise, hypothesis).get("final_hidden")

        graph_out = self._graph_model(premise_entities,
                                      hypothesis_entities).get("final_hidden")

        # combine the results (n x 4d)
        combined_input = torch.cat((text_out, graph_out), dim=-1)

        label_logits = self._classify_feed_forward(combined_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        choices_list: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``
        choices_list : Dict[str, torch.LongTensor]
            From a ``List[TextField]``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of each choice being the correct answer.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of each choice being the correct answer.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        encoded_choices_aggregated = embed_encode_and_aggregate_list_text_field(
            choices_list, self._text_field_embedder, self._embeddings_dropout,
            self._choice_encoder, self._choice_aggregate)  # bs, choices, hs

        encoded_question_aggregated = embed_encode_and_aggregate_text_field(
            question, self._text_field_embedder, self._embeddings_dropout,
            self._question_encoder, self._question_aggregate)  # bs, hs

        q_to_choices_att = self._matrix_attention_question_to_choice(
            encoded_question_aggregated.unsqueeze(1),
            encoded_choices_aggregated).squeeze()

        label_logits = q_to_choices_att
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
Beispiel #19
0
    def forward(self,
                claim: Dict[str, torch.LongTensor],
                evidence: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        claim : Dict[str, torch.LongTensor]
            From a ``TextField``
            The LongTensor Shape is typically ``(batch_size, sent_length)`
        evidence : Dict[str, torch.LongTensor]
            From a ``TextField``
            The LongTensor Shape is typically ``(batch_size, sent_length)`
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the claim and
            evidence sentences with 'claim_tokens' and 'premise_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        embedded_claim = self._embedder(claim)
        embedded_evidence = self._embedder(evidence)
        input_embeddings = torch.cat((embedded_claim, embedded_evidence), dim=1)

        projection = self._static_feedforward(input_embeddings)

        label_logits = self._feed_forward(projection)

        label_probs = F.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits,
                      "label_probs": label_probs}

        if label is not None:
           loss = self._loss(label_logits, label.long().view(-1))
           self._accuracy(label_logits, label)
           output_dict["loss"] = loss

        if metadata is not None:
           output_dict["claim_tokens"] = [x["claim_tokens"] for x in metadata]
           output_dict["evidence_tokens"] = [x["evidence_tokens"] for x in metadata]

        return output_dict
    def forward(self, # type: ignore
                tokens: Dict[str, torch.LongTensor],
                token_type_ids: torch.LongTensor,
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
                ) -> Dict[str, torch.Tensor]:

        debug = False
        # batch_size, num_of_choices, max_premise_perchoice, L
        input_ids = tokens['tokens']
        # batch_size, L
        input_mask = (input_ids != 0).long()

        # shape: batch_size*num_choices*max_premise_perchoice, max_len
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
        flat_attention_mask = input_mask.view(-1, input_mask.size(-1))

        # shape: batch_size*num_choices*max_premise_perchoice, hidden_dim
        _, pooled_ph = self.bert_model(input_ids=flat_input_ids,
                                    token_type_ids=flat_token_type_ids,
                                    attention_mask=flat_attention_mask)

        if debug:
            print(f"input_ids.size() = {input_ids.size()}")
            print(f"token_type_ids.size() = {token_type_ids.size()}")
            print(f"pooled_ph.size() = {pooled_ph.size()}")

        # batch*choice, max_premise_per_choice, hidden_dim
        pooled_ph = pooled_ph.view(-1,input_ids.size(2),pooled_ph.size(-1))

        max_pooled_ph,_ = torch.max(pooled_ph,dim=1,keepdim=False)

        if debug:
            print(f"max_pooled_ph.size() = {max_pooled_ph.size()}")

            max_pooled_ph = self._dropout(max_pooled_ph)

        # apply classification layer
        logits = self._classification_layer(max_pooled_ph)

        # shape: batch_size,num_choices
        reshaped_logits = logits.view(-1, input_ids.size(1))
        if debug:
            print(f"reshaped_logits = {reshaped_logits}")

        probs = torch.nn.functional.softmax(reshaped_logits, dim=-1)

        output_dict = {"logits": reshaped_logits, "probs": probs}

        if label is not None:
            loss = self._loss(reshaped_logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(reshaped_logits, label)

        return output_dict
Beispiel #21
0
    def forward(
            self,
            combined_source: Dict[str, torch.LongTensor],
            prev_turns: Dict[str, torch.LongTensor],
            curr_utt: Dict[str, torch.LongTensor],
            prev_turns_mask: torch.FloatTensor,
            utt_mask: torch.FloatTensor,
            label_scores: torch.FloatTensor = None,
            label_gold: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        embedded_source = self._text_field_embedder(combined_source)
        source_mask = get_text_field_mask(combined_source)
        embedded_source = self._variational_dropout(embedded_source)

        embedded_history = embedded_source * prev_turns_mask.unsqueeze(-1)
        embedded_question = embedded_source * utt_mask.unsqueeze(-1)

        scores = embedded_question.bmm(embedded_history.transpose(2, 1))
        mask = utt_mask.unsqueeze(1) * prev_turns_mask.unsqueeze(-1)
        alpha = masked_softmax(scores, mask, dim=-1)
        qdep_hist = alpha.bmm(embedded_history)

        x = torch.cat([embedded_question, qdep_hist], -1)
        x = self._qdep_henc_rnn(x, source_mask)
        x = x * utt_mask.unsqueeze(-1).float()

        x = self._senc_self_attn(x, utt_mask)
        x = self._variational_dropout(x)

        cls_tokens = self._attn_pool(x, utt_mask)

        pred_label_scores = self._output_ffl(cls_tokens)

        output = self._softmax(pred_label_scores)

        _, pred_label = output.max(1)

        assert output.shape[0] == pred_label.shape[0]

        output_dict = {
            'label_logits': pred_label_scores,
            'label_probs': output,
            'label_pred': pred_label,
            'metadata': metadata
        }

        if label_scores is not None:
            scores = label_scores
            label = label_gold.long()
            loss = self._loss(output, scores)
            self._accuracy(output, label)

            output_dict['loss'] = loss

        return output_dict
    def forward(  # type: ignore
            self,
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        feedforward_output = self._feedforward_layer(embedded_text)

        logits = self._classification_layer(feedforward_output)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            for i in range(self._num_labels):
                metric = self._label_f1_metrics[
                    self.vocab.get_token_from_index(index=i,
                                                    namespace="labels")]
                metric(probs, label)
            self._accuracy(logits, label)

        return output_dict
Beispiel #23
0
    def forward(
        self,  # type: ignore
        hypothesis0: Dict[str, torch.LongTensor],
        hypothesis1: Dict[str, torch.LongTensor],
        hypothesis2: Dict[str, torch.LongTensor],
        hypothesis3: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        logits = []
        for tokens in [hypothesis0, hypothesis1, hypothesis2, hypothesis3]:
            if isinstance(self.text_field_embedder, ElmoTokenEmbedder):
                self.text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states(
                )

            embedded_text_input = self.embedding_dropout(
                self.text_field_embedder(tokens))
            mask = get_text_field_mask(tokens)

            batch_size, sequence_length, _ = embedded_text_input.size()

            encoded_text = self.encoder(embedded_text_input, mask)

            logits.append(self.output_prediction(encoded_text.max(1)[0]))

        logits = torch.cat(logits, -1)
        class_probabilities = F.softmax(logits, dim=-1).view([batch_size, 4])
        output_dict = {
            "label_logits": logits,
            "label_probs": class_probabilities
        }

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            self._accuracy(logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
Beispiel #24
0
 def calc_loss(self, soft_estimation: torch.Tensor,
               transmitted_words: torch.IntTensor) -> torch.Tensor:
     """
     Cross Entropy loss - distribution over states versus the gt state label
     :param soft_estimation: [1,transmission_length,n_states], each element is a probability
     :param transmitted_words: [1, transmission_length]
     :return: loss value
     """
     loss = self.criterion(input=soft_estimation.reshape(-1, 2),
                           target=transmitted_words.long().reshape(-1))
     return loss
Beispiel #25
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        if "token" in tokens:
            tokens["tokens"] = self._pad(tokens["tokens"])
        if "token_characters" in tokens:
            tokens["token_characters"] = self._3pad(tokens["token_characters"])
        if "elmo" in tokens:
            tokens["elmo"] = self._3pad(tokens["elmo"])

        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict
Beispiel #26
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        label: torch.IntTensor = None,
        metadata: MetadataField = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            From a `TextField`
        label : `torch.IntTensor`, optional (default = `None`)
            From a `LabelField`

        # Returns

        An output dictionary consisting of:

            - `logits` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                unnormalized log probabilities of the label.
            - `probs` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                probabilities of the label.
            - `loss` : (`torch.FloatTensor`, optional) :
                A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict
Beispiel #27
0
    def forward(self,
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # Embed premise and hypothesis
        premise = self._text_field_embedder(premise)  # (n x p x d)
        hypothesis = self._text_field_embedder(hypothesis)  # (n x h x d)

        # encode premise and hypothesis
        # (n x p x 2d) if bidirectional else (n x p x d)
        premise = self._encoder(premise, premise_mask)
        # (n x h x 2d) if bidirectional else (n x h x d)
        hypothesis = self._encoder(hypothesis, hypothesis_mask)

        # calculate matrix attention
        similarity_matrix = self._inter_attention(hypothesis,
                                                  premise)  # (n x h x p)

        attention_softmax = last_dim_softmax(similarity_matrix,
                                             premise_mask)  # (n x h x p)
        hypothesis_tilda = weighted_sum(
            premise, attention_softmax
        )  # (n x h x 2d) assuming encoder is bidirectional

        hypothesis_matching_states = torch.cat([
            hypothesis, hypothesis_tilda, hypothesis - hypothesis_tilda,
            hypothesis * hypothesis_tilda
        ],
                                               dim=-1)

        # max pool
        hypothesis_max, _ = replace_masked_values(
            hypothesis_matching_states, hypothesis_mask.unsqueeze(-1),
            -1e7).max(dim=1)  # (n x 2d)

        output_dict = {"final_hidden": hypothesis_max}

        if self._output_feedforward:
            label_logits = self._output_feedforward(hypothesis_max)
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
            output_dict["label_logits"] = label_logits
            output_dict["label_probs"] = label_probs

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Beispiel #28
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise)
        hypothesis_mask = get_text_field_mask(hypothesis)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        output_dict = self.esim_forward(encoded_premise, encoded_hypothesis, premise_mask, hypothesis_mask, label=label)

        # If we're training, also compute loss and accuracy for the bias-only model
        if not self.evaluation_mode:
            sentence_pair_logits = output_dict["label_logits"]
            hyp_only_logits = self._classification_layer_hyp_only(get_final_encoder_states(encoded_hypothesis, hypothesis_mask,self._encoder.is_bidirectional()))
            hyp_only_probs = torch.softmax(hyp_only_logits, dim=1)

            scaled = (1.0 - hyp_only_probs).pow(self._gamma).detach()

            weighting = torch.cat([scaled[idx, val].unsqueeze(0) for idx, val in enumerate([l.item() for l in label])])
            instance_losses = self._element_cross_ent_loss(sentence_pair_logits, label)

            hyp_loss = self._cross_ent_loss(hyp_only_logits, label.long().view(-1))

            self._accuracy(sentence_pair_logits, label)
            self._hyp_only_accuracy(hyp_only_logits, label)
            output_dict = {
                "loss": (instance_losses * weighting).mean() + self._beta * hyp_loss,
                "logits": sentence_pair_logits,

            }

            return output_dict

        else:
            loss = self._cross_ent_loss(output_dict["label_logits"], label)
            output_dict["loss"] = loss

            self._accuracy(output_dict["label_logits"], label)
            return output_dict
Beispiel #29
0
    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : TextFieldTensors
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        # Returns

        An output dictionary consisting of:

        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict
Beispiel #30
0
    def forward(  # type: ignore
        self, tokens: TextFieldTensors, label: torch.IntTensor = None
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters

        tokens : TextFieldTensors
            From a `TextField` (that has a bert-pretrained token indexer)
        label : torch.IntTensor, optional (default = None)
            From a `LabelField`

        # Returns

        An output dictionary consisting of:

        logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        inputs = tokens[self._index]
        input_ids = inputs["input_ids"]
        token_type_ids = inputs["token_type_ids"]
        input_mask = (input_ids != 0).long()

        _, pooled, *_ = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=input_mask
        )

        pooled = self._dropout(pooled)

        # apply classification layer
        logits = self._classification_layer(pooled)

        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits,
                       "label_probs": label_probs,
                       "h2p_attention": h2p_attention,
                       "p2h_attention": p2h_attention}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata]
            output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata]

        return output_dict
Beispiel #32
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],
                dim=-1
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True
        )

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict