Beispiel #1
0
    def _passage_span_module(self, passage_out, passage_mask):
        # Shape: (batch_size, passage_length)
        passage_span_start_logits = self._passage_span_start_predictor(
            passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_end_logits = self._passage_span_end_predictor(
            passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_start_log_probs = util.masked_log_softmax(
            passage_span_start_logits, passage_mask)
        passage_span_end_log_probs = util.masked_log_softmax(
            passage_span_end_logits, passage_mask)

        # Info about the best passage span prediction
        passage_span_start_logits = util.replace_masked_values(
            passage_span_start_logits, passage_mask, -1e7)
        passage_span_end_logits = util.replace_masked_values(
            passage_span_end_logits, passage_mask, -1e7)

        # Shape: (batch_size, 2)
        best_passage_span = get_best_span(passage_span_start_logits,
                                          passage_span_end_logits)
        return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
Beispiel #2
0
    def forward(self,
                **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]:

        input, mask = self.get_input_and_mask(kwargs)

        # Shape: (batch_size, passage_length)
        start_logits = self._start_output_layer(input).squeeze(-1)

        # Shape: (batch_size, passage_length)
        end_logits = self._end_output_layer(input).squeeze(-1)

        start_log_probs = masked_log_softmax(start_logits, mask)
        end_log_probs = masked_log_softmax(end_logits, mask)

        # Info about the best span prediction
        start_logits = replace_masked_values(start_logits, mask, -1e7)
        end_logits = replace_masked_values(end_logits, mask, -1e7)

        # Shape: (batch_size, 2)
        best_span = get_best_span(start_logits, end_logits)

        output_dict = {
            'start_log_probs': start_log_probs,
            'end_log_probs': end_log_probs,
            'best_span': best_span
        }
        return output_dict
Beispiel #3
0
    def _question_span_module(self, passage_vector, question_out,
                              question_mask):
        # Shape: (batch_size, question_length)
        encoded_question_for_span_prediction = \
            torch.cat([question_out,
                       passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1)
        question_span_start_logits = \
            self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
        # Shape: (batch_size, question_length)
        question_span_end_logits = \
            self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
        question_span_start_log_probs = util.masked_log_softmax(
            question_span_start_logits, question_mask)
        question_span_end_log_probs = util.masked_log_softmax(
            question_span_end_logits, question_mask)

        # Info about the best question span prediction
        question_span_start_logits = \
            util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
        question_span_end_logits = \
            util.replace_masked_values(question_span_end_logits, question_mask, -1e7)

        # Shape: (batch_size, 2)
        best_question_span = get_best_span(question_span_start_logits,
                                           question_span_end_logits)
        return question_span_start_log_probs, question_span_end_log_probs, best_question_span
Beispiel #4
0
    def test_masked_log_softmax_masked(self):
        # Tests replicated from test_softmax_masked - we test that exponentiated,
        # the log softmax contains the correct elements (masked elements should be == 1).

        # Testing the general masked 1D case.
        vector_1d = torch.FloatTensor([[1.0, 2.0, 5.0]])
        mask_1d = torch.FloatTensor([[1.0, 0.0, 1.0]])
        vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy()
        assert_array_almost_equal(numpy.exp(vector_1d_softmaxed),
                                  numpy.array([[0.01798621, 0.0, 0.98201382]]))

        vector_1d = torch.FloatTensor([[0.0, 2.0, 3.0, 4.0]])
        mask_1d = torch.FloatTensor([[1.0, 0.0, 1.0, 1.0]])
        vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy()
        assert_array_almost_equal(numpy.exp(vector_1d_softmaxed),
                                  numpy.array([[0.01321289, 0.0,
                                                0.26538793, 0.72139918]]))

        # Testing the masked 1D case where the input is all 0s and the mask
        # is not all 0s.
        vector_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 0.0]])
        mask_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 1.0]])
        vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy()
        assert_array_almost_equal(numpy.exp(vector_1d_softmaxed),
                                  numpy.array([[0., 0., 0., 1.]]))

        # Testing the masked 1D case where the input is not all 0s
        # and the mask is all 0s.  The output here will be arbitrary, but it should not be nan.
        vector_1d = torch.FloatTensor([[0.0, 2.0, 3.0, 4.0]])
        mask_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 0.0]])
        vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy()
        assert not numpy.isnan(vector_1d_softmaxed).any()
Beispiel #5
0
    def test_masked_log_softmax_masked(self):
        # Tests replicated from test_softmax_masked - we test that exponentiated,
        # the log softmax contains the correct elements (masked elements should be == 1).

        # Testing the general masked 1D case.
        vector_1d = torch.FloatTensor([[1.0, 2.0, 5.0]])
        mask_1d = torch.FloatTensor([[1.0, 0.0, 1.0]])
        vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy()
        assert_array_almost_equal(numpy.exp(vector_1d_softmaxed),
                                  numpy.array([[0.01798621, 0.0, 0.98201382]]))

        vector_1d = torch.FloatTensor([[0.0, 2.0, 3.0, 4.0]])
        mask_1d = torch.FloatTensor([[1.0, 0.0, 1.0, 1.0]])
        vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy()
        assert_array_almost_equal(numpy.exp(vector_1d_softmaxed),
                                  numpy.array([[0.01321289, 0.0,
                                                0.26538793, 0.72139918]]))

        # Testing the masked 1D case where the input is all 0s and the mask
        # is not all 0s.
        vector_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 0.0]])
        mask_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 1.0]])
        vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy()
        assert_array_almost_equal(numpy.exp(vector_1d_softmaxed),
                                  numpy.array([[0., 0., 0., 1.]]))

        # Testing the masked 1D case where the input is not all 0s
        # and the mask is all 0s.  The output here will be arbitrary, but it should not be nan.
        vector_1d = torch.FloatTensor([[0.0, 2.0, 3.0, 4.0]])
        mask_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 0.0]])
        vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy()
        assert not numpy.isnan(vector_1d_softmaxed).any()
    def _construct_loss(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        head_indices: torch.Tensor,
        head_tags: torch.Tensor,
        mask: torch.Tensor,
        head_tag_temperature: Optional[float] = None,
        head_temperature: Optional[float] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        float_mask = mask.float()
        tag_mask = self._get_unknown_tag_mask(mask, head_tags)

        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        if head_temperature:
            attended_arcs /= head_temperature
        normalised_arc_logits = masked_log_softmax(
            attended_arcs,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        if head_tag_temperature:
            head_tag_logits /= head_tag_temperature
        normalised_head_tag_logits = masked_log_softmax(
            head_tag_logits,
            tag_mask.unsqueeze(-1)) * tag_mask.float().unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        tag_loss *= (head_tags > 1).float()
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        loss = arc_nll + tag_nll
        return loss, normalised_arc_logits, normalised_head_tag_logits
Beispiel #7
0
    def forward(self,
                query: torch.FloatTensor,
                key: torch.FloatTensor,
                edge_head_mask: torch.ByteTensor = None,
                gold_edge_heads: torch.Tensor = None) -> Dict:
        """
        :param query: [batch_size, query_length, query_vector_dim]
        :param key: [batch_size, key_length, key_vector_dim]
        :param edge_head_mask: [batch_size, query_length, key_length]
                        1 indicates a valid position; otherwise, 0.
        :param gold_edge_heads: None or [batch_size, query_length].
                        head indices start from 1.
        :return:
            edge_heads: [batch_size, query_length].
            edge_types: [batch_size, query_length].
            edge_head_ll: [batch_size, query_length, key_length + 1(sentinel)].
            edge_type_ll: [batch_size, query_length, num_labels] (based on gold_edge_head) or None.
        """
        if gold_edge_heads is not None:
            gold_edge_heads[gold_edge_heads == -1] = 0

        key, edge_head_mask = self._add_sentinel(query, key, edge_head_mask)
        edge_head_query, edge_head_key, edge_type_query, edge_type_key = self._mlp(
            query, key)
        # [batch_size, query_length, key_length + 1]
        edge_head_score = self._get_edge_head_score(edge_head_query,
                                                    edge_head_key)
        edge_heads, edge_types = self._greedy_search(edge_type_query,
                                                     edge_type_key,
                                                     edge_head_score,
                                                     edge_head_mask)

        if gold_edge_heads is None:
            gold_edge_heads = edge_heads
        # [batch_size, query_length, num_labels]
        edge_type_score = self._get_edge_type_score(edge_type_query,
                                                    edge_type_key, edge_heads)

        return dict(
            # Note: head indices start from 1.
            edge_heads=edge_heads,
            edge_types=edge_types,
            # Log-Likelihood.
            edge_head_ll=masked_log_softmax(edge_head_score,
                                            edge_head_mask,
                                            dim=2),
            edge_type_ll=masked_log_softmax(edge_type_score, None, dim=2),
            edge_head_query=edge_head_query,
            edge_head_key=edge_head_key,
            edge_type_query=edge_type_query,
            edge_type_key=edge_type_key)
Beispiel #8
0
def bidaf(curr_q: torch.Tensor, past_ctx_enc: torch.Tensor,
          past_ctx_emb: torch.Tensor, masks: torch.Tensor,
          att: MatrixAttention, men_scorer):
    """

    :param q: batch of queries (B, M)
    :param c: batch of contexts (B, N)
    :masks c: batch of masks (B, N)
    :param att: attention layer
    :return: tuple of
        q_hat - context aware query representation
        entropy - entropy of attention scores
    """
    # (B, N, M)
    q_c_att = att(curr_q, past_ctx_enc)  # (B, M, N)
    if men_scorer is not None:
        men_scores = men_scorer(past_ctx_enc)  # (B, N)
        men_scores_ex = men_scores.squeeze(-1).unsqueeze(1).expand_as(q_c_att)
        q_c_att = q_c_att + men_scores_ex
    else:
        men_scores = None
    sm_att = util.masked_softmax(q_c_att, masks)  # (B, M, N)
    log_sm_att = util.masked_log_softmax(q_c_att, masks)  # (B, M, N)
    entropy = torch.sum(-sm_att * log_sm_att, 2)  # (B, M)
    # print(past_ctx_emb.size(), sm_att.size())
    q_hat = util.weighted_sum(past_ctx_emb, sm_att)  # (B, M)
    return q_hat, entropy, sm_att, men_scores
Beispiel #9
0
    def forward(self, **kwargs) -> torch.FloatTensor:
        mask = kwargs['mask']
        embedded_text = kwargs['embedded_text']
        encoded_output = self._architecture(embedded_text, mask)
        encoded_repr = []
        for aggregation in self._aggregations:
            if aggregation == "meanpool":
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoded_output * broadcast_mask
                encoded_text = masked_mean(context_vectors,
                                           broadcast_mask,
                                           dim=1,
                                           keepdim=False)
            elif aggregation == 'maxpool':
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoded_output * broadcast_mask
                encoded_text = masked_max(context_vectors,
                                          broadcast_mask,
                                          dim=1)
            elif aggregation == 'final_state':
                is_bi = self._architecture.is_bidirectional()
                encoded_text = get_final_encoder_states(encoded_output,
                                                        mask,
                                                        is_bi)
            elif aggregation == 'attention':
                alpha = self._attention_layer(encoded_output)
                alpha = masked_log_softmax(alpha, mask.unsqueeze(-1), dim=1).exp()
                encoded_text = alpha * encoded_output
                encoded_text = encoded_text.sum(dim=1)
            else:
                raise ConfigurationError(f"{aggregation} aggregation not available.")
            encoded_repr.append(encoded_text)

        encoded_repr = torch.cat(encoded_repr, 1)
        return encoded_repr
Beispiel #10
0
 def forward(self, x, mask):
     """ x assumed to be logits"""
     # Softmax will convert logits to probabilities, this version is numerically stable
     b = util.masked_softmax(x, mask) * util.masked_log_softmax(x, mask)
     non_batch_dimensions = tuple(range(-len(b.shape) + 1, 0))
     b = -1.0 * b.sum(dim=non_batch_dimensions)
     return b.mean()
Beispiel #11
0
 def _get_next_state_info_without_agenda(
         state: NlvrDecoderState, considered_actions: List[List[int]],
         action_logits: torch.Tensor, action_mask: torch.Tensor
 ) -> List[List[Tuple[int, torch.LongTensor]]]:
     """
     We return a list of log probabilities corresponding to actions that are not padding. This
     method is related to the training scenario where we have target action sequences for
     training.
     """
     considered_action_logprobs = nn_util.masked_log_softmax(
         action_logits, action_mask)
     all_action_logprobs: List[List[Tuple[int, torch.LongTensor]]] = []
     for group_index, (score, considered_logprobs) in enumerate(
             zip(state.score, considered_action_logprobs)):
         instance_action_logprobs: List[Tuple[int, torch.Tensor]] = []
         for action_index, logprob in enumerate(considered_logprobs):
             # This is the actual index of the action from the original list of actions.
             action = considered_actions[group_index][action_index]
             if action == -1:
                 # Ignoring padding.
                 continue
             instance_action_logprobs.append(
                 (action_index, score + logprob))
         all_action_logprobs.append(instance_action_logprobs)
     return all_action_logprobs
Beispiel #12
0
    def _new_entity_loss(self,
                         encoded: torch.Tensor,
                         target_inds: torch.Tensor,
                         shortlist: torch.Tensor,
                         target_mask: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ==========
        target_inds : ``torch.Tensor``
            Either the shortlist inds if using shortlist, otherwise the target entity ids.
        """
        logits = self._new_entity_logits(encoded, shortlist)
        if self._use_shortlist:
            # Take masked softmax to get log probabilties and gather the targets.
            shortlist_mask = get_text_field_mask(shortlist)
            log_probs = masked_log_softmax(logits, shortlist_mask)
        else:
            logits = logits
            log_probs = F.log_softmax(logits, dim=-1)
            num_categories = log_probs.shape[-1]
            log_probs = log_probs.view(-1, num_categories)
            target_inds = target_inds.view(-1)
        target_log_probs = torch.gather(log_probs, -1, target_inds.unsqueeze(-1)).squeeze(-1)

        mask = ~target_inds.eq(0)
        target_log_probs[~mask] = 0

        if mask.any():
            self._new_entity_accuracy(predictions=log_probs[mask],
                                    gold_labels=target_inds[mask])
            self._new_entity_accuracy20(predictions=log_probs[mask],
                                        gold_labels=target_inds[mask])

        return -target_log_probs.sum() / (target_mask.sum() + 1e-13)
Beispiel #13
0
    def predict_labels_doc(self, output_dict):
        # Shape: (batch_size, num_spans_to_keep)
        coref_labels = output_dict["coref_labels"]
        coreference_scores = output_dict["coreference_scores"]
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict["predicted_antecedents"] = predicted_antecedents

        top_span_indices = output_dict["top_span_indices"]
        flat_top_span_indices = output_dict["flat_top_span_indices"]
        valid_antecedent_indices = output_dict["antecedent_indices"]
        valid_antecedent_log_mask = output_dict["valid_antecedent_log_mask"]
        top_spans = output_dict["top_spans"]
        top_span_mask = output_dict["top_span_mask"]
        metadata = output_dict["metadata"]
        sentence_lengths = output_dict["sentence_lengths"]

        if coref_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                coref_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            # There's an integer wrap-around happening here. It occurs in the original code.
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            # Need to get cluster data in same form as for original AllenNLP coref code so that the
            # evaluation code works.
            evaluation_metadata = self._make_evaluation_metadata(
                metadata, sentence_lengths)

            self._mention_recall(top_spans, evaluation_metadata)

            # TODO(dwadden) Shouldnt need to do the unsqueeze here; figure out what's happening.
            self._conll_coref_scores(top_spans,
                                     valid_antecedent_indices.unsqueeze(0),
                                     predicted_antecedents,
                                     evaluation_metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        return output_dict
    def forward(
            self,  # type: ignore
            question_passage: Dict[str, torch.LongTensor] = None,
            option: Dict[str, torch.LongTensor] = None,
            answer: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, option_num, option_length, _ = option[
            'token_characters'].size()
        for k in option.keys():
            if k != 'token_characters':
                option[k] = option[k].view(batch_size * option_num, -1)
            else:
                option[k] = option[k].view(batch_size * option_num,
                                           option_length, -1)
        # option['token_characters'] = option['token_characters'].view(batch_size*option_num, option_length, -1)

        question_bert = self._text_field_embedder(question_passage)
        option_bert = self._text_field_embedder(option)
        question_cls_rep = torch.index_select(question_bert, 1,
                                              torch.tensor([0]))
        option_cls_rep = torch.index_select(option_bert, 1, torch.tensor([0]))
        option_cls_rep = option_cls_rep.view(batch_size, option_num, -1)

        opt_logits = torch.sum(option_cls_rep * question_cls_rep, dim=-1)

        opt_predicts = torch.argmax(opt_logits, 1)

        predict_loss = nll_loss(util.masked_log_softmax(opt_logits, None),
                                answer)
        loss = predict_loss
        self._loss_trackers['loss'](loss)
        output_dict = {"loss": loss}

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            ids = []

            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_text'])
                passage_tokens.append(metadata[i]['original_passage'])
                # ids.append(metadata[i]['_id'])
                # passage_str = metadata[i]['original_passage']
                # offsets = metadata[i]['token_offsets']
                answer_text = metadata[i].get("answer_text")
                ans = answer[i]
                predict_ans = opt_predicts[i]
                # print(predict_ans, ans)
                output_dict['answer_texts'].append(answer_text)

                self._categorical_acc(opt_logits[i], ans)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['_id'] = ids

        return output_dict
Beispiel #15
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                context_span=None,
                gt_span=None,
                max_context_length=0,
                mode=ForwardMode.TRAIN):
        # Precomputing of the max_context_length is important
        # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible.
        sequence_output, _ = self.bert_encoder(input_ids,
                                               token_type_ids,
                                               attention_mask,
                                               output_all_encoded_layers=False)

        joint_seq_logits = self.qa_outputs(sequence_output)
        context_logits, context_length = span_util.span_select(
            joint_seq_logits, context_span, max_context_length)
        context_mask = allen_util.get_mask_from_sequence_lengths(
            context_length, max_context_length)

        # The following line is from AllenNLP bidaf.
        start_logits = allen_util.replace_masked_values(
            context_logits[:, :, 0], context_mask, -1e18)
        # B, T, 2
        end_logits = allen_util.replace_masked_values(context_logits[:, :, 1],
                                                      context_mask, -1e18)

        if mode == BertSpan.ForwardMode.TRAIN:
            assert gt_span is not None
            gt_start = gt_span[:, 0]  # gt_span: [B, 2]
            gt_end = gt_span[:, 1]

            start_loss = nll_loss(
                allen_util.masked_log_softmax(start_logits, context_mask),
                gt_start.squeeze(-1))
            end_loss = nll_loss(
                allen_util.masked_log_softmax(end_logits, context_mask),
                gt_end.squeeze(-1))

            loss = start_loss + end_loss
            return loss
        else:
            return start_logits, end_logits, context_length
Beispiel #16
0
    def forward(self, x, y, x_mask, y_mask):
        z_s = y[:, -1, :].unsqueeze(1)  # [B, 1, I]
        z_e = None
        s = None
        e = None
        p_s = None
        p_e = None

        for i in range(self.hop):
            z_s_ = z_s.repeat(1, x.size(1), 1)  # [B, S, I]
            s = self.FFNs_start[i](torch.cat([x, z_s_, x * z_s_],
                                             2)).squeeze(2)
            # s.data.masked_fill_(x_mask.data, -float('inf'))
            # p_s = F.softmax(s, dim=1)  # [B, S]
            p_s = util.masked_softmax(s, x_mask, dim=1)
            u_s = p_s.unsqueeze(1).bmm(x)  # [B, 1, I]
            z_e = self.SFUs_start[i](z_s, u_s)  # [B, 1, I]
            z_e_ = z_e.repeat(1, x.size(1), 1)  # [B, S, I]
            e = self.FFNs_end[i](torch.cat([x, z_e_, x * z_e_], 2)).squeeze(2)
            # e.data.masked_fill_(x_mask.data, -float('inf'))
            # p_e = F.softmax(e, dim=1)  # [B, S]
            p_e = util.masked_softmax(e, x_mask, dim=1)
            u_e = p_e.unsqueeze(1).bmm(x)  # [B, 1, I]
            z_s = self.SFUs_end[i](z_e, u_e)
        yesno = self.yesno_predictor(torch.cat([x, z_e_, x * z_e_], 2))
        if self.normalize:
            # if self.training:
            # In training we output log-softmax for NLL
            # p_s = F.log_softmax(s, dim=1)  # [B, S]
            p_s = util.masked_log_softmax(s, x_mask, dim=1)
            # p_e = F.log_softmax(e, dim=1)  # [B, S]
            p_e = util.masked_log_softmax(e, x_mask, dim=1)
            p_yesno = F.log_softmax(yesno, dim=2)
        # else:
        # ...Otherwise 0-1 probabilities
        # p_s = F.softmax(s, dim=1)  # [B, S]
        # p_e = F.softmax(e, dim=1)  # [B, S]
        # p_yesno = F.softmax(yesno, dim=2)
        else:
            p_s = s.exp()
            p_e = e.exp()
            p_yesno = yesno.exp()
        return p_s, p_e, p_yesno
Beispiel #17
0
    def loss(self, edge_scores: torch.Tensor, head_indices: torch.Tensor,
             mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the edge loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        edge_scores : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = edge_scores.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(edge_scores)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(
            edge_scores,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(edge_scores))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum()
        if self.normalize_wrt_seq_len:
            arc_nll /= valid_positions.float()
        return arc_nll
Beispiel #18
0
    def _parent_log_probs(self,
                          encoded_head: torch.Tensor,
                          entity_ids: torch.Tensor,
                          parent_ids: torch.Tensor) -> torch.Tensor:
        # Lookup recent entities (which are candidates for parents) and get their embeddings.
        candidate_ids, candidate_mask = self._recent_entities(entity_ids)
        logger.debug('Candidate ids shape: %s', candidate_ids.shape)
        candidate_embeddings = embedded_dropout(self._entity_embedder,
                                                words=candidate_ids,
                                                dropout=self._dropoute if self.training else 0)

        # Logits are computed using a general bilinear form that measures the similarity between
        # the projected hidden state and the embeddings of candidate entities
        encoded = self._locked_dropout(encoded_head, self._dropout)
        selection_logits = torch.bmm(
            encoded, candidate_embeddings.transpose(1, 2))

        # Get log probabilities using masked softmax (need to double check mask works properly).

        # shape: (batch_size, sequence_length, num_candidates)
        log_probs = masked_log_softmax(selection_logits, candidate_mask)

        # Now for the tricky part. We need to convert the parent ids to a mask that selects the
        # relevant probabilities from log_probs. To do this we need to align the candidates with
        # the parent ids, which can be achieved by an element-wise equality comparison. We also
        # need to ensure that null parents are not selected.

        # shape: (batch_size, sequence_length, num_parents, 1)
        _parent_ids = parent_ids.unsqueeze(-1)

        batch_size, num_candidates = candidate_ids.shape
        # shape: (batch_size, 1, 1, num_candidates)
        _candidate_ids = candidate_ids.view(batch_size, 1, 1, num_candidates)

        # shape: (batch_size, sequence_length, num_parents, num_candidates)
        is_parent = _parent_ids.eq(_candidate_ids)
        # shape: (batch_size, 1, 1, num_candidates)
        non_null = ~_candidate_ids.eq(0)

        # Since multiplication is addition in log-space, we can apply mask by adding its log (+
        # some small constant for numerical stability).
        mask = is_parent & non_null
        masked_log_probs = log_probs.unsqueeze(
            2) + (mask.float() + 1e-45).log()
        logger.debug('Masked log probs shape: %s', masked_log_probs.shape)

        # Lastly, we need to get rid of the num_candidates dimension. The easy way to do this would
        # be to marginalize it out. However, since our data is sparse (the last two dims are
        # essentially a delta function) this would add a lot of unneccesary terms to the computation graph.
        # To get around this we are going to try to use a gather.
        _, index = torch.max(mask, dim=-1, keepdim=True)
        target_log_probs = torch.gather(
            masked_log_probs, dim=-1, index=index).squeeze(-1)

        return target_log_probs
    def score_spans_if_labels(
        self,
        output_dict,
        span_labels,
        metadata,
        top_span_indices,
        flat_top_span_indices,
        top_span_mask,
        top_spans,
        valid_antecedent_indices,
        valid_antecedent_log_mask,
        coreference_scores,
        predicted_antecedents,
    ):
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Beispiel #20
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                gt_span=None,
                mode=ForwardMode.TRAIN):
        sequence_output, _ = self.bert_encoder(input_ids,
                                               token_type_ids,
                                               attention_mask,
                                               output_all_encoded_layers=False)
        joint_length = allen_util.get_lengths_from_binary_sequence_mask(
            attention_mask)

        joint_seq_logits = self.qa_outputs(sequence_output)

        # The following line is from AllenNLP bidaf.
        start_logits = allen_util.replace_masked_values(
            joint_seq_logits[:, :, 0], attention_mask, -1e18)
        # B, T, 2
        end_logits = allen_util.replace_masked_values(
            joint_seq_logits[:, :, 1], attention_mask, -1e18)

        if mode == BertSpan.ForwardMode.TRAIN:
            assert gt_span is not None
            gt_start = gt_span[:, 0]  # gt_span: [B, 2] -> [B]
            gt_end = gt_span[:, 1]

            start_loss = nll_loss(
                allen_util.masked_log_softmax(start_logits, attention_mask),
                gt_start)
            end_loss = nll_loss(
                allen_util.masked_log_softmax(end_logits, attention_mask),
                gt_end)
            # We delete squeeze bc it will cause problem when the batch size is 1, and remember the gt_start and gt_end should have shape [B].
            # start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1))
            # end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1))

            loss = start_loss + end_loss
            return loss
        else:
            return start_logits, end_logits, joint_length
Beispiel #21
0
    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)
        fill_value = -1e32
        output_projections = self._adjust_target_outputs(
            output_projections,
            state["allowed_token_ids"],
            state["allowed_mask"],
            fill_value=fill_value)

        # shape: (group_size, num_classes)
        mask = (output_projections != fill_value).float()
        class_log_probabilities = util.masked_log_softmax(output_projections,
                                                          mask,
                                                          dim=-1)
        # class_log_probabilities = F.log_softmax(output_projections, dim=-1)
        # class_log_probabilities = class_log_probabilities * mask

        return class_log_probabilities, state
    def _arithmetic_module(self, arithmetic_passage_vector, passage_out, number_indices, number_mask):
        if self.number_rep in ['average', 'attention']:
            
            # Shape: (batch_size, # of numbers, # of pieces) 
            number_indices = util.replace_masked_values(number_indices, number_indices != -1, 0).long()
            batch_size = number_indices.shape[0]
            num_numbers = number_indices.shape[1]
            seqlen = passage_out.shape[1]

            # Shape : (batch_size, # of numbers, seqlen)
            mask = torch.zeros((batch_size, num_numbers, seqlen), device=number_indices.device).long().scatter(
                        2, 
                        number_indices, 
                        torch.ones(number_indices.shape, device=number_indices.device).long())
            mask[:,:,0] = 0

            # Shape : (batch_size, # of numbers, seqlen, bert_dim)
            epassage_out = passage_out.unsqueeze(1).repeat(1,num_numbers,1,1)

            # Shape : (batch_size, # of numbers, bert_dim)
            encoded_numbers = self.summary_vector(epassage_out, mask, "numbers")
        else:
            number_indices = number_indices[:,:,0].long()
            clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0)
            encoded_numbers = torch.gather(
                    passage_out,
                    1,
                    clamped_number_indices.unsqueeze(-1).expand(-1, -1, passage_out.size(-1)))
            
        if self.num_special_numbers > 0:
            special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device))
            special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1)
            encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)
            
            mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long()
            number_mask = torch.cat([mask, number_mask], -1)
            
        # Shape: (batch_size, # of numbers, 2*bert_dim)
        encoded_numbers = torch.cat(
                [encoded_numbers, arithmetic_passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)
        
        # Shape: (batch_size, #templates, #slots, #numbers)
        arithmetic_template_slot_logits = self._arithmetic_template_slot_predictor(encoded_numbers).transpose(1,2)
        arithmetic_template_slot_log_probs = util.masked_log_softmax(arithmetic_template_slot_logits, number_mask)
        arithmetic_template_slot_log_probs = arithmetic_template_slot_log_probs.reshape(number_mask.shape[0],
                                                                                        self.num_arithmetic_templates, 
                                                                                        self.num_template_slots, 
                                                                                        number_mask.shape[-1])
        # Shape: (batch_size, #templates, #slots)
        arithmetic_best_template_slots = arithmetic_template_slot_log_probs.argmax(-1)
        return arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask
Beispiel #23
0
    def loss(self, edge_label_logits: torch.Tensor, mask: torch.Tensor,
             head_tags: torch.Tensor) -> torch.Tensor:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        edge_label_logits : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            that contains raw predictions for incoming edge labels
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the edge label loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = edge_label_logits.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(edge_label_logits)).unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        normalised_edge_label_logits = masked_log_softmax(
            edge_label_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(edge_label_logits))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        tag_loss = normalised_edge_label_logits[range_vector, child_index,
                                                head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        if self.normalize_wrt_seq_len:
            return -tag_loss.sum() / valid_positions.float()
        else:
            return -tag_loss.sum()
Beispiel #24
0
    def _construct_loss(
            self, head_tag_representation: torch.Tensor,
            child_tag_representation: torch.Tensor,
            attended_arcs: torch.Tensor, head_indices: torch.Tensor,
            head_tags: torch.Tensor,
            mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        float_mask = mask.float()

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(
            attended_arcs,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        normalised_head_tag_logits = torch.nn.functional.log_softmax(
            head_tag_logits, dim=-1) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
Beispiel #25
0
    def forward(self, frames: torch.FloatTensor, frame_lens: torch.LongTensor):
        """
        frames: (batch_size, seq_len, num_lmks, lmk_dim)
        frame_lens: (batch_size, )
        """
        if self.frame_processing == 'flatten':
            frames = frames.reshape(frames.shape[0], frames.shape[1], -1)

        # Reverse sorts the batch by unpadded seq_len.
        (sorted_frames, sorted_frame_lens, restoration_indices,
         _) = sort_batch_by_length(frames, frame_lens)

        # Returns a PackedSequence.
        packed_frames = nn.utils.rnn.pack_padded_sequence(
            sorted_frames,
            sorted_frame_lens.data.cpu().numpy()
            if sorted_frame_lens.is_cuda else sorted_frame_lens.data.numpy(),
            batch_first=True)

        # Encoder: feed frames to the model, output hidden states.
        # final_state: (num_layers * num_dir, batch_size, hidden_size) (*2 if LSTM)
        packed_hidden_states, final_state = self.rnn(packed_frames)

        # Unpack encoding, the hidden states, a Tensor.
        # (batch_size, seq_len, num_dir * hidden_size)
        hidden_states, _ = nn.utils.rnn.pad_packed_sequence(
            packed_hidden_states, batch_first=True)

        # (num_layers, batch_size, hidden_size * num_dir) (*2 if LSTM)
        if self.bidirectional:
            final_state = self._cat_directions(final_state)

        hidden_states = hidden_states.index_select(0, restoration_indices)
        if isinstance(final_state, tuple):  # LSTM
            final_state = (final_state[0].index_select(1, restoration_indices),
                           final_state[1].index_select(1, restoration_indices))
        else:
            final_state = final_state.index_select(1, restoration_indices)

        if self.enable_ctc:
            output_logits = self.output_proj(hidden_states)
            output_log_probs = masked_log_softmax(output_logits,
                                                  self.output_mask.expand(
                                                      output_logits.shape[0],
                                                      self.adj_vocab_size),
                                                  dim=-1)
            return output_log_probs, hidden_states, final_state
        else:
            return hidden_states, final_state
Beispiel #26
0
    def _compute_loss(self, logits: torch.Tensor, mask: torch.Tensor,
                      targets: torch.Tensor) -> torch.Tensor:
        """

        :param logits:
        :param mask:
        :param targets:
        :return:
        """

        batch_size, seq_len = mask.shape
        normalised_emission = masked_log_softmax(logits,
                                                 mask.unsqueeze(-1),
                                                 dim=-1)
        loss = normalised_emission.gather(dim=-1, index=targets.unsqueeze(-1))
        return -1 * loss.sum() / batch_size
    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes + max_input_sequence_length)
        output_projections, state = self._prepare_output_projections(last_predictions, state)

        source_mask = state['source_mask']
        group_size = source_mask.size(0)

        # (batch_size, num_classes + max_input_sequence_length)
        normalization_mask = torch.cat([source_mask.new_ones((group_size, self._num_classes)),
                                        source_mask], dim=-1)

        # shape: (group_size, num_classes + max_input_sequence_length)
        class_log_probabilities = util.masked_log_softmax(output_projections, normalization_mask, dim=-1)

        return class_log_probabilities, state
Beispiel #28
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, answer_choice=None,
                sentence_span_list=None, sentence_ids=None):
        sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view(batch, 1, -1)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)

        doc_vecs = layers.weighted_avg(doc, self.doc_sen_self_attn(doc, doc_mask)).view(batch, max_sen, -1)

        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_hidden = masked_softmax(sentence_sim, 1 - sentence_mask).bmm(doc_vecs).squeeze(1)

        yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1).squeeze_(1)
        output_dict = {'yesno_logits': yesno_logits,
                       'sentence_logits': sentence_scores,
                       'max_weight_index': sentence_scores.max(dim=1)[1],
                       'max_weight': sentence_scores.max(dim=1)[0]}
        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1)
            loss += choice_loss
        if sentence_ids is not None:
            log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1)
            sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1)
            loss += sentence_loss
        output_dict['loss'] = loss
        return output_dict
Beispiel #29
0
 def _get_next_state_info_without_agenda(state: NlvrDecoderState,
                                         considered_actions: List[List[int]],
                                         action_logits: torch.Tensor,
                                         action_mask: torch.Tensor) -> List[List[Tuple[int,
                                                                                       torch.LongTensor]]]:
     """
     We return a list of log probabilities corresponding to actions that are not padding. This
     method is related to the training scenario where we have target action sequences for
     training.
     """
     considered_action_logprobs = nn_util.masked_log_softmax(action_logits, action_mask)
     all_action_logprobs: List[List[Tuple[int, torch.LongTensor]]] = []
     for group_index, (score, considered_logprobs) in enumerate(zip(state.score,
                                                                    considered_action_logprobs)):
         instance_action_logprobs: List[Tuple[int, torch.Tensor]] = []
         for action_index, logprob in enumerate(considered_logprobs):
             # This is the actual index of the action from the original list of actions.
             action = considered_actions[group_index][action_index]
             if action == -1:
                 # Ignoring padding.
                 continue
             instance_action_logprobs.append((action_index, score + logprob))
         all_action_logprobs.append(instance_action_logprobs)
     return all_action_logprobs
Beispiel #30
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        # encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))

        # # v5:
        # # remember to set token embeddings in the CONFIG JSON
        encoded_question = self._dropout(embedded_question)

        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX
        similarity_matrix = self._matrix_attention(encoded_passage,
                                                   encoded_question)

        # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY
        passage_question_attention = util.last_dim_softmax(
            similarity_matrix, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Our custom query2context
        q2c_attention = util.masked_softmax(similarity_matrix,
                                            question_mask,
                                            dim=1).transpose(-1, -2)
        q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention)

        # Now we try the various variants
        # v1:
        # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention)

        # v2:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1]))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2)

        # v3:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim)

        # v4:
        # Re-application of query2context attention
        # new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs)
        # masked_similarity = util.replace_masked_values(new_similarity_matrix,
        #                                                question_mask.unsqueeze(1),
        #                                                -1e7)
        # # Shape: (batch_size, passage_length)
        # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # # Shape: (batch_size, passage_length)
        # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # # Shape: (batch_size, encoding_dim)
        # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # # Shape: (batch_size, passage_length, encoding_dim)
        # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
        #                                                                             passage_length,
        #                                                                             encoding_dim)

        # ------- Original variant
        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            similarity_matrix, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # ------- END

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        # original beta combination function
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        # # v6:
        # final_merged_passage = torch.cat([tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v7:
        # final_merged_passage = torch.cat([passage_question_vectors],
        #                                  dim=-1)
        #
        # # v8:
        # final_merged_passage = torch.cat([passage_question_vectors,
        #                                   tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v9:
        # final_merged_passage = torch.cat([encoded_passage,
        #                                   passage_question_vectors,
        #                                   encoded_passage * passage_question_vectors],
        #                                  dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Beispiel #31
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Beispiel #32
0
    def _construct_loss(self,
                        head_tag_representation: torch.Tensor,
                        child_tag_representation: torch.Tensor,
                        attended_arcs: torch.Tensor,
                        head_indices: torch.Tensor,
                        head_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(attended_arcs,
                                                   mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
        normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
                                                        mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
Beispiel #33
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
best_span = BidirectionalAttentionFlow_1.get_best_span(span_start_logits, span_end_logits)

print ("best_spans", best_span)

"""
------------------------------ GET LOSES AND ACCURACIES -----------------------------------
"""
span_start_accuracy_function = CategoricalAccuracy()
span_end_accuracy_function = CategoricalAccuracy()
span_accuracy_function = BooleanAccuracy()
squad_metrics_function = SquadEmAndF1()

# Compute the loss for training.
if span_start is not None:
    span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
    span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
    loss = span_start_loss + span_end_loss
    
    span_start_accuracy_function(span_start_logits, span_start.squeeze(-1))
    span_end_accuracy_function(span_end_logits, span_end.squeeze(-1))
    span_accuracy_function(best_span, torch.stack([span_start, span_end], -1))

    span_start_accuracy = span_start_accuracy_function.get_metric()
    span_end_accuracy =  span_end_accuracy_function.get_metric()
    span_accuracy = span_accuracy_function.get_metric()


    print ("Loss: ", loss)
    print ("span_start_accuracy: ", span_start_accuracy)
    print ("span_start_accuracy: ", span_start_accuracy)
    def _construct_loss(self,
                        head_tag_representation: torch.Tensor,
                        child_tag_representation: torch.Tensor,
                        attended_arcs: torch.Tensor,
                        head_indices: torch.Tensor,
                        head_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(attended_arcs,
                                                   mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
        normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
                                                        mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
    def _take_first_step(self,
                         state: WikiTablesDecoderState,
                         allowed_actions: List[Set[int]] = None) -> List[WikiTablesDecoderState]:
        # We'll just do a projection from the current hidden state (which was initialized with the
        # final encoder output) to the number of start actions that we have, normalize those
        # logits, and use that as our score.  We end up duplicating some of the logic from
        # `_compute_new_states` here, but we do things slightly differently, and it's easier to
        # just copy the parts we need than to try to re-use that code.

        # (group_size, hidden_dim)
        hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])
        # (group_size, num_start_type)
        start_action_logits = self._start_type_predictor(hidden_state)
        log_probs = util.masked_log_softmax(start_action_logits, None)
        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True)

        sorted_actions = sorted_actions.detach().cpu().numpy().tolist()
        if state.debug_info is not None:
            probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()

        # state.get_valid_actions() will return a list that is consistently sorted, so as along as
        # the set of valid start actions never changes, we can just match up the log prob indices
        # above with the position of each considered action, and we're good.
        considered_actions, _, _ = self._get_actions_to_consider(state)
        if len(considered_actions[0]) != self._num_start_types:
            raise RuntimeError("Calculated wrong number of initial actions.  Expected "
                               f"{self._num_start_types}, found {len(considered_actions[0])}.")

        best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list)
        for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)):
            for action_index, action in enumerate(group_actions):
                # `action` is currently the index in `log_probs`, not the actual action ID.  To get
                # the action ID, we need to go through `considered_actions`.
                action = considered_actions[group_index][action]
                if allowed_actions is not None and action not in allowed_actions[group_index]:
                    # This happens when our _decoder trainer_ wants us to only evaluate certain
                    # actions, likely because they are the gold actions in this state.  We just skip
                    # emitting any state that isn't allowed by the trainer, because constructing the
                    # new state can be expensive.
                    continue
                best_next_states[batch_index].append((group_index, action_index, action))

        new_states = []
        for batch_index, best_states in sorted(best_next_states.items()):
            for group_index, action_index, action in best_states:
                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the
                # learning algorithm can decide how many of these it wants to keep, and it can just
                # regroup them later, as that's a really easy operation.
                batch_index = state.batch_indices[group_index]
                new_action_history = state.action_history[group_index] + [action]
                new_score = state.score[group_index] + sorted_log_probs[group_index, action_index]

                production_rule = state.possible_actions[batch_index][action][0]
                new_grammar_state = state.grammar_state[group_index].take_action(production_rule)
                new_checklist_state = [state.checklist_state[group_index]]
                if state.debug_info is not None:
                    debug_info = {
                            'considered_actions': considered_actions[group_index],
                            'probabilities': probs_cpu[group_index],
                            }
                    new_debug_info = [state.debug_info[group_index] + [debug_info]]
                else:
                    new_debug_info = None

                # This part is different from `_compute_new_states` - we're just passing through
                # the previous RNN state, as predicting the start type wasn't included in the
                # decoder RNN in the original model.
                new_rnn_state = state.rnn_state[group_index]
                new_state = WikiTablesDecoderState(batch_indices=[batch_index],
                                                   action_history=[new_action_history],
                                                   score=[new_score],
                                                   rnn_state=[new_rnn_state],
                                                   grammar_state=[new_grammar_state],
                                                   action_embeddings=state.action_embeddings,
                                                   output_action_embeddings=state.output_action_embeddings,
                                                   action_biases=state.action_biases,
                                                   action_indices=state.action_indices,
                                                   possible_actions=state.possible_actions,
                                                   flattened_linking_scores=state.flattened_linking_scores,
                                                   actions_to_entities=state.actions_to_entities,
                                                   entity_types=state.entity_types,
                                                   world=state.world,
                                                   example_lisp_string=state.example_lisp_string,
                                                   checklist_state=new_checklist_state,
                                                   debug_info=new_debug_info)
                new_states.append(new_state)
        return new_states
    def take_step(self,
                  state: WikiTablesDecoderState,
                  max_actions: int = None,
                  allowed_actions: List[Set[int]] = None) -> List[WikiTablesDecoderState]:
        if not state.action_history[0]:
            # The wikitables parser did something different when predicting the start type, which
            # is our first action.  So in this case we break out into a different function.  We'll
            # ignore max_actions on our first step, assuming there aren't that many start types.
            return self._take_first_step(state, allowed_actions)
        # Outline here: first we'll construct the input to the decoder, which is a concatenation of
        # an embedding of the decoder input (the last action taken) and an attention over the
        # question.  Then we'll update our decoder's hidden state given this input, and recompute an
        # attention over the question given our new hidden state.  We'll use a concatenation of the
        # new hidden state and the new attention, and optionally the checklist balance, to predict an
        # output, then yield new states.  Each new state corresponds to one valid action that can be
        # taken from the current state, and they are ordered by their probability of being selected.
        attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state])
        hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])
        memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state])
        previous_action_embedding = torch.stack([rnn_state.previous_action_embedding
                                                 for rnn_state in state.rnn_state])

        # The scores from all prior state transitions until now.  Shape: (group_size, 1).
        scores_so_far = torch.stack(state.score).unsqueeze(-1)

        # (group_size, decoder_input_dim)
        projected_input = self._input_projection_layer(torch.cat([attended_question,
                                                                  previous_action_embedding], -1))
        decoder_input = torch.nn.functional.relu(projected_input)

        hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))
        hidden_state = self._dropout(hidden_state)

        # (group_size, encoder_output_dim)
        encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices])
        encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices])
        attended_question, attention_weights = self.attend_on_question(hidden_state,
                                                                       encoder_outputs,
                                                                       encoder_output_mask)
        action_query = torch.cat([hidden_state, attended_question], dim=-1)
        considered_actions, actions_to_embed, actions_to_link = self._get_actions_to_consider(state)
        # action_embeddings: (group_size, num_embedded_actions, action_embedding_dim)
        # output_action_embeddings: (group_size, num_embedded_actions, action_embedding_dim)
        # action_mask: (group_size, num_embedded_actions)
        action_embeddings, output_action_embeddings, action_biases, embedded_action_mask = \
                self._get_action_embeddings(state, actions_to_embed)
        # (group_size, action_embedding_dim)
        projected_query = torch.nn.functional.relu(self._output_projection_layer(action_query))
        predicted_action_embedding = self._dropout(projected_query)
        linked_balance = None
        if state.checklist_state[0] is not None:
            linked_balance, unlinked_balance = self._get_checklist_balance(state,
                                                                           self._unlinked_terminal_indices,
                                                                           actions_to_link)
            embedding_addition = self._get_predicted_embedding_addition(state,
                                                                        self._unlinked_terminal_indices,
                                                                        unlinked_balance)
            addition = embedding_addition * self._unlinked_checklist_multiplier
            predicted_action_embedding = predicted_action_embedding + addition

        # We'll do a batch dot product here with `bmm`.  We want `dot(predicted_action_embedding,
        # action_embedding)` for each `action_embedding`, and we can get that efficiently with
        # `bmm` and some squeezing.
        # Shape: (group_size, num_embedded_actions)
        embedded_action_logits = action_embeddings.bmm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
        embedded_action_logits = embedded_action_logits + action_biases.squeeze(-1)
        if actions_to_link:
            # entity_action_logits: (group_size, num_entity_actions)
            # entity_action_mask: (group_size, num_entity_actions)
            entity_action_logits, entity_action_mask, entity_type_embeddings = \
                    self._get_entity_action_logits(state,
                                                   actions_to_link,
                                                   attention_weights,
                                                   linked_balance)

            # The `output_action_embeddings` tensor gets used later as the input to the next
            # decoder step.  For linked actions, we don't have any action embedding, so we use the
            # entity type instead.
            output_action_embeddings = torch.cat([output_action_embeddings, entity_type_embeddings], dim=1)

            if self._mixture_feedforward is not None:
                # The entity and action logits are combined with a mixture weight to prevent the
                # entity_action_logits from dominating the embedded_action_logits if a softmax
                # was applied on both together.
                mixture_weight = self._mixture_feedforward(hidden_state)
                mix1 = torch.log(mixture_weight)
                mix2 = torch.log(1 - mixture_weight)

                entity_action_probs = util.masked_log_softmax(entity_action_logits,
                                                              entity_action_mask.float()) + mix1
                embedded_action_probs = util.masked_log_softmax(embedded_action_logits,
                                                                embedded_action_mask.float()) + mix2
                current_log_probs = torch.cat([embedded_action_probs, entity_action_probs], dim=1)
            else:
                action_logits = torch.cat([embedded_action_logits, entity_action_logits], dim=1)
                action_mask = torch.cat([embedded_action_mask, entity_action_mask], dim=1).float()
                current_log_probs = util.masked_log_softmax(action_logits, action_mask)
        else:
            action_logits = embedded_action_logits
            action_mask = embedded_action_mask.float()
            current_log_probs = util.masked_log_softmax(action_logits, action_mask)

        # current_log_probs is shape (group_size, num_actions).  We're broadcasting an addition
        # here with scores_so_far, which has shape (group_size, 1).  This is now the total score
        # for each state after taking each action.  We're going to sort by this in
        # `_compute_new_states`, so it's important that this is the total score, not just the score
        # for the current action.
        log_probs = scores_so_far + current_log_probs

        return self._compute_new_states(state,
                                        log_probs,
                                        hidden_state,
                                        memory_cell,
                                        output_action_embeddings,
                                        attended_question,
                                        attention_weights,
                                        considered_actions,
                                        allowed_actions,
                                        max_actions)
    def forward(
        self,  # type: ignore
        text: TextFieldTensors,
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters

        text : `TextFieldTensors`, required.
            The output of a `TextField` representing the text of
            the document.
        spans : `torch.IntTensor`, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of
            indices into the text of the document.
        span_labels : `torch.IntTensor`, optional (default = None).
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : `List[Dict[str, Any]]`, optional (default = None).
            A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys
            from this dictionary, which respectively have the original text and the annotated gold coreference
            clusters for that instance.

        # Returns

        An output dictionary consisting of:
        top_spans : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : `torch.IntTensor`
            A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        batch_size = spans.size(0)
        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text)

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)
        ).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep
        )

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices)
        # Shape: (batch_size, num_spans_to_keep, embedding_size)
        top_span_embeddings = util.batched_index_select(
            span_embeddings, top_span_indices, flat_top_span_indices
        )

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        # index to the indices of its allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        # we can use to make coreference decisions between valid span pairs.

        if self._coarse_to_fine:
            pruned_antecedents = self._coarse_to_fine_pruning(
                top_span_embeddings, top_span_mention_scores, top_span_mask, max_antecedents
            )
        else:
            pruned_antecedents = self._distance_pruning(
                top_span_embeddings, top_span_mention_scores, max_antecedents
            )

        # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors
        (
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
            top_antecedent_indices,
        ) = pruned_antecedents

        flat_top_antecedent_indices = util.flatten_and_batch_shift_indices(
            top_antecedent_indices, num_spans_to_keep
        )

        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        top_antecedent_embeddings = util.batched_index_select(
            top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
        )
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            top_span_embeddings,
            top_antecedent_embeddings,
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
        )

        for _ in range(self._inference_order - 1):
            dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1)
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,)
            top_antecedent_with_dummy_mask = torch.cat([dummy_mask, top_antecedent_mask], -1)
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
            attention_weight = util.masked_softmax(
                coreference_scores, top_antecedent_with_dummy_mask, memory_efficient=True
            )
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size)
            top_antecedent_with_dummy_embeddings = torch.cat(
                [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2
            )
            # Shape: (batch_size, num_spans_to_keep, embedding_size)
            attended_embeddings = util.weighted_sum(
                top_antecedent_with_dummy_embeddings, attention_weight
            )
            # Shape: (batch_size, num_spans_to_keep, embedding_size)
            top_span_embeddings = self._span_updating_gated_sum(
                top_span_embeddings, attended_embeddings
            )

            # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
            top_antecedent_embeddings = util.batched_index_select(
                top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
            )
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
            coreference_scores = self._compute_coreference_scores(
                top_span_embeddings,
                top_antecedent_embeddings,
                top_partial_coreference_scores,
                top_antecedent_mask,
                top_antecedent_offsets,
            )

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": top_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices
            )

            # Shape: (batch_size, num_spans_to_keep, max_antecedents)
            antecedent_labels = util.batched_index_select(
                pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices
            ).squeeze(-1)
            antecedent_labels = util.replace_masked_values(
                antecedent_labels, top_antecedent_mask, -100
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels
            )
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask.unsqueeze(-1)
            )
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(
                top_spans, top_antecedent_indices, predicted_antecedents, metadata
            )

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Beispiel #39
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict