Exemplo n.º 1
0
    def test_masked_max(self):
        # Testing the general masked 1D case.
        vector_1d = torch.FloatTensor([1.0, 12.0, 5.0])
        mask_1d = torch.FloatTensor([1.0, 0.0, 1.0])
        vector_1d_maxed = util.masked_max(vector_1d, mask_1d, dim=0).data.numpy()
        assert_array_almost_equal(vector_1d_maxed, 5.0)

        # Testing if all masks are zero, the output will be arbitrary, but it should not be nan.
        vector_1d = torch.FloatTensor([1.0, 12.0, 5.0])
        mask_1d = torch.FloatTensor([0.0, 0.0, 0.0])
        vector_1d_maxed = util.masked_max(vector_1d, mask_1d, dim=0).data.numpy()
        assert not numpy.isnan(vector_1d_maxed).any()

        # Testing batch value and batch masks
        matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
        matrix_maxed = util.masked_max(matrix, mask, dim=-1).data.numpy()
        assert_array_almost_equal(matrix_maxed, numpy.array([5.0, -1.0]))

        # Testing keepdim for batch value and batch masks
        matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
        matrix_maxed = util.masked_max(matrix, mask, dim=-1, keepdim=True).data.numpy()
        assert_array_almost_equal(matrix_maxed, numpy.array([[5.0], [-1.0]]))

        # Testing broadcast
        matrix = torch.FloatTensor([[[1.0, 2.0], [12.0, 3.0], [5.0, -1.0]],
                                    [[-1.0, -3.0], [-2.0, -0.5], [3.0, 8.0]]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]).unsqueeze(-1)
        matrix_maxed = util.masked_max(matrix, mask, dim=1).data.numpy()
        assert_array_almost_equal(matrix_maxed, numpy.array([[5.0, 2.0], [-1.0, -0.5]]))
Exemplo n.º 2
0
    def test_masked_max(self):
        # Testing the general masked 1D case.
        vector_1d = torch.FloatTensor([1.0, 12.0, 5.0])
        mask_1d = torch.FloatTensor([1.0, 0.0, 1.0])
        vector_1d_maxed = util.masked_max(vector_1d, mask_1d, dim=0).data.numpy()
        assert_array_almost_equal(vector_1d_maxed, 5.0)

        # Testing if all masks are zero, the output will be arbitrary, but it should not be nan.
        vector_1d = torch.FloatTensor([1.0, 12.0, 5.0])
        mask_1d = torch.FloatTensor([0.0, 0.0, 0.0])
        vector_1d_maxed = util.masked_max(vector_1d, mask_1d, dim=0).data.numpy()
        assert not numpy.isnan(vector_1d_maxed).any()

        # Testing batch value and batch masks
        matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
        matrix_maxed = util.masked_max(matrix, mask, dim=-1).data.numpy()
        assert_array_almost_equal(matrix_maxed, numpy.array([5.0, -1.0]))

        # Testing keepdim for batch value and batch masks
        matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
        matrix_maxed = util.masked_max(matrix, mask, dim=-1, keepdim=True).data.numpy()
        assert_array_almost_equal(matrix_maxed, numpy.array([[5.0], [-1.0]]))

        # Testing broadcast
        matrix = torch.FloatTensor([[[1.0, 2.0], [12.0, 3.0], [5.0, -1.0]],
                                    [[-1.0, -3.0], [-2.0, -0.5], [3.0, 8.0]]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]).unsqueeze(-1)
        matrix_maxed = util.masked_max(matrix, mask, dim=1).data.numpy()
        assert_array_almost_equal(matrix_maxed, numpy.array([[5.0, 2.0], [-1.0, -0.5]]))
Exemplo n.º 3
0
    def forward(self, **kwargs) -> torch.FloatTensor:
        mask = kwargs['mask']
        embedded_text = kwargs['embedded_text']
        encoded_output = self._architecture(embedded_text, mask)
        encoded_repr = []
        for aggregation in self._aggregations:
            if aggregation == "meanpool":
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoded_output * broadcast_mask
                encoded_text = masked_mean(context_vectors,
                                           broadcast_mask,
                                           dim=1,
                                           keepdim=False)
            elif aggregation == 'maxpool':
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoded_output * broadcast_mask
                encoded_text = masked_max(context_vectors,
                                          broadcast_mask,
                                          dim=1)
            elif aggregation == 'final_state':
                is_bi = self._architecture.is_bidirectional()
                encoded_text = get_final_encoder_states(encoded_output,
                                                        mask,
                                                        is_bi)
            elif aggregation == 'attention':
                alpha = self._attention_layer(encoded_output)
                alpha = masked_log_softmax(alpha, mask.unsqueeze(-1), dim=1).exp()
                encoded_text = alpha * encoded_output
                encoded_text = encoded_text.sum(dim=1)
            else:
                raise ConfigurationError(f"{aggregation} aggregation not available.")
            encoded_repr.append(encoded_text)

        encoded_repr = torch.cat(encoded_repr, 1)
        return encoded_repr
