Example #1
0
File: models.py Project: zxlzr/PURE
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, sub_obj_ids=None, sub_obj_masks=None, input_position=None):
        """
        attention_mask: [batch_size, from_seq_length, to_seq_length]
        """
        batch_size = input_ids.size(0)
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=False, output_attentions=False, position_ids=input_position)
        sequence_output = outputs[0]

        sub_ids = sub_obj_ids[:, :, 0].view(batch_size, -1)
        sub_embeddings = batched_index_select(sequence_output, sub_ids)
        obj_ids = sub_obj_ids[:, :, 1].view(batch_size, -1)
        obj_embeddings = batched_index_select(sequence_output, obj_ids)
        rep = torch.cat((sub_embeddings, obj_embeddings), dim=-1)
        rep = self.layer_norm(rep)
        rep = self.dropout(rep)
        logits = self.classifier(rep)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            active_loss = (sub_obj_masks.view(-1) == 1)
            active_logits = logits.view(-1, logits.shape[-1])
            active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
            loss = loss_fct(active_logits, active_labels)
            return loss
        else:
            return logits
Example #2
0
File: models.py Project: zxlzr/PURE
    def _get_span_embeddings(self,
                             input_ids,
                             spans,
                             token_type_ids=None,
                             attention_mask=None):
        sequence_output, pooled_output = self.bert(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask)

        sequence_output = self.hidden_dropout(sequence_output)
        """
        spans: [batch_size, num_spans, 3]; 0: left_ned, 1: right_end, 2: width
        spans_mask: (batch_size, num_spans, )
        """
        spans_start = spans[:, :, 0].view(spans.size(0), -1)
        spans_start_embedding = batched_index_select(sequence_output,
                                                     spans_start)
        spans_end = spans[:, :, 1].view(spans.size(0), -1)
        spans_end_embedding = batched_index_select(sequence_output, spans_end)

        spans_width = spans[:, :, 2].view(spans.size(0), -1)
        spans_width_embedding = self.width_embedding(spans_width)

        # Concatenate embeddings of left/right points and the width embedding
        spans_embedding = torch.cat(
            (spans_start_embedding, spans_end_embedding,
             spans_width_embedding),
            dim=-1)
        """
        spans_embedding: (batch_size, num_spans, hidden_size*2+embedding_dim)
        """
        return spans_embedding
Example #3
0
    def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # Concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, "x,y")

        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 14]
        assert extractor.get_output_dim() == 14
        assert extractor.get_input_dim() == 7

        start_indices, end_indices = indices.split(1, -1)
        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(7, -1)

        correct_start_embeddings = batched_index_select(
            sequence_tensor, start_indices.squeeze())
        correct_end_embeddings = batched_index_select(sequence_tensor,
                                                      end_indices.squeeze())
        numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
                                         correct_start_embeddings.data.numpy())
        numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
                                         correct_end_embeddings.data.numpy())
Example #4
0
    def test_masked_indices_are_handled_correctly(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, "x,y")

        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        # Make a mask with the second batch element completely masked.
        indices_mask = torch.LongTensor([[1, 1], [0, 0]])

        span_representations = extractor(sequence_tensor,
                                         indices,
                                         span_indices_mask=indices_mask)
        start_embeddings, end_embeddings = span_representations.split(7, -1)
        start_indices, end_indices = indices.split(1, -1)

        correct_start_embeddings = batched_index_select(
            sequence_tensor, start_indices.squeeze()).data
        # Completely masked second batch element, so it should all be zero.
        correct_start_embeddings[1, :, :].fill_(0)
        correct_end_embeddings = batched_index_select(
            sequence_tensor, end_indices.squeeze()).data
        correct_end_embeddings[1, :, :].fill_(0)
        numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
                                         correct_start_embeddings.numpy())
        numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
                                         correct_end_embeddings.numpy())
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        if not self._use_exclusive_start_indices:
            start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

        else:
            # We want `exclusive` span starts, so we remove 1 from the forward span starts
            # as the AllenNLP ``SpanField`` is inclusive.
            # shape (batch_size, num_spans)
            exclusive_span_starts = span_starts - 1
            # shape (batch_size, num_spans, 1)
            start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
            exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

            # We'll check the indices here at runtime, because it's difficult to debug
            # if this goes wrong and it's tricky to get right.
            if (exclusive_span_starts < 0).any():
                raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, "
                                 f"but found: exclusive_span_starts: {exclusive_span_starts}.")

            start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

            # We're using sentinels, so we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with the start sentinel.
            float_start_sentinel_mask = start_sentinel_mask.float()
            start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel

        combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()
        return combined_tensors
Example #6
0
    def _compute_span_representations(
            self, text_embeddings: torch.FloatTensor,
            text_mask: torch.FloatTensor, span_starts: torch.IntTensor,
            span_ends: torch.IntTensor) -> torch.FloatTensor:
        """
        Computes an embedded representation of every candidate span. This is a concatenation
        of the contextualized endpoints of the span, an embedded representation of the width of
        the span and a representation of the span's predicted head.

        Parameters
        ----------
        text_embeddings : ``torch.FloatTensor``, required.
            The embedded document of shape (batch_size, document_length, embedding_dim)
            over which we are computing a weighted sum.
        text_mask : ``torch.FloatTensor``, required.
            A mask of shape (batch_size, document_length) representing non-padding entries of
            ``text_embeddings``.
        span_starts : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans) representing the start of each span candidate.
        span_ends : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans) representing the end of each span candidate.
        Returns
        -------
        span_embeddings : ``torch.FloatTensor``
            An embedded representation of every candidate span with shape:
            (batch_size, num_spans, context_layer.get_output_dim() * 2 + embedding_size + feature_size)
        """
        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)

        # Shape: (batch_size, num_spans, encoding_dim)
        start_embeddings = util.batched_index_select(contextualized_embeddings,
                                                     span_starts.squeeze(-1))
        end_embeddings = util.batched_index_select(contextualized_embeddings,
                                                   span_ends.squeeze(-1))

        # Compute and embed the span_widths (strictly speaking the span_widths - 1)
        # Shape: (batch_size, num_spans, 1)
        span_widths = span_ends - span_starts
        # Shape: (batch_size, num_spans, encoding_dim)
        span_width_embeddings = self._span_width_embedding(
            span_widths.squeeze(-1))

        # Shape: (batch_size, document_length, 1)
        head_scores = self._head_scorer(contextualized_embeddings)

        # Shape: (batch_size, num_spans, embedding_dim)
        # Note that we used the original text embeddings, not the contextual ones here.
        attended_text_embeddings = self._create_attended_span_representations(
            head_scores, text_embeddings, span_ends, span_widths)
        # (batch_size, num_spans, context_layer.get_output_dim() * 2 + embedding_dim + feature_dim)
        span_embeddings = torch.cat([
            start_embeddings, end_embeddings, span_width_embeddings,
            attended_text_embeddings
        ], -1)
        return span_embeddings
    def forward(
        self,
        input_ids,
        spans,
        spans_mask,
        token_type_ids=None,
        attention_mask=None,
        spans_ner_label=None,
    ):
        sequence_output = self.albert(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )["last_hidden_state"]

        sequence_output = self.hidden_dropout(sequence_output)
        """
        spans: [batch_size, num_spans, 3]; 0: left_ned, 1: right_end, 2: width
        spans_mask: (batch_size, num_spans, )
        """
        spans_start = spans[:, :, 0]
        spans_start_embedding = batched_index_select(sequence_output,
                                                     spans_start)
        spans_end = spans[:, :, 1]
        spans_end_embedding = batched_index_select(sequence_output, spans_end)

        spans_width = spans[:, :, 2]
        spans_width_embedding = self.width_embedding(spans_width)

        spans_embedding = torch.cat(
            (
                spans_start_embedding,
                spans_end_embedding,
                spans_width_embedding,
            ),
            dim=-1,
        )
        """
        spans_embedding: (batch_size, num_spans, hidden_size*2+embedding_dim)
        """

        logits = self.ner_classifier(spans_embedding)[:, :, 0]
        logits = torch.sigmoid(logits)

        if spans_ner_label is not None:
            gold = spans_ner_label.type(torch.float)
            pred = logits
            loss_fct = BCELoss(
                weight=(gold > 0.5) * 0.97 + 0.03, reduction="sum"
            )  # 0 is 98.56% of the time, 1 is 1.44% of the time
            loss = loss_fct(pred, gold)
            # print(gold[0], pred[0])

            return loss, f1_loss(gold, pred), logits, spans_embedding
        else:
            return logits, spans_embedding
Example #8
0
    def forward(
        self,
        span: torch.Tensor,  # SHAPE: (batch_size, num_spans, span_dim)
        span_pairs: torch.LongTensor  # SHAPE: (batch_size, num_span_pairs)
    ):
        span1 = span2 = span
        if self.dim_reduce_layer1 is not None:
            span1 = self.dim_reduce_layer1(span)
        if self.dim_reduce_layer2 is not None:
            span2 = self.dim_reduce_layer2(span)

        if not self.pair:
            return span1, span2

        num_spans = span.size(1)

        # get span pair embedding
        span_pairs_p = span_pairs[:, :, 0]
        span_pairs_c = span_pairs[:, :, 1]
        # SHAPE: (batch_size * num_span_pairs)
        flat_span_pairs_p = util.flatten_and_batch_shift_indices(
            span_pairs_p, num_spans)
        flat_span_pairs_c = util.flatten_and_batch_shift_indices(
            span_pairs_c, num_spans)
        # SHAPE: (batch_size, num_span_pairs, span_dim)
        span_pair_p_emb = util.batched_index_select(span1, span_pairs_p,
                                                    flat_span_pairs_p)
        span_pair_c_emb = util.batched_index_select(span2, span_pairs_c,
                                                    flat_span_pairs_c)
        if self.combine == 'concat':
            # SHAPE: (batch_size, num_span_pairs, span_dim * 2)
            span_pair_emb = torch.cat([span_pair_p_emb, span_pair_c_emb], -1)
        elif self.combine == 'coref':
            # use the indices gap as distance, which requires the indices to be consistent
            # with the order they appear in the sentences
            distance = span_pairs_p - span_pairs_c
            # SHAPE: (batch_size, num_span_pairs, dist_emb_dim)
            distance_embeddings = self.distance_embedding(
                util.bucket_values(
                    distance, num_total_buckets=self.num_distance_buckets))
            # SHAPE: (batch_size, num_span_pairs, span_dim * 3)
            span_pair_emb = torch.cat([
                span_pair_p_emb, span_pair_c_emb,
                span_pair_p_emb * span_pair_c_emb, distance_embeddings
            ], -1)

        if self.repr_layer is not None:
            # SHAPE: (batch_size, num_span_pairs, out_dim)
            span_pair_emb = self.repr_layer(span_pair_emb)

        return span_pair_emb
Example #9
0
    def test_scorer_works_for_completely_masked_rows(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = Pruner(scorer=scorer) # type: ignore

        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :2, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        mask[2, :] = 0 # fully masked last batch element.

        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(items, mask, 2)

        # We can't check the last row here, because it's completely masked.
        # Instead we'll check that the scores for these elements are very small.
        numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1],
                                                                                       [1, 2]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1],
                                                                                [1, 1],
                                                                                [0, 0]]))
        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(items, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements, with masked elements very
        # small (but not -inf, because that can cause problems).  We'll test these two cases
        # separately.
        correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy()
        numpy.testing.assert_array_equal(correct_scores[:2], pruned_scores[:2].data.numpy())
        numpy.testing.assert_array_equal(pruned_scores[2] < -1e15, [[1], [1]])
        numpy.testing.assert_array_equal(pruned_scores[2] == float('-inf'), [[0], [0]])
Example #10
0
    def coref_propagation_doc(self, output_dict):
        coreference_scores = output_dict["coreference_scores"]
        top_span_embeddings = output_dict["top_span_embeddings"]
        antecedent_indices = output_dict["antecedent_indices"]
        for t in range(self.coref_prop):
            assert coreference_scores.shape[1] == antecedent_indices.shape[0]
            assert coreference_scores.shape[2] - 1 == antecedent_indices.shape[1]
            assert top_span_embeddings.shape[1] == coreference_scores.shape[1]
            assert antecedent_indices.max() <= top_span_embeddings.shape[1]

            antecedent_distribution = self.antecedent_softmax(coreference_scores)[:, :, 1:]
            top_span_emb_repeated = top_span_embeddings.repeat(antecedent_distribution.shape[2],1,1)
            if antecedent_indices.shape[0]==antecedent_indices.shape[1]:
                selected_top_span_embs = util.batched_index_select(top_span_emb_repeated, antecedent_indices).unsqueeze(0)
                entity_embs = (selected_top_span_embs.permute([3,0,1,2]) * antecedent_distribution).permute([1, 2, 3, 0]).sum(dim=2)
            else:
                ant_var1 = antecedent_indices.unsqueeze(0).unsqueeze(-1).repeat(1,1,1,top_span_embeddings.shape[-1])
                top_var1 = top_span_embeddings.unsqueeze(1).repeat(1,antecedent_distribution.shape[1],1,1)
                entity_embs = (torch.gather(top_var1, 2, ant_var1).permute([3,0,1,2]) * antecedent_distribution).permute([1, 2, 3, 0]).sum(dim=2)

            #entity_embs = F.dropout(entity_embs)

            f_network_input = torch.cat([top_span_embeddings, entity_embs], dim=-1)
            f_weights = self._f_network(f_network_input)
            top_span_embeddings = f_weights * top_span_embeddings + (1.0 - f_weights) * entity_embs

            #f_weights2 = self._f_network2(f_network_input)
            #top_span_embeddings = f_weights2 * top_span_embeddings + (1.0 - f_weights2) * entity_embs
            coreference_scores = self.get_coref_scores(top_span_embeddings, self._mention_pruner._scorer(top_span_embeddings), output_dict["antecedent_indices"], output_dict["valid_antecedent_offsets"], output_dict["valid_antecedent_log_mask"])

        output_dict["coreference_scores"] = coreference_scores
        output_dict["top_span_embeddings"] = top_span_embeddings
        return output_dict
Example #11
0
    def forward(self, text: Dict[str, torch.LongTensor],
                predicate_indicator: torch.LongTensor,
                predicate_index: torch.LongTensor, **kwargs):
        # slot_name -> Shape: batch_size, 1
        gold_slot_labels = self._get_gold_slot_labels(kwargs)
        if gold_slot_labels is None:
            raise ConfigurationError(
                "QfirstQuestionGenerator requires gold labels for teacher forcing when running forward. "
                "You may wish to run beam_decode instead.")
        # Shape: batch_size, num_tokens, self._sentence_encoder.get_output_dim()
        encoded_text, text_mask = self._sentence_encoder(
            text, predicate_indicator)
        # Shape: batch_size, self._sentence_encoder.get_output_dim()
        pred_rep = batched_index_select(encoded_text,
                                        predicate_index).squeeze(1)
        # slot_name -> Shape: batch_size, slot_name_vocab_size
        slot_logits = self._question_generator(pred_rep, **gold_slot_labels)

        batch_size, _ = pred_rep.size()

        # Shape: <scalar>
        slot_nlls, neg_log_likelihood = self._get_cross_entropy(
            slot_logits, gold_slot_labels)
        self.metric(slot_logits, gold_slot_labels, torch.ones([batch_size]),
                    slot_nlls, neg_log_likelihood)
        return {**slot_logits, "loss": neg_log_likelihood}
Example #12
0
 def forward(self,
             text: Dict[str, torch.LongTensor],
             predicate_indicator: torch.LongTensor,
             predicate_index: torch.LongTensor,
             clause_dist: torch.FloatTensor = None,
             **kwargs):
     # Shape: batch_size, num_tokens, self._sentence_encoder.get_output_dim()
     encoded_text, text_mask = self._sentence_encoder(text, predicate_indicator)
     # Shape: batch_size, encoder_output_dim
     pred_rep = batched_index_select(encoded_text, predicate_index).squeeze(1)
     # Shape: batch_size, get_vocab_size(self._label_namespace)
     frame_logits = self._frame_pred(pred_rep)
     frame_probs = F.softmax(frame_logits, dim = 1)
     frames = F.softmax(self._frames_matrix, dim = 1)
     clause_probs = torch.matmul(frame_probs, frames)
     clause_log_probs = clause_probs.log()
     output_dict = { "probs": frame_probs }
     # TODO figure out how to do this with logits
     # TODO figure out how to handle the null case
     if clause_dist is not None and clause_dist.sum().item() > 0.1:
         gold_clause_probs = F.normalize(clause_dist.float(), p = 1, dim = 1)
         cross_entropy = torch.sum(-gold_clause_probs * clause_log_probs, 1)
         output_dict["loss"] = torch.mean(cross_entropy)
         gold_entropy = -gold_clause_probs * gold_clause_probs.log() # entropy summands
         gold_entropy[gold_entropy != gold_entropy] = 0.0 # zero out nans
         gold_entropy = torch.sum(gold_entropy, dim = 1) # compute sum for per-batch-item entropy
         kl_divergence = cross_entropy - gold_entropy
         self._metric(clause_probs, clause_dist > 0.0)
         self._kl_divergence_metric(kl_divergence)
     return output_dict
Example #13
0
 def forward(
         self,  # type: ignore
         text: Dict[str, torch.LongTensor],
         predicate_indicator: torch.LongTensor,
         predicate_index: torch.LongTensor,
         clause: torch.LongTensor,
         answer_slot: torch.LongTensor,
         answer_spans: torch.LongTensor = None,
         span_counts: torch.LongTensor = None,
         num_answers: torch.LongTensor = None,
         metadata=None,
         **kwargs):
     embedded_clause = self._clause_embedding(clause)
     embedded_slot = self._slot_embedding(answer_slot)
     encoded_text, text_mask = self._sentence_encoder(
         text, predicate_indicator)
     pred_rep = batched_index_select(encoded_text,
                                     predicate_index).squeeze(1)
     combined_embedding = torch.cat(
         [embedded_clause, embedded_slot, pred_rep], -1)
     question_embedding = self._question_projection(combined_embedding)
     return self._span_selector(encoded_text,
                                text_mask,
                                extra_input_embedding=question_embedding,
                                answer_spans=answer_spans,
                                span_counts=span_counts,
                                num_answers=num_answers,
                                metadata=metadata)
Example #14
0
    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(
            max_batch_span_width,
            util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(
            raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor,
                                                    span_indices,
                                                    flat_span_indices)

        #  text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        batch_size, num_spans, max_batch_span_width, _ = span_embeddings.size()

        view_text_embeddings = span_embeddings.view(batch_size * num_spans,
                                                    max_batch_span_width, -1)
        span_mask = span_mask.view(batch_size * num_spans,
                                   max_batch_span_width)
        cnn_text_embeddings = self.cnn(view_text_embeddings, span_mask)
        cnn_text_embeddings = cnn_text_embeddings.view(batch_size, num_spans,
                                                       self._output_dim)
        return cnn_text_embeddings
Example #15
0
    def predict_labels_doc(self, output_dict):
        # Shape: (batch_size, num_spans_to_keep)
        coref_labels = output_dict["coref_labels"]
        coreference_scores = output_dict["coreference_scores"]
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict["predicted_antecedents"] = predicted_antecedents

        top_span_indices = output_dict["top_span_indices"]
        flat_top_span_indices = output_dict["flat_top_span_indices"]
        valid_antecedent_indices = output_dict["antecedent_indices"]
        valid_antecedent_log_mask = output_dict["valid_antecedent_log_mask"]
        top_spans = output_dict["top_spans"]
        top_span_mask = output_dict["top_span_mask"]
        metadata = output_dict["metadata"]
        sentence_lengths = output_dict["sentence_lengths"]

        if coref_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                coref_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            # There's an integer wrap-around happening here. It occurs in the original code.
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            # Need to get cluster data in same form as for original AllenNLP coref code so that the
            # evaluation code works.
            evaluation_metadata = self._make_evaluation_metadata(
                metadata, sentence_lengths)

            self._mention_recall(top_spans, evaluation_metadata)

            # TODO(dwadden) Shouldnt need to do the unsqueeze here; figure out what's happening.
            self._conll_coref_scores(top_spans,
                                     valid_antecedent_indices.unsqueeze(0),
                                     predicted_antecedents,
                                     evaluation_metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        return output_dict
Example #16
0
    def test_pruner_selects_top_scored_items_and_respects_masking(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = Pruner(scorer=scorer)

        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :2, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(
            items, mask, 2)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(pruned_indices.data.numpy(),
                                         numpy.array([[0, 1], [1, 2], [2, 3]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(),
                                         numpy.ones([3, 2]))

        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(items, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements.
        numpy.testing.assert_array_equal(
            correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
            pruned_scores.data.numpy())
Example #17
0
    def test_batched_index_select(self):
        indices = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
        # Each element is a vector of it's index.
        targets = torch.ones([2, 10, 3]).cumsum(1) - 1
        # Make the second batch double it's index so they're different.
        targets[1, :, :] *= 2
        indices = Variable(torch.LongTensor(indices))
        targets = Variable(targets)
        selected = util.batched_index_select(targets, indices)

        assert list(selected.size()) == [2, 2, 2, 3]
        ones = numpy.ones([3])
        numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(),
                                         ones)
        numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(),
                                         ones * 2)
        numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(),
                                         ones * 3)
        numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(),
                                         ones * 4)

        numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(),
                                         ones * 10)
        numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(),
                                         ones * 12)
        numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(),
                                         ones * 14)
        numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(),
                                         ones * 16)
Example #18
0
    def test_span_scorer_works_for_completely_masked_rows(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = SpanPruner(scorer=scorer)  # type: ignore

        spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        spans[0, :2, :] = 1
        spans[1, 2:, :] = 1
        spans[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        mask[2, :] = 0  # fully masked last batch element.

        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(
            spans, mask, 2)

        # We can't check the last row here, because it's completely masked.
        # Instead we'll check that the scores for these elements are -inf.
        numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(),
                                         numpy.array([[0, 1], [1, 2]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(),
                                         numpy.array([[1, 1], [1, 1], [0, 0]]))
        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(spans, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements, with
        # masked elements equal to -inf.
        correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy()
        correct_scores[2, :] = float(u"-inf")
        numpy.testing.assert_array_equal(correct_scores,
                                         pruned_scores.data.numpy())
Example #19
0
    def test_span_scorer_works_for_completely_masked_rows(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = SpanPruner(scorer=scorer) # type: ignore

        spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        spans[0, :2, :] = 1
        spans[1, 2:, :] = 1
        spans[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        mask[2, :] = 0 # fully masked last batch element.

        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2)

        # We can't check the last row here, because it's completely masked.
        # Instead we'll check that the scores for these elements are -inf.
        numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1],
                                                                                       [1, 2]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1],
                                                                                [1, 1],
                                                                                [0, 0]]))
        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(spans, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements, with
        # masked elements equal to -inf.
        correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy()
        correct_scores[2, :] = float("-inf")
        numpy.testing.assert_array_equal(correct_scores,
                                         pruned_scores.data.numpy())
Example #20
0
    def test_span_pruner_selects_top_scored_spans_and_respects_masking(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = SpanPruner(scorer=scorer)

        spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        spans[0, :2, :] = 1
        spans[1, 2:, :] = 1
        spans[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1],
                                                                                   [1, 2],
                                                                                   [2, 3]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2]))

        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(spans, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements.
        numpy.testing.assert_array_equal(correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
                                         pruned_scores.data.numpy())
Example #21
0
    def inference_coref(self, batch, embedded_text_input_relation, mask):
        submodel = self.model._tagger_coref

        ### Fast inference of coreference ###
        spans = batch["spans"]

        document_length = mask.size(1)
        num_spans = spans.size(1)

        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        spans = F.relu(spans.float()).long()

        encoded_text_coref = submodel._context_layer(
            embedded_text_input_relation, mask)
        endpoint_span_embeddings = submodel._endpoint_span_extractor(
            encoded_text_coref, spans)
        attended_span_embeddings = submodel._attentive_span_extractor(
            embedded_text_input_relation, spans)

        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        num_spans_to_keep = int(
            math.floor(submodel._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = submodel._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        max_antecedents = min(submodel._max_antecedents, num_spans_to_keep)

        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            submodel._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(mask))
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        span_pair_embeddings = submodel._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        coreference_scores = submodel._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        _, predicted_antecedents = coreference_scores.max(2)
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents
        }

        return output_dict
    def score_spans_if_labels(
        self,
        output_dict,
        span_labels,
        metadata,
        top_span_indices,
        flat_top_span_indices,
        top_span_mask,
        top_spans,
        valid_antecedent_indices,
        valid_antecedent_log_mask,
        coreference_scores,
        predicted_antecedents,
    ):
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Example #23
0
    def test_masked_indices_are_handled_correctly_with_exclusive_indices(self):
        sequence_tensor = Variable(torch.randn([2, 5, 8]))
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = EndpointSpanExtractor(8,
                                          "x,y",
                                          use_exclusive_start_indices=True)
        indices = Variable(
            torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [0, 1]]]))
        sequence_mask = Variable(
            torch.LongTensor([[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]))

        span_representations = extractor(sequence_tensor,
                                         indices,
                                         sequence_mask=sequence_mask)

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(8, -1)

        correct_start_indices = Variable(torch.LongTensor([[0, 1], [-1, -1]]))
        # These indices should be -1, so they'll be replaced with a sentinel. Here,
        # we'll set them to a value other than -1 so we can index select the indices and
        # replace them later.
        correct_start_indices[1, 0] = 1
        correct_start_indices[1, 1] = 1

        correct_end_indices = Variable(torch.LongTensor([[3, 4], [2, 1]]))

        correct_start_embeddings = batched_index_select(
            sequence_tensor.contiguous(), correct_start_indices)
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_start_embeddings[1, 0] = extractor._start_sentinel.data
        correct_start_embeddings[1, 1] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
                                         correct_start_embeddings.data.numpy())

        correct_end_embeddings = batched_index_select(
            sequence_tensor.contiguous(), correct_end_indices)
        numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
                                         correct_end_embeddings.data.numpy())
    def forward(self, # pylint: disable=arguments-differ
                sequence_tensor: torch.FloatTensor,
                indicies: torch.LongTensor) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in indicies.split(1, dim=-1)]
        start_embeddings = batched_index_select(sequence_tensor, span_starts)
        end_embeddings = batched_index_select(sequence_tensor, span_ends)

        combined_tensors = combine_tensors(self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = bucket_values(span_ends - span_starts,
                                            num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        return combined_tensors
    def forward(self,
                text: Dict[str, torch.LongTensor],
                predicate_indicator: torch.LongTensor,
                predicate_index: torch.LongTensor,
                qarg_labeled_clauses,
                qarg_labeled_spans,
                qarg_labels=None,
                **kwargs):
        # Shape: batch_size, num_tokens, encoder_output_dim
        encoded_text, text_mask = self._sentence_encoder(
            text, predicate_indicator)
        # Shape: batch_size, encoder_output_dim
        pred_rep = batched_index_select(encoded_text,
                                        predicate_index).squeeze(1)

        batch_size, num_labeled_instances, _ = qarg_labeled_spans.size()
        # Shape: batch_size, num_labeled_instances
        qarg_labeled_mask = (qarg_labeled_spans[:, :, 0] >=
                             0).squeeze(-1).long()
        if len(qarg_labeled_mask.size()) == 1:
            qarg_labeled_mask = qarg_labeled_mask.unsqueeze(-1)

        # max to prevent the padded labels from messing up the embedding module
        # Shape: batch_size, num_labeled_instances, self._clause_embedding_dim
        input_clauses = self._clause_embedding(
            qarg_labeled_clauses.max(torch.zeros_like(qarg_labeled_clauses)))

        # Shape: batch_size, num_spans, 2 * encoder_output_projected_dim
        span_embeddings = self._span_extractor(encoded_text,
                                               qarg_labeled_spans, text_mask,
                                               qarg_labeled_mask)
        # Shape: batch_size, num_spans, self._span_hidden_dim
        input_span_hidden = self._span_hidden(span_embeddings)

        # Shape: batch_size, 1, self._final_input_dim
        pred_embedding = self._predicate_hidden(pred_rep)
        # Shape: batch_size, num_labeled_instances, self._final_input_dim
        qarg_inputs = F.relu(
            pred_embedding.unsqueeze(1) + input_clauses + input_span_hidden)
        # Shape: batch_size, num_labeled_instances, get_vocab_size("qarg-labels")
        qarg_logits = self._qarg_predictor(self._qarg_ffnn(qarg_inputs))

        final_mask = qarg_labeled_mask.unsqueeze(-1) \
                        .expand(batch_size, num_labeled_instances, self.vocab.get_vocab_size("qarg-labels")) \
                        .float()
        qarg_probs = torch.sigmoid(qarg_logits).squeeze(-1) * final_mask

        output_dict = {"logits": qarg_logits, "probs": qarg_probs}
        if qarg_labels is not None:
            output_dict["loss"] = F.binary_cross_entropy_with_logits(
                qarg_logits, qarg_labels, weight=final_mask, reduction="sum")
            self._metric(qarg_probs, qarg_labels)
        return output_dict
Example #26
0
 def forward(self, text: Dict[str, torch.LongTensor],
             predicate_indicator: torch.LongTensor,
             predicate_index: torch.LongTensor, **kwargs):
     # Shape: batch_size, num_tokens, self._sentence_encoder.get_output_dim()
     encoded_text, text_mask = self._sentence_encoder(
         text, predicate_indicator)
     # Shape: batch_size, encoder_output_dim
     pred_rep = batched_index_select(encoded_text,
                                     predicate_index).squeeze(1)
     # Shape: batch_size, get_vocab_size(self._label_namespace)
     logits = self._final_pred(pred_rep)
     return self._classifier(logits, None, kwargs.get(self._label_name),
                             kwargs.get(self._label_name + "_counts"))
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [
            index.squeeze(-1) for index in span_indices.split(1, dim=-1)
        ]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        start_embeddings = batched_index_select(sequence_tensor, span_starts)
        end_embeddings = batched_index_select(sequence_tensor, span_ends)

        combined_tensors = combine_tensors(self._combination,
                                           [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = bucket_values(
                    span_ends - span_starts,
                    num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()
        return combined_tensors
Example #28
0
 def beam_decode(self,
                 text: Dict[str, torch.LongTensor],
                 predicate_indicator: torch.LongTensor,
                 predicate_index: torch.LongTensor,
                 max_beam_size: int,
                 min_beam_probability: float,
                 clause_mode: bool = False):
     # Shape: batch_size, num_tokens, self._sentence_encoder.get_output_dim()
     encoded_text, text_mask = self._sentence_encoder(
         text, predicate_indicator)
     # Shape: batch_size, self._sentence_encoder.get_output_dim()
     pred_rep = batched_index_select(encoded_text,
                                     predicate_index).squeeze(1)
     return self._question_generator.beam_decode(pred_rep, max_beam_size,
                                                 min_beam_probability,
                                                 clause_mode)
Example #29
0
def nd_batched_index_select(target: torch.Tensor,
                            indices: torch.IntTensor) -> torch.Tensor:
    """
    Multidimensional version of `util.batched_index_select`.
    """
    batch_axes = target.size()[:-2]
    num_batch_axes = len(batch_axes)
    target_shape = target.size()
    indices_shape = indices.size()

    target_reshaped = target.view(-1, *target_shape[num_batch_axes:])
    indices_reshaped = indices.view(-1, *indices_shape[num_batch_axes:])

    output_reshaped = util.batched_index_select(target_reshaped,
                                                indices_reshaped)

    return output_reshaped.view(*indices_shape, -1)
Example #30
0
    def forward(self,
                text: Dict[str, torch.LongTensor],
                predicate_indicator: torch.LongTensor,
                predicate_index: torch.LongTensor,
                tan_spans,
                tan_labels=None,
                **kwargs):
        # Shape: batch_size, num_tokens, encoder_output_dim
        encoded_text, text_mask = self._sentence_encoder(
            text, predicate_indicator)

        # Shape: batch_size, num_labeled_instances
        span_mask = (tan_spans[:, :, 0] >= 0).squeeze(-1).float()
        if len(span_mask.size()) == 1:
            span_mask = span_mask.unsqueeze(-1)

        # Shape: batch_size, num_spans, 2 * encoder_output_projected_dim
        span_embeddings = self._span_extractor(encoded_text, tan_spans,
                                               text_mask, span_mask.long())
        batch_size, num_spans, _ = span_embeddings.size()

        if self._inject_predicate:
            expanded_pred_embedding = batched_index_select(encoded_text, predicate_index) \
                                      .expand(batch_size, num_spans, self._sentence_encoder.get_output_dim())
            input_embeddings = torch.cat(
                [expanded_pred_embedding, span_embeddings], -1)
        else:
            input_embeddings = span_embeddings

        tan_mask = span_mask.unsqueeze(-1)  # broadcast ops to all tan IDs

        # Shape: batch_size, num_labeled_instances, self.vocab.get_vocab_size("tan-string-labels")
        tan_logits = self._tan_pred(input_embeddings)
        tan_probs = torch.sigmoid(tan_logits) * tan_mask

        output_dict = {"logits": tan_logits, "probs": tan_probs}
        if tan_labels is not None:
            output_dict["loss"] = F.binary_cross_entropy_with_logits(
                tan_logits,
                tan_labels.float(),
                weight=tan_mask,
                reduction="sum")
            self._metric(tan_probs, tan_labels, tan_mask.long())
        return output_dict
Example #31
0
    def _prune_spans(self, spans, span_mask, span_embeddings, sentence_lengths):
        # Prune
        num_spans = spans.size(1)  # Max number of spans for the minibatch.

        # Keep different number of spans for each minibatch entry.
        num_spans_to_keep = torch.ceil(sentence_lengths.float() * self._spans_per_word).long()

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores, num_spans_kept) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)

        top_span_mask = top_span_mask.unsqueeze(-1)

        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        return top_span_embeddings, top_span_mention_scores, num_spans_to_keep, top_span_mask, top_span_indices, top_spans
Example #32
0
 def _get_question_inputs(self, text, predicate_indicator, predicate_index,
                          answer_spans):
     encoded_text, text_mask = self._sentence_encoder(
         text, predicate_indicator)
     span_mask = (answer_spans[:, :, 0] >= 0).long()
     # Shape: batch_size, num_spans, 2 * self._sentence_encoder.get_output_dim()
     span_reps = self._span_extractor(encoded_text,
                                      answer_spans,
                                      sequence_mask=text_mask,
                                      span_indices_mask=span_mask)
     batch_size, _, encoding_dim = encoded_text.size()
     num_spans = span_reps.size(1)
     if self._inject_predicate:
         pred_rep_expanded = batched_index_select(encoded_text, predicate_index) \
                             .expand(batch_size, num_spans, encoding_dim)
         question_inputs = torch.cat([pred_rep_expanded, span_reps], -1)
     else:
         question_inputs = span_reps
     return question_inputs, span_mask
Example #33
0
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            predicate_indicator: torch.LongTensor,
            predicate_index: torch.LongTensor,
            answer_spans: torch.LongTensor = None,
            num_answers: torch.LongTensor = None,
            num_invalids: torch.LongTensor = None,
            metadata=None,
            **kwargs):
        # each of gold_slot_labels[slot_name] is of
        # Shape: batch_size
        slot_labels = self._get_slot_labels(**kwargs)
        if slot_labels is None:
            raise ConfigurationError(
                "QuestionAnswerer must receive question slots as input.")

        encoded_text, text_mask = self._sentence_encoder(
            text, predicate_indicator)
        pred_rep = batched_index_select(encoded_text,
                                        predicate_index).squeeze(1)
        question_encoding = self._question_encoder(pred_rep, slot_labels)
        question_rep = torch.cat([pred_rep, question_encoding], -1)
        output_dict = self._span_selector(encoded_text,
                                          text_mask,
                                          extra_input_embedding=question_rep,
                                          answer_spans=answer_spans,
                                          num_answers=num_answers,
                                          metadata=metadata)

        if self._classify_invalids:
            invalid_logits = self._invalid_pred(question_rep).squeeze(-1)
            invalid_probs = torch.sigmoid(invalid_logits)
            output_dict["invalid_prob"] = invalid_probs
            if num_invalids is not None:
                invalid_labels = (num_invalids > 0.0).float()
                invalid_loss = F.binary_cross_entropy_with_logits(
                    invalid_logits, invalid_labels, reduction="sum")
                output_dict["loss"] += invalid_loss
                self._invalid_metric(invalid_probs, invalid_labels)
        return output_dict
Example #34
0
 def forward(
         self,  # type: ignore
         text: Dict[str, torch.LongTensor],
         predicate_indicator: torch.LongTensor = None,
         predicate_index: torch.LongTensor = None,
         answer_spans: torch.LongTensor = None,
         span_counts: torch.LongTensor = None,
         num_answers: torch.LongTensor = None,
         metadata=None,
         **kwargs):
     encoded_text, text_mask = self._sentence_encoder(
         text, predicate_indicator)
     extra_input = batched_index_select(
         encoded_text, predicate_index) if self._inject_predicate else None
     return self._span_selector(encoded_text,
                                text_mask,
                                extra_input_embedding=extra_input,
                                answer_spans=answer_spans,
                                span_counts=span_counts,
                                num_answers=num_answers,
                                metadata=metadata)
Example #35
0
    def test_batched_index_select(self):
        indices = numpy.array([[[1, 2],
                                [3, 4]],
                               [[5, 6],
                                [7, 8]]])
        # Each element is a vector of it's index.
        targets = torch.ones([2, 10, 3]).cumsum(1) - 1
        # Make the second batch double it's index so they're different.
        targets[1, :, :] *= 2
        indices = torch.tensor(indices, dtype=torch.long)
        selected = util.batched_index_select(targets, indices)

        assert list(selected.size()) == [2, 2, 2, 3]
        ones = numpy.ones([3])
        numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones)
        numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2)
        numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3)
        numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4)

        numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10)
        numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12)
        numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14)
        numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)
    def test_correct_sequence_elements_are_embedded_with_a_masked_sequence(self):
        sequence_tensor = torch.randn([2, 5, 8])
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = BidirectionalEndpointSpanExtractor(input_dim=8,
                                                       forward_combination="x,y",
                                                       backward_combination="x,y")
        indices = torch.LongTensor([[[1, 3],
                                     [2, 4]],
                                    # This span has an end index at the
                                    # end of the padded sequence.
                                    [[0, 2],
                                     [0, 1]]])
        sequence_mask = torch.LongTensor([[1, 1, 1, 1, 1],
                                          [1, 1, 1, 0, 0]])

        span_representations = extractor(sequence_tensor, indices, sequence_mask=sequence_mask)

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        (forward_start_embeddings, forward_end_embeddings,
         backward_start_embeddings, backward_end_embeddings) = span_representations.split(4, -1)

        forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(4, -1)

        # Forward direction => subtract 1 from start indices to make them exlusive.
        correct_forward_start_indices = torch.LongTensor([[0, 1], [-1, -1]])
        # These indices should be -1, so they'll be replaced with a sentinel. Here,
        # we'll set them to a value other than -1 so we can index select the indices and
        # replace them later.
        correct_forward_start_indices[1, 0] = 1
        correct_forward_start_indices[1, 1] = 1

        # Forward direction => end indices are the same.
        correct_forward_end_indices = torch.LongTensor([[3, 4], [2, 1]])

        # Backward direction => start indices are exclusive, so add 1 to the end indices.
        correct_backward_start_indices = torch.LongTensor([[4, 5], [3, 2]])
        # These exclusive backward start indices are outside the tensor, so will be replaced
        # with the end sentinel. Here we replace them with ones so we can index select using
        # these indices without torch complaining.
        correct_backward_start_indices[0, 1] = 1

        # Backward direction => end indices are inclusive and equal to the forward start indices.
        correct_backward_end_indices = torch.LongTensor([[1, 2], [0, 0]])

        correct_forward_start_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
                                                                correct_forward_start_indices)
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data
        correct_forward_start_embeddings[1, 1] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(forward_start_embeddings.data.numpy(),
                                         correct_forward_start_embeddings.data.numpy())

        correct_forward_end_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
                                                              correct_forward_end_indices)
        numpy.testing.assert_array_equal(forward_end_embeddings.data.numpy(),
                                         correct_forward_end_embeddings.data.numpy())

        correct_backward_end_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
                                                               correct_backward_end_indices)
        numpy.testing.assert_array_equal(backward_end_embeddings.data.numpy(),
                                         correct_backward_end_embeddings.data.numpy())

        correct_backward_start_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
                                                                 correct_backward_start_indices)
        # This element had sequence_tensor index == sequence_tensor.size(1),
        # so it's exclusive index is the end sentinel.
        correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data
        # This element has sequence_tensor index == the masked length of the batch element,
        # so it should be the end_sentinel even though it isn't greater than sequence_tensor.size(1).
        correct_backward_start_embeddings[1, 0] = extractor._end_sentinel.data

        numpy.testing.assert_array_equal(backward_start_embeddings.data.numpy(),
                                         correct_backward_start_embeddings.data.numpy())
    def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = Variable(torch.randn([2, 5, 8]))
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = BidirectionalEndpointSpanExtractor(input_dim=8,
                                                       forward_combination="x,y",
                                                       backward_combination="x,y")
        indices = Variable(torch.LongTensor([[[1, 3],
                                              [2, 4]],
                                             [[0, 2],
                                              [3, 4]]]))

        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 16]
        assert extractor.get_output_dim() == 16
        assert extractor.get_input_dim() == 8

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        (forward_start_embeddings, forward_end_embeddings,
         backward_start_embeddings, backward_end_embeddings) = span_representations.split(4, -1)

        forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(4, -1)

        # Forward direction => subtract 1 from start indices to make them exlusive.
        correct_forward_start_indices = Variable(torch.LongTensor([[0, 1],
                                                                   [-1, 2]]))
        # This index should be -1, so it will be replaced with a sentinel. Here,
        # we'll set it to a value other than -1 so we can index select the indices and
        # replace it later.
        correct_forward_start_indices[1, 0] = 1

        # Forward direction => end indices are the same.
        correct_forward_end_indices = Variable(torch.LongTensor([[3, 4], [2, 4]]))

        # Backward direction => start indices are exclusive, so add 1 to the end indices.
        correct_backward_start_indices = Variable(torch.LongTensor([[4, 5], [3, 5]]))
        # These exclusive end indices are outside the tensor, so will be replaced with the end sentinel.
        # Here we replace them with ones so we can index select using these indices without torch
        # complaining.
        correct_backward_start_indices[0, 1] = 1
        correct_backward_start_indices[1, 1] = 1
        # Backward direction => end indices are inclusive and equal to the forward start indices.
        correct_backward_end_indices = Variable(torch.LongTensor([[1, 2], [0, 3]]))

        correct_forward_start_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
                                                                correct_forward_start_indices)
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(forward_start_embeddings.data.numpy(),
                                         correct_forward_start_embeddings.data.numpy())

        correct_forward_end_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
                                                              correct_forward_end_indices)
        numpy.testing.assert_array_equal(forward_end_embeddings.data.numpy(),
                                         correct_forward_end_embeddings.data.numpy())

        correct_backward_end_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
                                                               correct_backward_end_indices)
        numpy.testing.assert_array_equal(backward_end_embeddings.data.numpy(),
                                         correct_backward_end_embeddings.data.numpy())

        correct_backward_start_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
                                                                 correct_backward_start_indices)
        # This element had sequence_tensor index == sequence_tensor.size(1),
        # so it's exclusive index is the end sentinel.
        correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data
        correct_backward_start_embeddings[1, 1] = extractor._end_sentinel.data
        numpy.testing.assert_array_equal(backward_start_embeddings.data.numpy(),
                                         correct_backward_start_embeddings.data.numpy())
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1)
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP ``SpanField`` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1)

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1))
        exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
            raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, "
                             f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                             f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                             f"{sequence_lengths}.")

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts)
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends)
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            float_end_sentinel_mask = end_sentinel_mask.float()
            float_start_sentinel_mask = start_sentinel_mask.float()
            forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel
            backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \
                                        + float_end_sentinel_mask * self._end_sentinel

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(self._forward_combination,
                                             [forward_start_embeddings, forward_end_embeddings])
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(self._backward_combination,
                                              [backward_start_embeddings, backward_end_embeddings])
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.float().unsqueeze(-1)
        return span_embeddings
Example #39
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
    def _get_initial_rnn_and_grammar_state(self,
                                           question: Dict[str, torch.LongTensor],
                                           table: Dict[str, torch.LongTensor],
                                           world: List[WikiTablesWorld],
                                           actions: List[List[ProductionRule]],
                                           outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet],
                                                                             List[LambdaGrammarStatelet]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the
        decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(global_attention_logits,
                                                          span_indices,
                                                          flat_span_indices).squeeze(-1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()

        return attended_text_embeddings
Example #42
0
    def forward(self, # pylint: disable=arguments-differ
                embeddings: torch.FloatTensor,
                mask: torch.LongTensor,
                num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor,
                                                 torch.LongTensor, torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model).

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``int``, required.
            The number of items to keep when pruning.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The corresponding mask for ``top_embeddings``.
            Has shape (batch_size, num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, 1).
        """
        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)
        # Shape: (batch_size, num_items, 1)
        scores = self._scorer(embeddings)

        if scores.size(-1) != 1 or scores.dim() != 3:
            raise ValueError(f"The scorer passed to Pruner must produce a tensor of shape"
                             f"(batch_size, num_items, 1), but found shape {scores.size()}")
        # Make sure that we don't select any masked items by setting their scores to be very
        # negative.  These are logits, typically, so -1e20 should be plenty negative.
        scores = util.replace_masked_values(scores, mask, -1e20)

        # Shape: (batch_size, num_items_to_keep, 1)
        _, top_indices = scores.topk(num_items_to_keep, 1)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size, num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Shape: (batch_size * num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items)

        # Shape: (batch_size, num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices)
        # Shape: (batch_size, num_items_to_keep)
        top_mask = util.batched_index_select(mask, top_indices, flat_top_indices)

        # Shape: (batch_size, num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices, flat_top_indices)

        return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
    def _get_initial_state_and_scores(self,
                                      question: Dict[str, torch.LongTensor],
                                      table: Dict[str, torch.LongTensor],
                                      world: List[WikiTablesWorld],
                                      actions: List[List[ProductionRuleArray]],
                                      example_lisp_string: List[str] = None,
                                      add_world_to_initial_state: bool = False,
                                      checklist_states: List[ChecklistState] = None) -> Dict:
        """
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               output_action_embeddings=output_action_embeddings,
                                               action_biases=action_biases,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               world=initial_state_world,
                                               example_lisp_string=example_lisp_string,
                                               checklist_state=checklist_states,
                                               debug_info=None)
        return {"initial_state": initial_state,
                "linking_scores": linking_scores,
                "feature_scores": feature_scores,
                "similarity_scores": question_entity_similarity_max_score}