Exemplo n.º 4
0
    def forward(self,
                document,
                query=None,
                rationale=None,
                metadata=None,
                label=None) -> Dict[str, Any]:
        input_ids = document["bert"]
        input_mask = (input_ids != 0).long()
        starting_offsets = document["bert-starting-offsets"]  # (B, T)

        last_hidden_states, _ = self._bert_model(
            input_ids,
            attention_mask=input_mask,
            position_ids=document["bert-position-ids"])

        token_embeddings, span_mask = generate_embeddings_for_pooling(
            last_hidden_states, starting_offsets,
            document["bert-ending-offsets"])

        token_embeddings = util.masked_max(token_embeddings,
                                           span_mask.unsqueeze(-1),
                                           dim=2)
        token_embeddings = token_embeddings * document["mask"].unsqueeze(-1)

        logits = self._classification_layer(self._dropout(token_embeddings))
        assert logits.shape[0:2] == starting_offsets.shape

        if self._use_crf:
            best_paths = self._crf.viterbi_tags(logits, mask=document["mask"])
            best_paths = [b[0] for b in best_paths]
            best_paths = [
                x + [0] * (logits.shape[1] - len(x)) for x in best_paths
            ]
            best_paths = torch.Tensor(best_paths).to(
                logits.device) * document["mask"]
        else:
            best_paths = (logits[:, :, 1] > 0.5).long() * document["mask"]

        output_dict = {}

        output_dict["predicted_rationales"] = best_paths
        output_dict["mask"] = document["mask"]
        output_dict["metadata"] = metadata

        if rationale is not None:
            if self._use_crf:
                output_dict["loss"] = -self._crf(logits, rationale,
                                                 document["mask"])
            else:
                output_dict["loss"] = ((F.cross_entropy(
                    logits.view(-1, logits.shape[-1]),
                    rationale.view(-1),
                    reduction="none",
                    weight=self._pos_weight,
                ) * document["mask"].view(-1)).sum(-1).mean())

            best_paths = best_paths.unsqueeze(-1)
            best_paths = torch.cat([1 - best_paths, best_paths], dim=-1)
            self._token_prf(best_paths, rationale, document["mask"])
        return output_dict
    def forward(self,
                document,
                query=None,
                label=None,
                metadata=None,
                rationale=None,
                **kwargs) -> Dict[str, Any]:
        #pylint: disable=arguments-differ

        bert_document = self.combine_document_query(document, query)

        last_hidden_states, _ = self._bert_model(
            bert_document["bert"]["wordpiece-ids"],
            attention_mask=bert_document["bert"]["wordpiece-mask"],
            position_ids=bert_document["bert"]["position-ids"],
            token_type_ids=bert_document["bert"]["type-ids"],
        )

        token_embeddings, span_mask = generate_embeddings_for_pooling(
            last_hidden_states,
            bert_document["bert"]['document-starting-offsets'],
            bert_document["bert"]['document-ending-offsets'])

        token_embeddings = util.masked_max(token_embeddings,
                                           span_mask.unsqueeze(-1),
                                           dim=2)
        token_embeddings = token_embeddings * bert_document['bert'][
            "mask"].unsqueeze(-1)

        logits = self._classification_layer(self._dropout(token_embeddings))

        probs = torch.sigmoid(logits)[:, :, 0]
        mask = bert_document['bert']['mask']

        output_dict = {}
        output_dict["probs"] = probs * mask
        output_dict['mask'] = mask
        predicted_rationale = (probs > 0.5).long()

        output_dict["predicted_rationale"] = predicted_rationale * mask
        output_dict["prob_z"] = probs * mask

        if rationale is not None:
            rat_mask = (rationale.sum(1) > 0)
            if rat_mask.sum().long() == 0:
                output_dict['loss'] = 0.0
            else:
                weight = torch.Tensor([1.0,
                                       self._pos_weight]).to(logits.device)
                loss = torch.nn.functional.cross_entropy(
                    logits[rat_mask].transpose(1, 2),
                    rationale[rat_mask],
                    weight=weight)
                output_dict['loss'] = loss
                self._token_prf(logits[rat_mask], rationale[rat_mask],
                                bert_document['bert']["mask"][rat_mask])

        return output_dict
Exemplo n.º 6
0
    def _encode_definition(
            self, definition: Dict[str,
                                   torch.Tensor]) -> Dict[str, torch.Tensor]:
        # [batch_size, seq_len]
        definition_mask = util.get_text_field_mask(definition)
        # [batch_size, seq_len, emb_dim]
        embedded_definition = self.text_embedder(definition)

        # either [batch_size, emb_dim] or [batch_size, seq_len, emb_dim]
        encoded_definition = self.definition_encoder(embedded_definition,
                                                     definition_mask)
        # if len(encoded_definition.size()) == 3:
        if self.definition_pooling == 'last':
            # [batch_size, emb_dim]
            encoded_definition = util.get_final_encoder_states(
                encoded_definition, definition_mask)
        elif self.definition_pooling == 'max':
            # encoded_definition = F.adaptive_max_pool1d(encoded_definition.transpose(1, 2), 1).squeeze(2)
            encoded_definition = util.masked_max(encoded_definition,
                                                 definition_mask.unsqueeze(2),
                                                 dim=1)
        elif self.definition_pooling == 'mean':
            # encoded_definition = F.adaptive_avg_pool1d(encoded_definition.transpose(1, 2), 1).squeeze(2)
            encoded_definition = util.masked_mean(encoded_definition,
                                                  definition_mask.unsqueeze(2),
                                                  dim=1)
        elif self.definition_pooling == 'self-attentive':
            self_attentive_logits = self.self_attentive_pooling_projection(
                encoded_definition).squeeze(2)
            self_weights = util.masked_softmax(self_attentive_logits,
                                               definition_mask)
            encoded_definition = util.weighted_sum(encoded_definition,
                                                   self_weights)
        # [batch_size, emb_dim]
        definition_embedding = self.definition_feedforward(encoded_definition)

        # [batch_size, vocab_size(num_class)]
        definition_logits = self.definition_decoder(definition_embedding)
        # [batch_size, seq_len, vocab_size]
        sequence_definition_logits = definition_logits.unsqueeze(1).repeat(
            1, definition_mask.size(1), 1)

        # ``average`` can be None, "batch", or "token"
        # loss for ``average==None`` is a vector of shape (batch_size,); otherwise, a scalar
        targets = definition['tokens'].clone()
        if self.limited_word_vocab_size is not None:
            targets[targets >= self.limited_word_vocab_size] = self._oov_index
        cross_entropy_loss = util.sequence_cross_entropy_with_logits(
            sequence_definition_logits,
            targets,
            # definition['tokens'],
            weights=definition_mask,
            average='token')

        return {
            "definition_embedding": definition_embedding,
            "cross_entropy_loss": cross_entropy_loss
        }
Exemplo n.º 7
0
    def forward(
        self,  # type: ignore
        tokens: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the text.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded = self._bert(tokens)

        first_token = embedded[:, 0, :]
        pooled_first = self._pooler(first_token)
        pooled_first = self._dropout(pooled_first)

        mask = tokens['mask'].float()
        encoded = self._encoder(embedded, mask)
        encoded = self._dropout(encoded)
        pooled_encoded = masked_max(encoded, mask.unsqueeze(-1), dim=1)

        concat = torch.cat([pooled_first, pooled_encoded], dim=-1)
        label_logits = self._classifier(concat)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits.view(-1, self._num_labels),
                              label.view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 8
0
    def pool_graph(self, node_embs, node_emb_mask):
        """
        Parameters:
            node_embs: (bsz, n_nodes, graph_dim)
            node_emb_mask: (bsz, n_nodes)

        Returns:
            (bsz, graph_dim (*2))
        """
        node_emb_mask = node_emb_mask.unsqueeze(-1)
        output = masked_max(node_embs, node_emb_mask, 1)
        output = torch.where(node_emb_mask.any(1), output,
                             torch.zeros_like(output))
        return output
Exemplo n.º 9
0
def seq2vec_seq_aggregate(seq_tensor, mask, aggregate, bidirectional, dim=1):
    """
        Takes the aggregation of sequence tensor

        :param seq_tensor: Batched sequence requires [batch, seq, hs]
        :param mask: binary mask with shape batch, seq_len, 1
        :param aggregate: max, avg, sum
        :param dim: The dimension to take the max. for batch, seq, hs it is 1
        :return:
    """

    seq_tensor_masked = seq_tensor * mask.unsqueeze(-1)
    aggr_func = None
    if aggregate == "last":
        if seq_tensor.dim() > 3:
            seq = get_final_encoder_states_after_squashing(seq_tensor, mask, bidirectional)
        else:
            seq = get_final_encoder_states(seq_tensor, mask, bidirectional)
    elif aggregate == "max":
        seq = masked_max(seq_tensor, mask.unsqueeze(-1).expand_as(seq_tensor), dim=dim)
    elif aggregate == "min":
        seq = -masked_max(-seq_tensor, mask.unsqueeze(-1).expand_as(seq_tensor), dim=dim)
    elif aggregate == "sum":
        aggr_func = torch.sum
        seq = aggr_func(seq_tensor_masked, dim=dim)
    elif aggregate == "avg":
        aggr_func = torch.sum
        seq = aggr_func(seq_tensor_masked, dim=dim)
        seq_lens = torch.sum(mask, dim=dim)  # this returns batch_size, .. 1 ..
        masked_seq_lens = replace_masked_values(seq_lens, (seq_lens != 0).float(), 1.0)
        masked_seq_lens = masked_seq_lens.unsqueeze(dim=dim).expand_as(seq)
        # print(seq.shape)
        # print(masked_seq_lens.shape)
        seq = seq / masked_seq_lens

    return seq
Exemplo n.º 10
0
def pool(vector: torch.Tensor,
         mask: torch.Tensor,
         dim: int,
         pooling: str,
         is_bidirectional: bool) -> torch.Tensor:
    if pooling == "max":
        return masked_max(vector, mask, dim)
    elif pooling == "mean":
        return masked_mean(vector, mask, dim)
    elif pooling == "sum":
        return torch.sum(vector, dim)
    elif pooling == "final":
        return get_final_encoder_states(vector, mask, is_bidirectional)
    else:
        raise ValueError(f"'{pooling}' is not a valid pooling operation.")
Exemplo n.º 11
0
    def forward(self,
                document,
                query=None,
                label=None,
                metadata=None,
                rationale=None,
                **kwargs) -> Dict[str, Any]:
        #pylint: disable=arguments-differ

        bert_document = self.combine_document_query(document, query)

        last_hidden_states, _ = self._bert_model(
            bert_document["bert"]["wordpiece-ids"],
            attention_mask=bert_document["bert"]["wordpiece-mask"],
            position_ids=bert_document["bert"]["position-ids"],
            token_type_ids=bert_document["bert"]["type-ids"],
        )

        token_embeddings, span_mask = generate_embeddings_for_pooling(
            last_hidden_states,
            bert_document["bert"]['document-starting-offsets'],
            bert_document["bert"]['document-ending-offsets'])

        token_embeddings = util.masked_max(token_embeddings,
                                           span_mask.unsqueeze(-1) == 1,
                                           dim=2)
        token_embeddings = token_embeddings * bert_document['bert'][
            "mask"].unsqueeze(-1)

        logits = torch.nn.functional.softplus(
            self._classification_layer(self._dropout(token_embeddings)))

        a, b = logits[:, :, 0], logits[:, :, 1]
        mask = bert_document['bert']['mask']

        output_dict = {}
        output_dict["a"] = a * mask
        output_dict["b"] = b * mask
        output_dict['mask'] = mask
        output_dict['wordpiece-to-token'] = bert_document['bert'][
            'wordpiece-to-token']
        return output_dict
Exemplo n.º 12
0
Arquivo: sse.py Projeto: Shuailong/SPM
    def forward(
        self,  # type: ignore
        premise: Dict[str, torch.LongTensor],
        hypothesis: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        s1_layer_1_out = self._encoder1(embedded_premise, premise_mask)
        s2_layer_1_out = self._encoder1(embedded_hypothesis, hypothesis_mask)

        s1_layer_2_out = self._encoder2(
            torch.cat([embedded_premise, s1_layer_1_out], dim=2), premise_mask)
        s2_layer_2_out = self._encoder2(
            torch.cat([embedded_hypothesis, s2_layer_1_out], dim=2),
            hypothesis_mask)

        s1_layer_3_out = self._encoder3(
            torch.cat([embedded_premise, s1_layer_1_out, s1_layer_2_out],
                      dim=2), premise_mask)
        s2_layer_3_out = self._encoder3(
            torch.cat([embedded_hypothesis, s2_layer_1_out, s2_layer_2_out],
                      dim=2), hypothesis_mask)

        premise_max = masked_max(s1_layer_3_out, premise_mask.unsqueeze(-1))
        hypothesis_max = masked_max(s2_layer_3_out,
                                    hypothesis_mask.unsqueeze(-1))

        features = torch.cat([
            premise_max, hypothesis_max,
            torch.abs(premise_max - hypothesis_max),
            premise_max * hypothesis_max
        ],
                             dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        output_hidden1 = self._output_feedforward1(features)
        output_hidden2 = self._output_feedforward2(output_hidden1)
        label_logits = self._output_logit(output_hidden2)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 13
0
    def forward(self,
                context_1: torch.Tensor,
                mask_1: torch.Tensor,
                context_2: torch.Tensor,
                mask_2: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        # pylint: disable=arguments-differ
        """
        Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral
        matching functions between them in one direction.

        Parameters
        ----------
        context_1 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence.
        mask_1 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len1), indicating which
            positions in the first sentence are padding (0) and which are not (1).
        context_2 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence.
        mask_2 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len2), indicating which
            positions in the second sentence are padding (0) and which are not (1).

        Returns
        -------
        A tuple of matching vectors for the two sentences. Each of which is a list of
        matching vectors of shape (batch, seq_len, num_perspectives or 1)
        """
        assert (not mask_2.requires_grad) and (not mask_1.requires_grad)
        assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim

        # (batch,)
        len_1 = get_lengths_from_binary_sequence_mask(mask_1)
        len_2 = get_lengths_from_binary_sequence_mask(mask_2)

        # (batch, seq_len*)
        mask_1, mask_2 = mask_1.float(), mask_2.float()

        # explicitly set masked weights to zero
        # (batch_size, seq_len*, hidden_dim)
        context_1 = context_1 * mask_1.unsqueeze(-1)
        context_2 = context_2 * mask_2.unsqueeze(-1)

        # array to keep the matching vectors for the two sentences
        matching_vector_1: List[torch.Tensor] = []
        matching_vector_2: List[torch.Tensor] = []

        # Step 0. unweighted cosine
        # First calculate the cosine similarities between each forward
        # (or backward) contextual embedding and every forward (or backward)
        # contextual embedding of the other sentence.

        # (batch, seq_len1, seq_len2)
        cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2), context_2.unsqueeze(-3), dim=3)

        # (batch, seq_len*, 1)
        cosine_max_1 = masked_max(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True)
        cosine_mean_1 = masked_mean(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True)
        cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True)
        cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True)

        matching_vector_1.extend([cosine_max_1, cosine_mean_1])
        matching_vector_2.extend([cosine_max_2, cosine_mean_2])

        # Step 1. Full-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with the last time step of the forward (or backward)
        # contextual embedding of the other sentence
        if self.with_full_match:

            # (batch, 1, hidden_dim)
            if self.is_forward:
                # (batch, 1, hidden_dim)
                last_position_1 = (len_1 - 1).clamp(min=0)
                last_position_1 = last_position_1.view(-1, 1, 1).expand(-1, 1, self.hidden_dim)
                last_position_2 = (len_2 - 1).clamp(min=0)
                last_position_2 = last_position_2.view(-1, 1, 1).expand(-1, 1, self.hidden_dim)

                context_1_last = context_1.gather(1, last_position_1)
                context_2_last = context_2.gather(1, last_position_2)
            else:
                context_1_last = context_1[:, 0:1, :]
                context_2_last = context_2[:, 0:1, :]

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_full = multi_perspective_match(context_1,
                                                             context_2_last,
                                                             self.full_match_weights)
            matching_vector_2_full = multi_perspective_match(context_2,
                                                             context_1_last,
                                                             self.full_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_full)
            matching_vector_2.extend(matching_vector_2_full)

        # Step 2. Maxpooling-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with every time step of the forward (or backward)
        # contextual embedding of the other sentence, and only the max value of each
        # dimension is retained.
        if self.with_maxpool_match:
            # (batch, seq_len1, seq_len2, num_perspectives)
            matching_vector_max = multi_perspective_match_pairwise(context_1,
                                                                   context_2,
                                                                   self.maxpool_match_weights)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_max = masked_max(matching_vector_max,
                                               mask_2.unsqueeze(-2).unsqueeze(-1),
                                               dim=2)
            matching_vector_1_mean = masked_mean(matching_vector_max,
                                                 mask_2.unsqueeze(-2).unsqueeze(-1),
                                                 dim=2)
            matching_vector_2_max = masked_max(matching_vector_max.permute(0, 2, 1, 3),
                                               mask_1.unsqueeze(-2).unsqueeze(-1),
                                               dim=2)
            matching_vector_2_mean = masked_mean(matching_vector_max.permute(0, 2, 1, 3),
                                                 mask_1.unsqueeze(-2).unsqueeze(-1),
                                                 dim=2)

            matching_vector_1.extend([matching_vector_1_max, matching_vector_1_mean])
            matching_vector_2.extend([matching_vector_2_max, matching_vector_2_mean])


        # Step 3. Attentive-Matching
        # Each forward (or backward) similarity is taken as the weight
        # of the forward (or backward) contextual embedding, and calculate an
        # attentive vector for the sentence by weighted summing all its
        # contextual embeddings.
        # Finally match each forward (or backward) contextual embedding
        # with its corresponding attentive vector.

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1)

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1)

        if self.with_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1))
            att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1))

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_mean = multi_perspective_match(context_1,
                                                                 att_mean_2,
                                                                 self.attentive_match_weights)
            matching_vector_2_att_mean = multi_perspective_match(context_2,
                                                                 att_mean_1,
                                                                 self.attentive_match_weights_reversed)
            matching_vector_1.extend(matching_vector_1_att_mean)
            matching_vector_2.extend(matching_vector_2_att_mean)

        # Step 4. Max-Attentive-Matching
        # Pick the contextual embeddings with the highest cosine similarity as the attentive
        # vector, and match each forward (or backward) contextual embedding with its
        # corresponding attentive vector.
        if self.with_max_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_max_2 = masked_max(att_2, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            att_max_1 = masked_max(att_1.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_max = multi_perspective_match(context_1,
                                                                att_max_2,
                                                                self.max_attentive_match_weights)
            matching_vector_2_att_max = multi_perspective_match(context_2,
                                                                att_max_1,
                                                                self.max_attentive_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_att_max)
            matching_vector_2.extend(matching_vector_2_att_max)

        return matching_vector_1, matching_vector_2
Exemplo n.º 14
0
    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        answer: torch.BoolTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat(
            [
                encoded_passage,
                passage_question_vectors,
                encoded_passage * passage_question_vectors,
                encoded_passage * tiled_question_passage_vector,
            ],
            dim=-1,
        )

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)

        prediction_bool_logits = util.masked_max(span_start_logits,
                                                 passage_mask,
                                                 dim=1)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "prediction_bool_logits": prediction_bool_logits
        }

        # Compute the loss for training.
        if answer is not None:
            loss = binary_cross_entropy_with_logits(prediction_bool_logits,
                                                    answer)
            threshold = 0.5
            prediction_bool_logits = torch.where(
                torch.sigmoid(prediction_bool_logits) > threshold,
                torch.ones_like(prediction_bool_logits),
                torch.zeros_like(prediction_bool_logits))
            self._accuracy(prediction_bool_logits, answer)
            output_dict["loss"] = loss

        return output_dict
    def esim_forward(  # type: ignore
        self,
        encoded_premise, encoded_hypothesis, premise_mask, hypothesis_mask,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
            [
                encoded_premise,
                attended_hypothesis,
                encoded_premise - attended_hypothesis,
                encoded_premise * attended_hypothesis,
            ],
            dim=-1,
        )
        hypothesis_enhanced = torch.cat(
            [
                encoded_hypothesis,
                attended_premise,
                encoded_hypothesis - attended_premise,
                encoded_hypothesis * attended_premise,
            ],
            dim=-1,
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max = masked_max(v_ai, premise_mask.unsqueeze(-1), dim=1)
        v_b_max = masked_max(v_bi, hypothesis_mask.unsqueeze(-1), dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
            premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
            hypothesis_mask, 1, keepdim=True
        )

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        return output_dict
Exemplo n.º 16
0
    def forward(
        self,
        context_1: torch.Tensor,
        mask_1: torch.Tensor,
        context_2: torch.Tensor,
        mask_2: torch.Tensor,
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral
        matching functions between them in one direction.

        Parameters
        ----------
        context_1 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence.
        mask_1 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len1), indicating which
            positions in the first sentence are padding (0) and which are not (1).
        context_2 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence.
        mask_2 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len2), indicating which
            positions in the second sentence are padding (0) and which are not (1).

        Returns
        -------
        A tuple of matching vectors for the two sentences. Each of which is a list of
        matching vectors of shape (batch, seq_len, num_perspectives or 1)
        """
        assert (not mask_2.requires_grad) and (not mask_1.requires_grad)
        assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim

        # (batch,)
        len_1 = get_lengths_from_binary_sequence_mask(mask_1)
        len_2 = get_lengths_from_binary_sequence_mask(mask_2)

        # (batch, seq_len*)
        mask_1, mask_2 = mask_1.float(), mask_2.float()

        # explicitly set masked weights to zero
        # (batch_size, seq_len*, hidden_dim)
        context_1 = context_1 * mask_1.unsqueeze(-1)
        context_2 = context_2 * mask_2.unsqueeze(-1)

        # array to keep the matching vectors for the two sentences
        matching_vector_1: List[torch.Tensor] = []
        matching_vector_2: List[torch.Tensor] = []

        # Step 0. unweighted cosine
        # First calculate the cosine similarities between each forward
        # (or backward) contextual embedding and every forward (or backward)
        # contextual embedding of the other sentence.

        # (batch, seq_len1, seq_len2)
        cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2),
                                         context_2.unsqueeze(-3),
                                         dim=3)

        # (batch, seq_len*, 1)
        cosine_max_1 = masked_max(cosine_sim,
                                  mask_2.unsqueeze(-2),
                                  dim=2,
                                  keepdim=True)
        cosine_mean_1 = masked_mean(cosine_sim,
                                    mask_2.unsqueeze(-2),
                                    dim=2,
                                    keepdim=True)
        cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1),
                                  mask_1.unsqueeze(-2),
                                  dim=2,
                                  keepdim=True)
        cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1),
                                    mask_1.unsqueeze(-2),
                                    dim=2,
                                    keepdim=True)

        matching_vector_1.extend([cosine_max_1, cosine_mean_1])
        matching_vector_2.extend([cosine_max_2, cosine_mean_2])

        # Step 1. Full-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with the last time step of the forward (or backward)
        # contextual embedding of the other sentence
        if self.with_full_match:

            # (batch, 1, hidden_dim)
            if self.is_forward:
                # (batch, 1, hidden_dim)
                last_position_1 = (len_1 - 1).clamp(min=0)
                last_position_1 = last_position_1.view(-1, 1, 1).expand(
                    -1, 1, self.hidden_dim)
                last_position_2 = (len_2 - 1).clamp(min=0)
                last_position_2 = last_position_2.view(-1, 1, 1).expand(
                    -1, 1, self.hidden_dim)

                context_1_last = context_1.gather(1, last_position_1)
                context_2_last = context_2.gather(1, last_position_2)
            else:
                context_1_last = context_1[:, 0:1, :]
                context_2_last = context_2[:, 0:1, :]

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_full = multi_perspective_match(
                context_1, context_2_last, self.full_match_weights)
            matching_vector_2_full = multi_perspective_match(
                context_2, context_1_last, self.full_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_full)
            matching_vector_2.extend(matching_vector_2_full)

        # Step 2. Maxpooling-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with every time step of the forward (or backward)
        # contextual embedding of the other sentence, and only the max value of each
        # dimension is retained.
        if self.with_maxpool_match:
            # (batch, seq_len1, seq_len2, num_perspectives)
            matching_vector_max = multi_perspective_match_pairwise(
                context_1, context_2, self.maxpool_match_weights)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_max = masked_max(
                matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            matching_vector_1_mean = masked_mean(
                matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            matching_vector_2_max = masked_max(
                matching_vector_max.permute(0, 2, 1, 3),
                mask_1.unsqueeze(-2).unsqueeze(-1),
                dim=2)
            matching_vector_2_mean = masked_mean(
                matching_vector_max.permute(0, 2, 1, 3),
                mask_1.unsqueeze(-2).unsqueeze(-1),
                dim=2)

            matching_vector_1.extend(
                [matching_vector_1_max, matching_vector_1_mean])
            matching_vector_2.extend(
                [matching_vector_2_max, matching_vector_2_mean])

        # Step 3. Attentive-Matching
        # Each forward (or backward) similarity is taken as the weight
        # of the forward (or backward) contextual embedding, and calculate an
        # attentive vector for the sentence by weighted summing all its
        # contextual embeddings.
        # Finally match each forward (or backward) contextual embedding
        # with its corresponding attentive vector.

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1)

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1)

        if self.with_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1))
            att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1))

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_mean = multi_perspective_match(
                context_1, att_mean_2, self.attentive_match_weights)
            matching_vector_2_att_mean = multi_perspective_match(
                context_2, att_mean_1, self.attentive_match_weights_reversed)
            matching_vector_1.extend(matching_vector_1_att_mean)
            matching_vector_2.extend(matching_vector_2_att_mean)

        # Step 4. Max-Attentive-Matching
        # Pick the contextual embeddings with the highest cosine similarity as the attentive
        # vector, and match each forward (or backward) contextual embedding with its
        # corresponding attentive vector.
        if self.with_max_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_max_2 = masked_max(att_2,
                                   mask_2.unsqueeze(-2).unsqueeze(-1),
                                   dim=2)
            att_max_1 = masked_max(att_1.permute(0, 2, 1, 3),
                                   mask_1.unsqueeze(-2).unsqueeze(-1),
                                   dim=2)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_max = multi_perspective_match(
                context_1, att_max_2, self.max_attentive_match_weights)
            matching_vector_2_att_max = multi_perspective_match(
                context_2, att_max_1,
                self.max_attentive_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_att_max)
            matching_vector_2.extend(matching_vector_2_att_max)

        return matching_vector_1, matching_vector_2
Exemplo n.º 17
0
    def forward(  # type: ignore
        self,
        sent1: TextFieldTensors,
        sent2: TextFieldTensors,
        label: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:
        with adv_utils.forward_context("sent1"):
            embedded_sent1 = self.word_embedders(sent1)
        with adv_utils.forward_context("sent2"):
            embedded_sent2 = self.word_embedders(sent2)
        sent1_mask = get_text_field_mask(sent1)
        sent2_mask = get_text_field_mask(sent2)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_sent1 = self.rnn_input_dropout(embedded_sent1)
            embedded_sent2 = self.rnn_input_dropout(embedded_sent2)

        # encode sent1 and sent2
        encoded_sent1 = self._encoder(embedded_sent1, sent1_mask)
        encoded_sent2 = self._encoder(embedded_sent2, sent2_mask)

        # Shape: (batch_size, sent1_length, sent2_length)
        similarity_matrix = self._matrix_attention(encoded_sent1,
                                                   encoded_sent2)

        # Shape: (batch_size, sent1_length, sent2_length)
        p2h_attention = masked_softmax(similarity_matrix, sent2_mask)
        # Shape: (batch_size, sent1_length, embedding_dim)
        attended_sent2 = weighted_sum(encoded_sent2, p2h_attention)

        # Shape: (batch_size, sent2_length, sent1_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), sent1_mask)
        # Shape: (batch_size, sent2_length, embedding_dim)
        attended_sent1 = weighted_sum(encoded_sent1, h2p_attention)

        # the "enhancement" layer
        sent1_enhanced = torch.cat(
            [
                encoded_sent1,
                attended_sent2,
                encoded_sent1 - attended_sent2,
                encoded_sent1 * attended_sent2,
            ],
            dim=-1,
        )
        sent2_enhanced = torch.cat(
            [
                encoded_sent2,
                attended_sent1,
                encoded_sent2 - attended_sent1,
                encoded_sent2 * attended_sent1,
            ],
            dim=-1,
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_sent1 = self._projection_feedforward(sent1_enhanced)
        projected_enhanced_sent2 = self._projection_feedforward(sent2_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_sent1 = self.rnn_input_dropout(
                projected_enhanced_sent1)
            projected_enhanced_sent2 = self.rnn_input_dropout(
                projected_enhanced_sent2)
        v_ai = self._inference_encoder(projected_enhanced_sent1, sent1_mask)
        v_bi = self._inference_encoder(projected_enhanced_sent2, sent2_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max = masked_max(v_ai, sent1_mask.unsqueeze(-1), dim=1)
        v_b_max = masked_max(v_bi, sent2_mask.unsqueeze(-1), dim=1)

        v_a_avg = torch.sum(v_ai * sent1_mask.unsqueeze(-1),
                            dim=1) / torch.sum(sent1_mask, 1, keepdim=True)
        v_b_avg = torch.sum(v_bi * sent2_mask.unsqueeze(-1),
                            dim=1) / torch.sum(sent2_mask, 1, keepdim=True)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"logits": label_logits, "probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 18
0
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> None:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(
            max_batch_span_width,
            util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(
            raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor,
                                                    span_indices,
                                                    flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        masked_span_embeddings = span_embeddings * span_mask.unsqueeze(-1)

        batch_size, num_spans, max_batch_span_width, embedding_dim = masked_span_embeddings.size(
        )
        # Shape: (batch_size*num_spans, embedding_dim, max_batch_span_width)
        masked_span_embeddings = masked_span_embeddings.view(
            batch_size * num_spans, max_batch_span_width,
            embedding_dim).transpose(1, 2)

        # Shape: (batch_size, embedding_dim, num_spans*max_batch_span_width)
        conv_span_embeddings = torch.nn.functional.relu(
            self._conv(masked_span_embeddings))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        conv_span_embeddings = conv_span_embeddings.transpose(1, 2).view(
            batch_size, num_spans, max_batch_span_width, embedding_dim)

        # Shape: (batch_size, num_spans, embedding_dim)
        span_embeddings = util.masked_max(conv_span_embeddings,
                                          span_mask.unsqueeze(-1),
                                          dim=2)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            span_width_embeddings = self._span_width_embedding(
                span_widths.squeeze(-1))
            span_embeddings = torch.cat(
                [span_embeddings, span_width_embeddings], -1)

        return span_embeddings
Exemplo n.º 19
0
    def forward(  # type: ignore
        self,
        premise: TextFieldTensors,
        hypothesis: TextFieldTensors,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        premise : TextFieldTensors
            From a `TextField`
        hypothesis : TextFieldTensors
            From a `TextField`
        label : torch.IntTensor, optional (default = None)
            From a `LabelField`
        metadata : `List[Dict[str, Any]]`, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        # Returns

        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise)
        hypothesis_mask = get_text_field_mask(hypothesis)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis,
                                           hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise,
                                                   encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
            [
                encoded_premise,
                attended_hypothesis,
                encoded_premise - attended_hypothesis,
                encoded_premise * attended_hypothesis,
            ],
            dim=-1,
        )
        hypothesis_enhanced = torch.cat(
            [
                encoded_hypothesis,
                attended_premise,
                encoded_hypothesis - attended_premise,
                encoded_hypothesis * attended_premise,
            ],
            dim=-1,
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(
            premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(
            hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(
                projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(
                projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise,
                                       premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis,
                                       hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max = masked_max(v_ai, premise_mask.unsqueeze(-1), dim=1)
        v_b_max = masked_max(v_bi, hypothesis_mask.unsqueeze(-1), dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1),
                            dim=1) / torch.sum(premise_mask, 1, keepdim=True)
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1),
                            dim=1) / torch.sum(
                                hypothesis_mask, 1, keepdim=True)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 20
0
    def forward(
            self,  # type: ignore
            passage: Dict[str, torch.LongTensor],
            all_qa: Dict[str, torch.LongTensor],
            candidate: Dict[str, torch.LongTensor],
            combined_source: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """"""
        if self._with_knowledge:
            embedded_passage = self._text_field_embedder(passage)  # B * T * d
            passage_len = embedded_passage.size(1)
        embedded_all_qa = self._text_field_embedder(all_qa)  # B * U * d
        embedded_choice = self._text_field_embedder(candidate)  # B * V * d

        if self._with_knowledge:
            embedded_passage = self._variational_dropout(
                embedded_passage)  # B * T * d
        embedded_all_qa = self._variational_dropout(embedded_all_qa)
        embedded_choice = self._variational_dropout(
            embedded_choice)  # B * V * d

        all_qa_mask = util.get_text_field_mask(all_qa)  # B * U
        choice_mask = util.get_text_field_mask(candidate)  # B * V

        # Encoding
        if self._with_knowledge:
            # B * T * H
            passage_mask = util.get_text_field_mask(passage)  # B * T
            encoded_passage = self._variational_dropout(
                self._pseqlevel_enc(embedded_passage, passage_mask))
        # B * U * H
        if self._shared_rnn:
            encoded_allqa = self._variational_dropout(
                self._pseqlevel_enc(embedded_all_qa, all_qa_mask))
        else:
            encoded_allqa = self._variational_dropout(
                self._qaseqlevel_enc(embedded_all_qa, all_qa_mask))

        if self._with_knowledge and self._is_qdep_penc:
            # similarity matrix
            _, normalized_attn_mat = self._cart_attn(encoded_passage,
                                                     encoded_allqa,
                                                     all_qa_mask)  # B * T * U
            # question dependent passage encoding
            q_aware_passage_rep = sequential_weighted_avg(
                encoded_allqa, normalized_attn_mat)  # B * T * H

            q_dep_passage_enc_rnn_input = torch.cat(
                [encoded_passage, q_aware_passage_rep], 2)  # B * T * 2H

            # gated question dependent passage encoding
            gated_qaware_passage_rep = self._gate_qdep_penc(
                q_dep_passage_enc_rnn_input)  # B * T * 2H
            encoded_qdep_penc = self._qdep_penc_rnn(gated_qaware_passage_rep,
                                                    passage_mask)  # B * T * H

        # multi factor attentive encoding
        if self._with_knowledge and self._is_mfa_enc:
            if self._is_qdep_penc:
                mfa_enc = self._multifactor_attn(encoded_qdep_penc,
                                                 passage_mask)  # B * T * 2H
            else:
                mfa_enc = self._multifactor_attn(encoded_passage,
                                                 passage_mask)  # B * T * 2H
            encoded_passage = self._mfarnn(mfa_enc, passage_mask)  # B * T * H

        # B * V * H
        if self._shared_rnn:
            encoded_choice = self._variational_dropout(
                self._pseqlevel_enc(embedded_choice, choice_mask))  # B * V * H
        else:
            encoded_choice = self._variational_dropout(
                self._cseqlevel_enc(embedded_choice, choice_mask))  # B * V * H

        if self._with_knowledge:
            attn_pq, _ = self._pqaattnmat(encoded_passage, encoded_allqa,
                                          all_qa_mask)  # B * T * U
            combined_pqa_mask = passage_mask.unsqueeze(-1) * \
                                all_qa_mask.unsqueeze(1)  # B * T * U
            max_attn_pqa = masked_max(attn_pq, combined_pqa_mask,
                                      dim=1)  # B * U
            norm_attn_pqa = masked_softmax(max_attn_pqa, all_qa_mask,
                                           dim=-1)  # B * U
            agg_prev_qa = norm_attn_pqa.unsqueeze(1).bmm(
                encoded_allqa).squeeze(1)  # B * H

            attn_pc, _ = self._pcattnmat(encoded_passage, encoded_choice,
                                         choice_mask)  # B * T * V
            combined_pc_mask = passage_mask.unsqueeze(-1) * \
                               choice_mask.unsqueeze(1)  # B * T * V
            max_attn_pc = masked_max(attn_pc, combined_pc_mask, dim=1)  # B * V
            norm_attn_pc = masked_softmax(max_attn_pc, choice_mask,
                                          dim=-1)  # B * V
            agg_c = norm_attn_pc.unsqueeze(1).bmm(encoded_choice)  # B * 1 * H

            choice_scores_wk = agg_c.bmm(agg_prev_qa.unsqueeze(-1)).squeeze(
                -1)  # B * 1

        if self._qac_ap:
            attn_qac, _ = self._cqaattnmat(encoded_allqa, encoded_choice,
                                           choice_mask)  # B * U * V
            combined_qac_mask = all_qa_mask.unsqueeze(-1) * \
                                choice_mask.unsqueeze(1)  # B * U * V

            max_attn_c = masked_max(attn_qac, combined_qac_mask,
                                    dim=1)  # B * V
            max_attn_qa = masked_max(attn_qac, combined_qac_mask,
                                     dim=2)  # B * U
            norm_attn_c = masked_softmax(max_attn_c, choice_mask,
                                         dim=-1)  # B * V
            norm_attn_qa = masked_softmax(max_attn_qa, all_qa_mask,
                                          dim=-1)  # B * U
            agg_c_qa = norm_attn_c.unsqueeze(1).bmm(encoded_choice).squeeze(
                1)  # B * H
            agg_qa_c = norm_attn_qa.unsqueeze(1).bmm(encoded_allqa).squeeze(
                1)  # B * H

            choice_scores_nk = agg_c_qa.unsqueeze(1).bmm(
                agg_qa_c.unsqueeze(-1)).squeeze(-1)  # B * 1

        if self._with_knowledge and self._qac_ap:
            choice_score = choice_scores_wk + choice_scores_nk
        elif self._qac_ap:
            choice_score = choice_scores_nk
        elif self._with_knowledge:
            choice_score = choice_scores_wk
        else:
            raise NotImplementedError

        output = torch.sigmoid(choice_score).squeeze(-1)  # B

        output_dict = {
            "label_logits": choice_score.squeeze(-1),
            "label_probs": output,
            "metadata": metadata
        }

        if label is not None:
            label = label.long().view(-1)
            loss = self._loss(output, label.float())
            self._auc(output, label)
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 21
0
    def forward(
        self,  # type: ignore
        qa_pairs: Dict[str, torch.LongTensor],
        answer_index: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        qa_pairs : Dict[str, torch.LongTensor]
            From a ``ListField``.
        answer_index : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is what we are trying to predict.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, question and choices for each instance 
            in the batch. The length of this list should be the batch size, and each dictionary 
            should have the keys ``qid``, ``question``, ``choices``, ``question_tokens`` and 
            ``choices_tokens``.

        Returns
        -------
        An output dictionary consisting of the followings.

        qid : List[str]
            A list consisting of question ids.
        answer_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_options=5)`` representing unnormalised log
            probabilities of the choices.
        answer_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_options=5)`` representing probabilities of the
            choices.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embeded = self._text_field_embedder(qa_pairs, num_wrapping_dims=1)
        mask = qa_pairs['mask'].float()
        batch_size, choice_size, seq_len, hidden_size = embeded.size()
        embeded = embeded.view(-1, seq_len, hidden_size)
        mask = mask.view(-1, seq_len)
        if self.dropout:
            embeded = self.dropout(embeded)
        if self._encoder:
            embeded = self._encoder(embeded, mask)
        embeded = embeded.view(batch_size, choice_size, seq_len, -1)
        mask = mask.view(batch_size, choice_size, seq_len)
        embeded = masked_max(embeded, mask.unsqueeze(-1), dim=2)

        # the final MLP -- apply dropout to input, and MLP applies to hidden
        answer_logits = self._output_logit(embeded).squeeze(-1)
        answer_probs = torch.nn.functional.softmax(answer_logits, dim=-1)
        qids = [m['qid'] for m in metadata]
        output_dict = {
            "answer_logits": answer_logits,
            "answer_probs": answer_probs,
            "qid": qids
        }

        if answer_index is not None:
            answer_index = answer_index.squeeze(-1)
            loss = self._loss(answer_logits, answer_index)
            self._accuracy(answer_logits, answer_index)
            output_dict["loss"] = loss
        return output_dict
Exemplo n.º 22
0
    def forward(self, span_embeddings, span_children, span_children_mask):
        batch, sequence, children_num, _ = span_children.size()
        # (batch, sequence, children_num)
        span_children = span_children.squeeze(-1)

        for t in range(self._tree_prop):

            flat_span_indices = util.flatten_and_batch_shift_indices(span_children, span_embeddings.size(1))
            # (batch, sequence, children_num, span_emb_dim)
            children_span_embeddings = util.batched_index_select(span_embeddings, span_children, flat_span_indices)

            if self._tree_children == 'attention':
                # (batch, sequence, children_num)
                attention_scores = self._global_attention(children_span_embeddings).squeeze(-1)
                # (batch, sequence, children_num)
                attention_scores_softmax = util.masked_softmax(attention_scores, span_children_mask, dim=2)
                # attention_scores_softmax = self.antecedent_softmax(attention_scores)
                # debug feili
                # for dim1 in attention_scores_softmax:
                #     for dim2 in dim1:
                #         pass
                # (batch, sequence, span_emb_dim)
                children_span_embeddings_merged = util.weighted_sum(children_span_embeddings, attention_scores_softmax)
            elif self._tree_children == 'pooling':
                children_span_embeddings_merged = util.masked_max(children_span_embeddings, span_children_mask.unsqueeze(-1), dim=2)
            elif self._tree_children == 'conv':
                masked_children_span_embeddings = children_span_embeddings * span_children_mask.unsqueeze(-1)

                masked_children_span_embeddings = masked_children_span_embeddings.view(batch * sequence, children_num, -1).transpose(1, 2)

                conv_children_span_embeddings = torch.nn.functional.relu(self._conv(masked_children_span_embeddings))

                conv_children_span_embeddings = conv_children_span_embeddings.transpose(1, 2).view(batch, sequence, children_num, -1)

                children_span_embeddings_merged = util.masked_max(conv_children_span_embeddings, span_children_mask.unsqueeze(-1), dim=2)
            elif self._tree_children == 'rnn':
                masked_children_span_embeddings = children_span_embeddings * span_children_mask.unsqueeze(-1)
                masked_children_span_embeddings = masked_children_span_embeddings.view(batch * sequence, children_num, -1)
                try : # if all spans don't have children in this batch, this code will report error
                    rnn_children_span_embeddings = self._encoder(masked_children_span_embeddings, span_children_mask.view(batch * sequence, children_num))
                except Exception as e:
                    rnn_children_span_embeddings = masked_children_span_embeddings

                rnn_children_span_embeddings = rnn_children_span_embeddings.view(batch, sequence, children_num, -1)
                forward_sequence, backward_sequence = rnn_children_span_embeddings.split(int(self._span_emb_dim / 2), dim=-1)
                children_span_embeddings_merged = torch.cat([forward_sequence[:,:,-1,:], backward_sequence[:,:,0,:]], dim=-1)
            else:
                raise RuntimeError
            # for dim1 in children_span_embeddings_attentioned:
            #     for dim2 in dim1:
            #         pass
            # (batch, sequence, 2*span_emb_dim)
            f_network_input = torch.cat([span_embeddings, children_span_embeddings_merged], dim=-1)
            # (batch, sequence, span_emb_dim)
            f_weights = self._f_network(f_network_input)
            # for dim1 in f_weights:
            #     for dim2 in dim1:
            #         pass
            # (batch, sequence, 1), if f_weights_mask=1, this span has at least one child
            f_weights_mask, _ = span_children_mask.max(dim=-1, keepdim=True)
            # for dim1 in f_weights_mask:
            #     for dim2 in dim1:
            #         pass
            # (batch, sequence, span_emb_dim), let the element of f_weights becomes 1 where f_weights_mask==0
            f_weights = util.replace_masked_values(f_weights, f_weights_mask, 1.0)
            # for dim1 in f_weights:
            #     for dim2 in dim1:
            #         pass
            # (batch, sequence, span_emb_dim)
            # for dim1 in span_embeddings:
            #     for dim2 in dim1:
            #         pass
            span_embeddings = f_weights * span_embeddings + (1.0 - f_weights) * children_span_embeddings_merged
            # for dim1 in combined_span_embeddings:
            #     for dim2 in dim1:
            #         pass

        span_embeddings = self._dropout(span_embeddings)

        return span_embeddings
Exemplo n.º 23
0
    def forward(self,
                document,
                query=None,
                label=None,
                metadata=None,
                rationale=None,
                **kwargs) -> Dict[str, Any]:
        # pylint: disable=arguments-differ

        bert_document = self.combine_document_query(document, query)

        last_hidden_states, _ = self._bert_model(
            bert_document["bert"]["wordpiece-ids"],
            attention_mask=bert_document["bert"]["wordpiece-mask"],
            position_ids=bert_document["bert"]["position-ids"],
            token_type_ids=bert_document["bert"]["type-ids"],
        )

        token_embeddings, span_mask = generate_embeddings_for_pooling(
            last_hidden_states,
            bert_document["bert"]["document-starting-offsets"],
            bert_document["bert"]["document-ending-offsets"],
        )

        token_embeddings = util.masked_max(token_embeddings,
                                           span_mask.unsqueeze(-1) == 1,
                                           dim=2)
        token_embeddings = token_embeddings * bert_document["bert"][
            "mask"].unsqueeze(-1)

        logits = self._classification_layer(self._dropout(token_embeddings))

        probs = torch.sigmoid(logits)[:, :, 0]
        mask = bert_document["bert"]["mask"]

        output_dict = {}
        output_dict["probs"] = probs * mask
        output_dict["mask"] = mask
        predicted_rationale = (probs > 0.5).long()

        output_dict["predicted_rationale"] = predicted_rationale * mask
        output_dict["prob_z"] = probs * mask

        if rationale is not None and self._supervise_rationale:
            rat_mask = rationale.sum(1) > 0
            if rat_mask.sum().long() == 0:
                output_dict["loss"] = 0.0
            else:
                rat_mask = rat_mask.bool()
                loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    logits[rat_mask].squeeze(-1),
                    rationale[rat_mask],
                    reduction="none",
                    pos_weight=self._pos_weight.to(rationale.device),
                )
                loss = ((loss * mask[rat_mask]).sum(-1) /
                        mask[rat_mask].sum(-1)).mean()
                output_dict["loss"] = loss
                self._token_prf(
                    torch.cat([
                        1 - probs[rat_mask].unsqueeze(-1),
                        probs[rat_mask].unsqueeze(-1)
                    ],
                              dim=-1),
                    rationale[rat_mask].long(),
                    mask[rat_mask] == 1,
                )

        return output_dict
    def forward(
        self,  # type: ignore
        tokens: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata to persist

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        encoder_output = self._encoder(embedded_text, mask)

        encoded_repr = []
        for aggregation in self._aggregations:
            if aggregation == "meanpool":
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoder_output * broadcast_mask
                encoded_text = masked_mean(context_vectors,
                                           broadcast_mask,
                                           dim=1,
                                           keepdim=False)
            elif aggregation == 'maxpool':
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoder_output * broadcast_mask
                encoded_text = masked_max(context_vectors,
                                          broadcast_mask,
                                          dim=1)
            elif aggregation == 'final_state':
                is_bi = self._encoder.is_bidirectional()
                encoded_text = get_final_encoder_states(
                    encoder_output, mask, is_bi)
            encoded_repr.append(encoded_text)

        encoded_repr = torch.cat(encoded_repr, 1)

        if self.dropout:
            encoded_repr = self.dropout(encoded_repr)

        output_hidden = self._output_feedforward(encoded_repr)
        label_logits = self._classification_layer(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict