예제 #1
0
    def forward(
            self,  # type: ignore
            spans: torch.IntTensor,
            span_mask: torch.IntTensor,
            span_embeddings: torch.IntTensor,
            sentence_lengths: torch.Tensor,
            ner_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings

        self._active_namespace = f"{metadata.dataset}__ner_labels"
        if self._active_namespace not in self._ner_scorers:
            return {"loss": 0}

        scorer = self._ner_scorers[self._active_namespace]

        ner_scores = scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        ner_scores = util.replace_masked_values(ner_scores, mask.bool(), -1e20)
        # The dummy_scores are the score for the null label.
        dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1]
        dummy_scores = ner_scores.new_zeros(*dummy_dims)
        ner_scores = torch.cat((dummy_scores, ner_scores), -1)

        _, predicted_ner = ner_scores.max(2)

        predictions = self.predict(ner_scores.detach().cpu(),
                                   spans.detach().cpu(),
                                   span_mask.detach().cpu(), metadata)
        output_dict = {"predictions": predictions}

        if ner_labels is not None:
            metrics = self._ner_metrics[self._active_namespace]
            metrics(predicted_ner, ner_labels, span_mask)
            ner_scores_flat = ner_scores.view(
                -1, self._n_labels[self._active_namespace])
            ner_labels_flat = ner_labels.view(-1)
            mask_flat = span_mask.view(-1).bool()

            loss = self._loss(ner_scores_flat[mask_flat],
                              ner_labels_flat[mask_flat])

            output_dict["loss"] = loss

        return output_dict
예제 #2
0
    def forward(
            self,  # type: ignore
            text: Dict[str, Any],
            text_mask: torch.IntTensor,
            token_embeddings: torch.IntTensor,
            sentence_lengths: torch.Tensor,
            token_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        seq_scores = self._seq_scorer(token_embeddings)
        # Give large negative scores to masked-out elements.
        mask = text_mask.unsqueeze(-1)
        seq_scores = util.replace_masked_values(seq_scores, mask, -1e20)
        seq_scores[:, :, 0] *= text_mask

        _, predicted_seq = seq_scores.max(2)

        if self._label_scheme == 'flat':
            pred_spans = self._seq_metrics._decode_flat(
                predicted_seq, text_mask)
        elif self._label_scheme == 'stacked':
            pred_spans = self._seq_metrics._decode_stacked(
                predicted_seq, text_mask)
        else:
            raise RuntimeError("invalid label_scheme {}".format(
                self.label_scheme))

        output_dict = {
            "predicted_seq": predicted_seq,
            "predicted_seq_span": pred_spans
        }

        if token_labels is not None:
            self._seq_metrics(predicted_seq, token_labels, text_mask,
                              self.training)
            seq_scores_flat = seq_scores.view(-1, self._n_labels)
            seq_labels_flat = token_labels.view(-1)
            mask_flat = text_mask.view(-1).bool()

            loss = self._loss(seq_scores_flat[mask_flat],
                              seq_labels_flat[mask_flat])
            output_dict["loss"] = loss

        return output_dict
예제 #3
0
파일: ner.py 프로젝트: zhangqixun/dygiepp
    def forward(
            self,  # type: ignore
            spans: torch.IntTensor,
            span_mask: torch.IntTensor,
            span_embeddings: torch.IntTensor,
            sentence_lengths: torch.Tensor,
            ner_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings
        ner_scores = self._ner_scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        ner_scores = util.replace_masked_values(ner_scores, mask, -1e20)
        # The dummy_scores are the score for the null label.
        dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1]
        dummy_scores = ner_scores.new_zeros(*dummy_dims)
        ner_scores = torch.cat((dummy_scores, ner_scores), -1)

        _, predicted_ner = ner_scores.max(2)

        output_dict = {
            "spans": spans,
            "span_mask": span_mask,
            "ner_scores": ner_scores,
            "predicted_ner": predicted_ner
        }

        if ner_labels is not None:
            self._ner_metrics(predicted_ner, ner_labels, span_mask)
            ner_scores_flat = ner_scores.view(-1, self._n_labels)
            ner_labels_flat = ner_labels.view(-1)
            mask_flat = span_mask.view(-1).bool()

            loss = self._loss(ner_scores_flat[mask_flat],
                              ner_labels_flat[mask_flat])
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["document"] = [x["sentence"] for x in metadata]

        return output_dict
예제 #4
0
    def forward(
        self,  # type: ignore
        tokens: Dict[str, torch.LongTensor] = None,
        label: torch.IntTensor = None,
        weight: torch.FloatTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        weight: torch.FloatTensor, optional (default = None)
            Weights to apply to each sample loss.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the text.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded = self._bert(tokens)
        first_token = embedded[:, 0, :]
        pooled_output = self._pooler(first_token)
        pooled_output = self._dropout(pooled_output)

        label_logits = self._classifier(pooled_output)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits.view(-1, self._num_labels),
                              label.view(-1))
            if self.weighted_training:
                loss = (loss * weight).mean()

            self._accuracy(label_logits, label)
            if 'f1' in self.metrics:
                self._f1(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
예제 #5
0
    def forward(
        self,  # type: ignore
        tokens: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the text.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded = self._bert(tokens)

        first_token = embedded[:, 0, :]
        pooled_first = self._pooler(first_token)
        pooled_first = self._dropout(pooled_first)

        mask = tokens['mask'].float()
        encoded = self._encoder(embedded, mask)
        encoded = self._dropout(encoded)
        pooled_encoded = masked_max(encoded, mask.unsqueeze(-1), dim=1)

        concat = torch.cat([pooled_first, pooled_encoded], dim=-1)
        label_logits = self._classifier(concat)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits.view(-1, self._num_labels),
                              label.view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
예제 #6
0
    def forward(
            self,  # type: ignore
            spans: torch.IntTensor,
            span_mask: torch.IntTensor,
            span_embeddings: torch.IntTensor,
            sentence_lengths: torch.Tensor,
            span_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings
        span_scores = self._span_scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        span_scores = util.replace_masked_values(span_scores, mask, -1e20)
        span_scores[:, :, 0] *= span_mask

        _, predicted_span = span_scores.max(2)

        output_dict = {
            "spans": spans,
            "span_mask": span_mask,
            "span_scores": span_scores,
            "predicted_span": predicted_span
        }

        if span_labels is not None:
            self._span_metrics(predicted_span, span_labels, span_mask)
            span_scores_flat = span_scores.view(-1, self._n_labels)
            span_labels_flat = span_labels.view(-1)
            mask_flat = span_mask.view(-1).bool()

            loss = self._loss(span_scores_flat[mask_flat],
                              span_labels_flat[mask_flat])
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["document"] = [x["sentence"] for x in metadata]

        return output_dict
예제 #7
0
    def forward(
        self,  # type: ignore
        s1: Dict[str, torch.LongTensor] = None,
        s2: Dict[str, torch.LongTensor] = None,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the text.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_s1 = self._bert(s1)
        embedded_s2 = self._bert(s2)
        s1_cls = embedded_s1[:, 0, :]
        s2_cls = embedded_s2[:, 0, :]

        if self.dropout:
            s1_cls = self.dropout(s1_cls)
            s2_cls = self.dropout(s2_cls)

        label_logits = self._classifier(torch.cat([s1_cls, s2_cls], dim=-1))
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits.view(-1, self._num_labels),
                              label.view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
예제 #8
0
    def _compute_metrics_and_return_loss(self, logits: torch.Tensor,
                                         label: torch.IntTensor) -> float:
        """Helper function for the `self._make_forward_output` method."""
        for metric in self._metrics.get_dict(is_train=self.training).values():
            metric(logits, label)

        if self._multilabel:
            # casting long to float for BCELoss
            # see https://discuss.pytorch.org/t/nn-bcewithlogitsloss-cant-accept-one-hot-target/59980
            return self._loss(
                logits.view(-1, self.num_labels),
                label.view(-1, self.num_labels).type_as(logits),
            )

        return self._loss(logits, label.long())
예제 #9
0
파일: train_utils.py 프로젝트: dksifoua/NMT
def accuracy(logits: torch.FloatTensor, labels: torch.IntTensor, top_k: int = 5):
    """
    Compute the top-k accuracy.

    Args:
        logits: torch.FloatTensor[seq_len, batch_size, vocab_size]
        labels: torch.IntTensor[seq_len, batch_size]
        top_k: int

    Returns:
        float: the top-k accuracy
    """
    batch_size = logits.shape[1]
    _, indices = logits.topk(top_k, dim=1, largest=True, sorted=True)
    correct = indices.eq(labels.view(-1, 1).expand_as(indices))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)
예제 #10
0
def nd_batched_index_select(target: torch.Tensor,
                            indices: torch.IntTensor) -> torch.Tensor:
    """
    Multidimensional version of `util.batched_index_select`.
    """
    batch_axes = target.size()[:-2]
    num_batch_axes = len(batch_axes)
    target_shape = target.size()
    indices_shape = indices.size()

    target_reshaped = target.view(-1, *target_shape[num_batch_axes:])
    indices_reshaped = indices.view(-1, *indices_shape[num_batch_axes:])

    output_reshaped = util.batched_index_select(target_reshaped,
                                                indices_reshaped)

    return output_reshaped.view(*indices_shape, -1)
예제 #11
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,  # batch * words
        options: TextFieldTensors,  # batch * num_options * words
        labels: torch.IntTensor = None  # batch * num_options
    ) -> Dict[str, torch.Tensor]:
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).long()

        embedded_options = self._text_field_embedder(
            options, num_wrapping_dims=1)  # options_mask.dim() - 2
        options_mask = get_text_field_mask(options).long()

        if self._dropout:
            embedded_text = self._dropout(embedded_text)
            embedded_options = self._dropout(embedded_options)
        """
        This isn't exactly a 'hack', but it's definitely not the most efficient way to do it.
        Our matcher expects a single (query, document) pair, but we have (query, [d_0, ..., d_n]).
        To get around this, we expand the query embeddings to create these pairs, and then
        flatten both into the 3D tensor [batch*num_options, words, dim] expected by the matcher. 
        The expansion does this:

        [
            (q_0, [d_{0,0}, ..., d_{0,n}]), 
            (q_1, [d_{1,0}, ..., d_{1,n}])
        ]
        =>
        [
            [ (q_0, d_{0,0}), ..., (q_0, d_{0,n}) ],
            [ (q_1, d_{1,0}), ..., (q_1, d_{1,n}) ]
        ]

        Which we then flatten along the batch dimension. It would likely be more efficient
        to rewrite the matrix multiplications in the relevance matchers, but this is a more general solution.
        """

        embedded_text = embedded_text.unsqueeze(1).expand(
            -1, embedded_options.size(1), -1,
            -1)  # [batch, num_options, words, dim]
        mask = mask.unsqueeze(1).expand(-1, embedded_options.size(1), -1)

        scores = self._relevance_matcher(embedded_text, embedded_options, mask,
                                         options_mask).squeeze(-1)
        probs = torch.sigmoid(scores)

        output_dict = {"logits": scores, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)
        if labels is not None:
            label_mask = (labels != -1)

            self._mrr(probs, labels, label_mask)
            self._ndcg(probs, labels, label_mask)

            probs = probs.view(-1)
            labels = labels.view(-1)
            label_mask = label_mask.view(-1)

            self._auc(probs, labels.ge(0.5).long(), label_mask)

            loss = self._loss(probs, labels)
            output_dict["loss"] = loss.masked_fill(~label_mask,
                                                   0).sum() / label_mask.sum()

        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, num_of_passage_tokens = passage['bert'].size()

        # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of
        # start and end spans.
        embedded_passage = self._text_field_embedder(passage)
        passage_length = embedded_passage.size(1)
        logits = self.qa_outputs(embedded_passage)
        start_logits, end_logits = logits.split(1, dim=-1)
        span_start_logits = start_logits.squeeze(-1)
        span_end_logits = end_logits.squeeze(-1)

        # Adding some masks with numerically stable values
        passage_mask = util.get_text_field_mask(passage).float()
        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            batch_size, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        output_dict: Dict[str, Any] = {}
        # add span start and end logits for knowledge distillation
        output_dict: Dict[str, Any] = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
        }

        # We may have multiple instances per questions, moving to per-question
        intances_question_id = [
            insta_meta['question_id'] for insta_meta in metadata
        ]
        question_instances_split_inds = np.cumsum(
            np.unique(intances_question_id, return_counts=True)[1])[:-1]
        per_question_inds = np.split(range(batch_size),
                                     question_instances_split_inds)
        metadata = np.split(metadata, question_instances_split_inds)

        # Compute the loss.
        # if span_start is not None and len(np.argwhere(span_start.squeeze().cpu() >= 0)) > 0:
        if span_start is not None and len(
                np.argwhere(
                    span_start.squeeze(-1).squeeze(-1).cpu() >= 0)) > 0:
            # in evaluation some instances may not contain the gold answer, so we need to compute
            # loss only on those that do.
            inds_with_gold_answer = np.argwhere(
                span_start.view(-1).cpu().numpy() >= 0)
            inds_with_gold_answer = inds_with_gold_answer.squeeze(
            ) if len(inds_with_gold_answer) > 1 else inds_with_gold_answer
            if len(inds_with_gold_answer) > 0:
                loss = nll_loss(util.masked_log_softmax(span_start_logits[inds_with_gold_answer], \
                                                    repeated_passage_mask[inds_with_gold_answer]),\
                                span_start.view(-1)[inds_with_gold_answer], ignore_index=-1)
                output_dict["loss_start"] = loss
                loss += nll_loss(util.masked_log_softmax(span_end_logits[inds_with_gold_answer], \
                                                    repeated_passage_mask[inds_with_gold_answer]),\
                                span_end.view(-1)[inds_with_gold_answer], ignore_index=-1)
                output_dict["loss"] = loss
                output_dict["loss_end"] = loss - output_dict["loss_start"]

        # This is a hack for cases in which gold answer is not provided so we cannot compute loss...
        if 'loss' not in output_dict:
            output_dict["loss"] = torch.cuda.FloatTensor([0], device=span_end_logits.device) \
                if torch.cuda.is_available() else torch.FloatTensor([0])

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict["start_bias_weight"] = []
        output_dict["end_bias_weight"] = []

        # getting best span prediction for
        best_span = self._get_example_predications(span_start_logits,
                                                   span_end_logits,
                                                   self._max_span_length)
        best_span_cpu = best_span.detach().cpu().numpy()

        span_start_logits_numpy = span_start_logits.data.cpu().numpy()
        span_end_logits_numpy = span_end_logits.data.cpu().numpy()
        # Iterating over every question (which may contain multiple instances, one per chunk)
        for question_inds, question_instances_metadata in zip(
                per_question_inds, metadata):
            best_span_ind = np.argmax(
                span_start_logits_numpy[question_inds,
                                        best_span_cpu[question_inds][:, 0]] +
                span_end_logits_numpy[question_inds,
                                      best_span_cpu[question_inds][:, 1]])
            best_span_logit = np.max(
                span_start_logits_numpy[question_inds,
                                        best_span_cpu[question_inds][:, 0]] +
                span_end_logits_numpy[question_inds,
                                      best_span_cpu[question_inds][:, 1]])

            passage_str = question_instances_metadata[best_span_ind][
                'original_passage']
            offsets = question_instances_metadata[best_span_ind][
                'token_offsets']

            predicted_span = best_span_cpu[question_inds[best_span_ind]]
            start_offset = offsets[predicted_span[0]][0]
            end_offset = offsets[predicted_span[1]][1]
            best_span_string = passage_str[start_offset:end_offset]

            # Note: this is a hack, because AllenNLP, when predicting, expects a value for each instance.
            # But we may have more than 1 chunk per question, and thus less output strings than instances
            for i in range(len(question_inds)):
                output_dict['best_span_str'].append(best_span_string)
                output_dict['qid'].append(
                    question_instances_metadata[best_span_ind]['question_id'])

                # get the scalar logit value of the predicted span start and end index as bias weight.
                output_dict["start_bias_weight"].append(
                    util.masked_softmax(span_start_logits[best_span_ind],
                                        repeated_passage_mask[best_span_ind])[
                                            best_span_cpu[best_span_ind][0]])
                output_dict["end_bias_weight"].append(
                    util.masked_softmax(span_end_logits[best_span_ind],
                                        repeated_passage_mask[best_span_ind])[
                                            best_span_cpu[best_span_ind][1]])

            f1_score = 0.0
            EM_score = 0.0
            gold_answer_texts = question_instances_metadata[best_span_ind][
                'answer_texts_list']
            if gold_answer_texts:
                f1_score = squad_eval.metric_max_over_ground_truths(
                    squad_eval.f1_score, best_span_string, gold_answer_texts)
                EM_score = squad_eval.metric_max_over_ground_truths(
                    squad_eval.exact_match_score, best_span_string,
                    gold_answer_texts)
            self._official_f1(100 * f1_score)
            self._official_EM(100 * EM_score)

            # TODO move to predict
            if self._predictions_file is not None:
                with open(self._predictions_file, 'a') as f:
                    f.write(json.dumps({'question_id':question_instances_metadata[best_span_ind]['question_id'], \
                                'best_span_logit':float(best_span_logit), \
                                'f1':100 * f1_score,
                                'EM':100 * EM_score,
                                'best_span_string':best_span_string,\
                                'gold_answer_texts':gold_answer_texts, \
                                'qas_used_fraction':1.0}) + '\n')

        return output_dict
예제 #13
0
    def forward(
            self,
            sentence: Dict[str, torch.LongTensor],
            column: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            col_start_idx: torch.IntTensor = None,
            col_end_idx: torch.IntTensor = None,
            val_start_idx: torch.IntTensor = None,
            val_end_idx: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        ## 字数
        batch_size, max_sent_count, max_sent_len = sentence['bert'].size()

        ## 中文分词Token数
        _, _, max_sent_token_len = sentence['bert-offsets'].size()

        # # total_qa_count * max_q_len * encoding_dim
        total_sent_count = batch_size * max_sent_count
        yesno_mask = torch.ge(yesno_list, 0).view(total_sent_count)

        # embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim())
        embedded_sentence = self._embedder(sentence['bert']).reshape(
            total_sent_count, max_sent_len, self._embedder.get_output_dim())
        embedded_passage = self._embedder(passage['bert'])
        embedded_column = self._embedder(column['bert'])

        sentence_mask = util.get_text_field_mask(
            sentence, num_wrapping_dims=1).float().squeeze(1)

        # sentence_mask = sentence_mask.reshape(total_sent_count, max_sent_len - 2)
        # sentence_mask = sentence_mask.reshape(total_sent_count, max_sent_len)
        # sentence_mask = sentence_mask.new_ones(batch_size, max_sent_count, max_sent_len)
        # sentence_mask = [[[1] + s + [1]] for s in sentence_mask]
        column_mask = util.get_text_field_mask(column).float()

        # column_mask = column_mask.reshape(total_sent_count, max_sent_len)
        # column_mask = column_mask.new_ones(batch_col_size, max_col_count, max_col_len)
        passage_mask = util.get_text_field_mask(passage).float()

        encode_passage = self._passage_BiLSTM(embedded_passage, passage_mask)
        encode_sentence = self._sentence_BiLSTM(embedded_sentence,
                                                sentence_mask)
        encode_column = self._columns_BiLSTM(embedded_column, column_mask)

        passage_length = encode_passage.size(1)
        column_length = encode_column.size(1)

        projected_passage = self.relu(self.projected_layer(encode_passage))
        projected_sentence = self.relu(self.projected_layer(encode_sentence))
        projected_column = self.relu(self.projected_layer(encode_column))

        encoded_passage = self._variational_dropout(projected_passage)
        encode_sentence = self._variational_dropout(projected_sentence)
        encode_column = self._variational_dropout(projected_column)

        # repeated_encode_column = encode_column.repeat(1, max_col_count, 1, 1)

        repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
            1, max_sent_count, 1, 1)
        repeated_encoded_passage = repeated_encoded_passage.view(
            total_sent_count, passage_length, self._encoding_dim)

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_sent_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_sent_count, passage_length)

        repeated_encode_column = encode_column.unsqueeze(1).repeat(
            1, max_sent_count, 1, 1)
        repeated_encode_column = repeated_encode_column.view(
            total_sent_count, column_length, self._encoding_dim)

        repeated_column_mask = column_mask.unsqueeze(1).repeat(
            1, max_sent_count, 1)
        repeated_column_mask = repeated_column_mask.view(
            total_sent_count, column_length)

        ## S2C
        s = torch.bmm(encode_sentence, repeated_encode_column.transpose(2, 1))
        alpha = util.masked_softmax(s,
                                    sentence_mask.unsqueeze(2).expand(
                                        s.size()),
                                    dim=1)
        aligned_s2c = torch.bmm(alpha.transpose(2, 1), encode_sentence)

        ## P2C
        p = torch.bmm(repeated_encoded_passage,
                      repeated_encode_column.transpose(2, 1))
        beta = util.masked_softmax(p,
                                   repeated_passage_mask.unsqueeze(2).expand(
                                       p.size()),
                                   dim=1)
        aligned_p2c = torch.bmm(beta.transpose(2, 1), repeated_encoded_passage)

        ## C2S
        alpha1 = util.masked_softmax(s,
                                     repeated_column_mask.unsqueeze(1).expand(
                                         s.size()),
                                     dim=1)
        aligned_c2s = torch.bmm(alpha1, repeated_encode_column)

        ## C2P
        beta1 = util.masked_softmax(p,
                                    repeated_column_mask.unsqueeze(1).expand(
                                        p.size()),
                                    dim=1)
        aligned_c2p = torch.bmm(beta1, repeated_encode_column)

        fused_p = self.fuse_p(repeated_encoded_passage, aligned_c2p)
        fused_s = self.fuse_s(encode_sentence, aligned_c2s)
        fused_c = self.fuse_c(aligned_p2c, aligned_s2c)

        contextual_p = self._passage_contextual(fused_p, repeated_passage_mask)
        contextual_s = self._sentence_contextual(fused_s, sentence_mask)
        contextual_c = self._columns_contextual(fused_c, repeated_column_mask)

        contextual_c2p = torch.bmm(contextual_p, contextual_c.transpose(1, 2))
        alpha2 = util.masked_softmax(contextual_c2p,
                                     repeated_column_mask.unsqueeze(1).expand(
                                         contextual_c2p.size()),
                                     dim=1)
        aligned_contextual_c2p = torch.bmm(alpha2, contextual_c)

        contextual_c2s = torch.bmm(contextual_s, contextual_c.transpose(1, 2))
        beta2 = util.masked_softmax(contextual_c2s,
                                    repeated_column_mask.unsqueeze(1).expand(
                                        contextual_c2s.size()),
                                    dim=1)
        aligned_contextual_c2s = torch.bmm(beta2, contextual_c)

        # cnt * m
        gamma = util.masked_softmax(
            self.linear_self_align(aligned_contextual_c2s).squeeze(2),
            sentence_mask,
            dim=1)
        # cnt * h
        weighted_s = torch.bmm(gamma.unsqueeze(1),
                               aligned_contextual_c2s).squeeze(1)

        # weighted_s = torch.bmm(gamma_s.unsqueeze(1), contextual_c2s).squeeze(1)

        span_start_logits = self.bilinear_layer_s(weighted_s,
                                                  aligned_contextual_c2p)
        span_end_logits = self.bilinear_layer_e(weighted_s,
                                                aligned_contextual_c2p)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        span_yesno_logits = self.yesno_predictor(
            torch.bmm(span_end_logits.unsqueeze(2), weighted_s.unsqueeze(1)))

        best_span = self._get_best_span(span_start_logits, span_end_logits,
                                        span_yesno_logits,
                                        self._max_span_length)
        output_dict: Dict[str, Any] = {}

        # Compute the loss for training

        if col_start_idx is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
                                                    repeated_passage_mask),
                            col_start_idx.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits,
                                      col_start_idx.view(-1),
                                      mask=yesno_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask),
                             col_end_idx.view(-1),
                             ignore_index=-1)
            self._span_end_accuracy(span_end_logits,
                                    col_end_idx.view(-1),
                                    mask=yesno_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([col_start_idx, col_end_idx],
                                            -1).view(total_sent_count, 2),
                                mask=yesno_mask.unsqueeze(1).expand(-1,
                                                                    2).long())
            gold_span_end_loc = []
            col_end_idx = col_end_idx.view(
                total_sent_count).squeeze().data.cpu().numpy()
            for i in range(0, total_sent_count):
                # print(total_sent_count)

                gold_span_end_loc.append(
                    max(col_end_idx[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(col_end_idx[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(col_end_idx[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = col_start_idx.new(gold_span_end_loc)
            pred_span_end_loc = []
            for i in range(0, total_sent_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = col_start_idx.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(torch.nn.functional.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=yesno_mask)
            output_dict["loss"] = loss

        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['origin_passage']
            offsets = passage['bert-offsets'][i].cpu().numpy()
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, sql in enumerate(metadata[i]["sqls"]):

                predicted_span = tuple(best_span_cpu[i * max_sent_count +
                                                     per_dialog_query_index])
                start_offset = offsets[predicted_span[0]]
                end_offset = offsets[predicted_span[1]]
                per_dialog_query_id_list.append(sql)
                best_span_string = ''.join([
                    t.text for t in metadata[i]['passage_tokens']
                    [start_offset:end_offset]
                ])
                #print(best_span_string)
                per_dialog_best_span_list.append(best_span_string)

            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
        return output_dict
예제 #14
0
def compute_span_representations(
        max_span_width: int, encoded_text: torch.FloatTensor,
        target_index: torch.IntTensor, span_starts: torch.IntTensor,
        span_ends: torch.IntTensor, span_width_embedding: Embedding,
        span_direction_embedding: Embedding,
        span_distance_embedding: Embedding, span_distance_bin: int,
        head_scorer: TimeDistributed) -> torch.FloatTensor:
    """
    Computes an embedded representation of every candidate span. This is a concatenation
    of the contextualized endpoints of the span, an embedded representation of the width of
    the span and a representation of the span's predicted head. Also contains a bunch of features
    with respect to the target.

    Parameters
    ----------
    encoded_text : ``torch.FloatTensor``, required.
        The deeply embedded sentence of shape (batch_size, sequence_length, embedding_dim)
        over which we are computing a weighted sum.
    span_starts : ``torch.IntTensor``, required.
        A tensor of shape (batch_size, num_spans) representing the start of each span candidate.
    span_ends : ``torch.IntTensor``, required.
        A tensor of shape (batch_size, num_spans) representing the end of each span candidate.
    Returns
    -------
    span_embeddings : ``torch.FloatTensor``
        An embedded representation of every candidate span with shape:
        (batch_size, sentence_length, span_width, context_layer.get_output_dim() * 2 + embedding_size + feature_size)
    """
    # Shape: (batch_size, sequence_length, encoding_dim)
    # TODO(Swabha): necessary to have this? is it going to mess with attention computation?
    # contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
    _, sequence_length, _ = encoded_text.size()
    contextualized_embeddings = encoded_text

    # Shape: (batch_size, num_spans, encoding_dim)
    batch_size, num_spans = span_starts.size()
    assert num_spans == sequence_length * max_span_width

    start_embeddings = util.batched_index_select(contextualized_embeddings,
                                                 span_starts.squeeze(-1))
    end_embeddings = util.batched_index_select(contextualized_embeddings,
                                               span_ends.squeeze(-1))

    # Compute and embed the span_widths (strictly speaking the span_widths - 1)
    # Shape: (batch_size, num_spans, 1)
    span_widths = span_ends - span_starts
    # Shape: (batch_size, num_spans, encoding_dim)
    span_width_embeddings = span_width_embedding(span_widths.squeeze(-1))

    target_index = target_index.view(batch_size, 1)
    span_dist = torch.abs(span_ends - target_index)
    span_dist = span_dist * (span_dist < span_distance_bin).long()
    span_dist_embeddings = span_distance_embedding(span_dist.squeeze(-1))

    span_dir = ((span_ends - target_index) > 0).long()
    span_dir_embeddings = span_direction_embedding(span_dir.squeeze(-1))

    # Shape: (batch_size, sequence_length, 1)
    head_scores = head_scorer(contextualized_embeddings)

    # Shape: (batch_size, num_spans, embedding_dim)
    # Note that we used the original text embeddings, not the contextual ones here.
    attended_text_embeddings = create_attended_span_representations(
        max_span_width, head_scores, encoded_text, span_ends, span_widths)
    # (batch_size, num_spans, context_layer.get_output_dim() * 3 + 2 * feature_dim)
    span_embeddings = torch.cat([
        start_embeddings, end_embeddings, span_width_embeddings,
        span_dist_embeddings, span_dir_embeddings, attended_text_embeddings
    ], -1)
    span_embeddings = span_embeddings.view(batch_size, sequence_length,
                                           max_span_width, -1)
    return span_embeddings
예제 #15
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            sentence_spans: torch.IntTensor = None,
            sent_labels: torch.IntTensor = None,
            evd_chain_labels: torch.IntTensor = None,
            q_type: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self._sent_labels_src == 'chain':
            batch_size, num_spans = sent_labels.size()
            sent_labels_mask = (sent_labels >= 0).float()
            print("chain:", evd_chain_labels)
            # we use the chain as the label to supervise the gate
            # In this model, we only take the first chain in ``evd_chain_labels`` for supervision,
            # right now the number of chains should only be one too.
            evd_chain_labels = evd_chain_labels[:, 0].long()
            # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding
            # shape: (batch_size, 1+num_spans)
            sent_labels = sent_labels.new_zeros((batch_size, 1 + num_spans))
            sent_labels.scatter_(1, evd_chain_labels, 1.)
            # remove the column for end embedding
            # shape: (batch_size, num_spans)
            sent_labels = sent_labels[:, 1:].float()
            # make the padding be -1
            sent_labels = sent_labels * sent_labels_mask + -1. * (
                1 - sent_labels_mask)

        # word + char embedding
        embedded_question = self._text_field_embedder(question)
        embedded_passage = self._text_field_embedder(passage)
        # mask
        ques_mask = util.get_text_field_mask(question).float()
        context_mask = util.get_text_field_mask(passage).float()

        # BiDAF for answer predicion
        ques_output = self._dropout(
            self._phrase_layer(embedded_question, ques_mask))
        context_output = self._dropout(
            self._phrase_layer(embedded_passage, context_mask))

        modeled_passage, _, qc_score = self.qc_att(context_output, ques_output,
                                                   ques_mask)

        modeled_passage = self._modeling_layer(modeled_passage, context_mask)

        # BiDAF for gate prediction
        ques_output_sp = self._dropout(
            self._phrase_layer_sp(embedded_question, ques_mask))
        context_output_sp = self._dropout(
            self._phrase_layer_sp(embedded_passage, context_mask))

        modeled_passage_sp, _, qc_score_sp = self.qc_att_sp(
            context_output_sp, ques_output_sp, ques_mask)

        modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp,
                                                     context_mask)

        # gate prediction
        # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim)
        # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(
            modeled_passage_sp, sentence_spans)
        spans_rep, _ = convert_sequence_to_spans(modeled_passage,
                                                 sentence_spans)
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        # Shape(gate_mask): (batch_size, num_spans)
        #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask)
        gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(
            spans_rep_sp, spans_mask, self._gate_self_attention_layer,
            self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()

        strong_sup_loss = F.nll_loss(
            F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
            sent_labels.long().view(batch_size * num_spans),
            ignore_index=-1)

        gate = (gate >= 0.3).long()
        spans_rep = spans_rep * gate.unsqueeze(-1).float()
        attended_sent_embeddings = convert_span_to_sequence(
            modeled_passage_sp, spans_rep, spans_mask)

        modeled_passage = attended_sent_embeddings + modeled_passage

        self_att_passage = self._self_attention_layer(modeled_passage,
                                                      mask=context_mask)
        modeled_passage = modeled_passage + self_att_passage[0]
        self_att_score = self_att_passage[2]

        output_start = self._span_start_encoder(modeled_passage, context_mask)
        span_start_logits = self.linear_start(output_start).squeeze(
            2) - 1e30 * (1 - context_mask)
        output_end = torch.cat([modeled_passage, output_start], dim=2)
        output_end = self._span_end_encoder(output_end, context_mask)
        span_end_logits = self.linear_end(output_end).squeeze(
            2) - 1e30 * (1 - context_mask)

        output_type = torch.cat([modeled_passage, output_end, output_start],
                                dim=2)
        output_type = torch.max(output_type, 1)[0]
        # output_type = torch.max(self.rnn_type(output_type, context_mask), 1)[0]
        predict_type = self.linear_type(output_type)
        type_predicts = torch.argmax(predict_type, 1)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span": best_span,
            "pred_sent_labels": gate.view(batch_size,
                                          num_spans),  #[B, num_span]
            "gate_probs":
            pred_sent_probs[:, 1].view(batch_size, num_spans),  #[B, num_span]
        }
        if self._output_att_scores:
            if not qc_score is None:
                output_dict['qc_score'] = qc_score
            if not qc_score_sp is None:
                output_dict['qc_score_sp'] = qc_score_sp
            if not self_att_score is None:
                output_dict['self_attention_score'] = self_att_score
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        # Compute the loss for training.
        if span_start is not None:
            try:
                start_loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, None),
                    span_start.squeeze(-1))
                end_loss = nll_loss(
                    util.masked_log_softmax(span_end_logits, None),
                    span_end.squeeze(-1))
                type_loss = nll_loss(
                    util.masked_log_softmax(predict_type, None), q_type)
                loss = start_loss + end_loss + type_loss + strong_sup_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['start_loss'](start_loss)
                self._loss_trackers['end_loss'](end_loss)
                self._loss_trackers['type_loss'](type_loss)
                self._loss_trackers['strong_sup_loss'](strong_sup_loss)
                output_dict["loss"] = loss
            except RuntimeError:
                print('\n meta_data:', metadata)
                print(span_start_logits.shape)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            token_spans_sp = []
            token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            ids = []
            count_yes = 0
            count_no = 0
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                token_spans_sp.append(metadata[i]['token_spans_sp'])
                token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                if type_predicts[i] == 1:
                    best_span_string = 'yes'
                    count_yes += 1
                elif type_predicts[i] == 2:
                    best_span_string = 'no'
                    count_no += 1
                else:
                    predicted_span = tuple(best_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    best_span_string = passage_str[start_offset:end_offset]

                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                if answer_texts:
                    self._squad_metrics(best_span_string.lower(), answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([
                    s_idx - 1
                    for s_idx in metadata[i]['evd_possible_chains'][0]
                    if s_idx > 0
                ])
                ans_sent_idxs.append(
                    [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']])
            self._f1_metrics(pred_sent_probs, sent_labels.view(-1),
                             gate_mask.view(-1))
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['token_spans_sp'] = token_spans_sp
            output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['_id'] = ids

        return output_dict
예제 #16
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                sentence_spans: torch.IntTensor = None,
                sent_labels: torch.IntTensor = None,
                evd_chain_labels: torch.IntTensor = None,
                q_type: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # In this model, we only take the first chain in ``evd_chain_labels`` for supervision
        evd_chain_labels = evd_chain_labels[:, 0] if not evd_chain_labels is None else None
        # there may be some instances that we can't find any evd chain for training
        # In that case, use the mask to ignore those instances
        evd_instance_mask = (evd_chain_labels[:, 0] != 0).float() if not evd_chain_labels is None else None
        #print('passage size:', passage['bert'].shape)
        # bert embedding for answer prediction
        # shape: [batch_size, max_q_len, emb_size]
        print('\nBert wordpiece size:', passage['bert'].shape)
        embedded_question = self._text_field_embedder(question)
        # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim]
        embedded_passage = self._text_field_embedder(passage)
        # print('\npassage size:', embedded_passage.shape)
        #embedded_question = self._bert_projection(embedded_question)
        #embedded_passage = self._bert_projection(embedded_passage)
        #print('size embedded_passage:', embedded_passage.shape)
        # mask
        ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float()
        context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float()
        #print(context_mask.shape)
        # get the word embeddings for sentences
        batch_size, num_sent, max_sent_len, embedding_dim = embedded_passage.size()
        embedded_passage = embedded_passage.view(batch_size*num_sent, max_sent_len, embedding_dim)
        sentence_spans = sentence_spans.view(batch_size*num_sent, 2).unsqueeze(1)
        # spans_rep_sp shape: (batch_size*num_sent*1, max_sent_len(no extend), embedding_dim)
        # spans_mask shape: (batch_size*num_sent, 1, max_sent_len(no_entend))
        # print(sentence_spans)
        # print(embedded_passage.shape)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(embedded_passage, sentence_spans)
        max_sent_len = spans_rep_sp.size(1)
        spans_rep_sp = spans_rep_sp.view(batch_size, num_sent, max_sent_len, embedding_dim)
        spans_mask = spans_mask.view(batch_size, num_sent, max_sent_len)
        # chain prediction
        # Shape(all_predictions): (batch_size, num_decoding_steps)
        # Shape(all_logprobs): (batch_size, num_decoding_steps)
        # Shape(seq_logprobs): (batch_size,)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(gate_probs): (batch_size * num_spans, 1)
        # Shape(gate_mask): (batch_size, num_spans)
        # Shape(g_att_score): (batch_size, num_heads, num_spans, num_spans)
        # Shape(orders): (batch_size, K, num_spans)
        all_predictions,    \
        all_logprobs,       \
        seq_logprobs,       \
        gate,               \
        gate_probs,         \
        gate_mask,          \
        g_att_score,        \
        orders = self._span_gate(spans_rep_sp, spans_mask,
                                 embedded_question, ques_mask,
                                 evd_chain_labels,
                                 self._gate_self_attention_layer,
                                 self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = context_mask.size()

        output_dict = {
            "pred_sent_labels": gate.squeeze(1).view(batch_size, num_spans), #[B, num_span]
            "gate_probs": gate_probs.squeeze(1).view(batch_size, num_spans), #[B, num_span]
            "pred_sent_orders": orders, #[B, K, num_span]
        }
        if self._output_att_scores:
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # compute evd rl training metric, rewards, and loss
        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        evd_TP, evd_NP, evd_NT = self._f1_metrics(gate.squeeze(1).view(batch_size, num_spans),
                                                  sent_labels,
                                                  mask=gate_mask,
                                                  instance_mask=evd_instance_mask if self.training else None,
                                                  sum=False)
        # print("TP:", evd_TP)
        # print("NP:", evd_NP)
        # print("NT:", evd_NT)
        evd_ps = np.array(evd_TP) / (np.array(evd_NP) + 1e-13)
        evd_rs = np.array(evd_TP) / (np.array(evd_NT) + 1e-13)
        evd_f1s = 2. * ((evd_ps * evd_rs) / (evd_ps + evd_rs + 1e-13))
        predict_mask = get_evd_prediction_mask(all_predictions.unsqueeze(1), eos_idx=0)[0]
        gold_mask = get_evd_prediction_mask(evd_chain_labels, eos_idx=0)[0]
        # default to take multiple predicted chains, so unsqueeze dim 1
        self.evd_sup_acc_metric(predictions=all_predictions.unsqueeze(1), gold_labels=evd_chain_labels,
                                predict_mask=predict_mask, gold_mask=gold_mask, instance_mask=evd_instance_mask)
        print("gold chain:", evd_chain_labels)
        predict_mask = predict_mask.float().squeeze(1)
        rl_loss = -torch.mean(torch.sum(all_logprobs * predict_mask * evd_instance_mask[:, None], dim=1))
        # torch.cuda.empty_cache()
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        # Compute before loss for rl
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            #token_spans_sp = []
            #token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            pred_chains_include_ans = []
            beam_pred_chains_include_ans = []
            ids = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_sent_tokens'])
                #token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                #offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0])
                ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']])
                print("ans_sent_idxs:", metadata[i]['ans_sent_idxs'])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_orders = orders[i].detach().cpu().numpy()
                    if any([pred_sent_orders[0][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]):
                        self.evd_ans_metric(1)
                        pred_chains_include_ans.append(1)
                    else:
                        self.evd_ans_metric(0)
                        pred_chains_include_ans.append(0)
                    if any([any([pred_sent_orders[beam][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]) 
                                                                      for beam in range(len(pred_sent_orders))]):
                        self.evd_beam_ans_metric(1)
                        beam_pred_chains_include_ans.append(1)
                    else:
                        self.evd_beam_ans_metric(0)
                        beam_pred_chains_include_ans.append(0)

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_sent_tokens'] = passage_tokens
            #output_dict['token_spans_sp'] = token_spans_sp
            #output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['pred_chains_include_ans'] = pred_chains_include_ans
            output_dict['beam_pred_chains_include_ans'] = beam_pred_chains_include_ans
            output_dict['_id'] = ids

        # Compute the loss for training.
        if evd_chain_labels is not None:
            try:
                loss = rl_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['rl_loss'](rl_loss)
                output_dict["loss"] = loss
            except RuntimeError:
                print('\n meta_data:', metadata)
                print(output_dict['_id'])

        return output_dict
예제 #17
0
    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, max_qa_count, max_q_len, _ = question[
            'token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(yesno_list, 0).view(total_qa_count)

        embedded_question = self._text_field_embedder(question,
                                                      num_wrapping_dims=1)
        # total_qa_count * max_q_len * encoding_dim
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
            self._text_field_embedder.get_output_dim())
        embedded_passage = self._text_field_embedder(passage)

        # split the embedded tensors to get the word embedding and char embedding, elmo embedding and features embedding
        word_emb_ques, elmo_ques, ques_feat = torch.split(embedded_question,
                                                          [200, 1024, 40],
                                                          dim=2)
        word_emb_pass, elmo_pass, pass_feat = torch.split(embedded_passage,
                                                          [200, 1024, 40],
                                                          dim=2)
        # word embedding and char embedding
        embedded_question = self._variational_dropout(
            torch.cat([word_emb_ques, elmo_ques], dim=2))
        embedded_passage = self._variational_dropout(
            torch.cat([word_emb_pass, elmo_pass], dim=2))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question,
                                                 num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        encode_passage = self._phrase_layer(embedded_passage, passage_mask)
        projected_passage = self.relu(
            self.projected_layer(torch.cat([encode_passage, elmo_pass],
                                           dim=2)))

        encode_question = self._phrase_layer(embedded_question, question_mask)
        projected_question = self.relu(
            self.projected_layer(torch.cat([encode_question, elmo_ques],
                                           dim=2)))

        encoded_passage = self._variational_dropout(projected_passage)
        repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
            1, max_qa_count, 1, 1)
        repeated_encoded_passage = repeated_encoded_passage.view(
            total_qa_count, passage_length, self._encoding_dim)
        repeated_pass_feat = (pass_feat.unsqueeze(1).repeat(
            1, max_qa_count, 1, 1)).view(total_qa_count, passage_length, 40)
        encoded_question = self._variational_dropout(projected_question)

        # total_qa_count * max_q_len * passage_length
        # cnt * m * n
        s = torch.bmm(encoded_question,
                      repeated_encoded_passage.transpose(2, 1))
        alpha = util.masked_softmax(s,
                                    question_mask.unsqueeze(2).expand(
                                        s.size()),
                                    dim=1)
        # cnt * n * h
        aligned_p = torch.bmm(alpha.transpose(2, 1), encoded_question)

        # cnt * m * n
        beta = util.masked_softmax(s,
                                   repeated_passage_mask.unsqueeze(1).expand(
                                       s.size()),
                                   dim=2)
        # cnt * m * h
        aligned_q = torch.bmm(beta, repeated_encoded_passage)

        fused_p = self.fuse_p(repeated_encoded_passage, aligned_p)
        fused_q = self.fuse_q(encoded_question, aligned_q)

        # add manual features here
        q_aware_p = self.projected_lstm(
            torch.cat([fused_p, repeated_pass_feat], dim=2),
            repeated_passage_mask)

        # cnt * n * n
        # self_p = torch.bmm(q_aware_p, q_aware_p.transpose(2, 1))
        # self_p = self.bilinear_self_align(q_aware_p)
        self_p = self._self_attention(q_aware_p, q_aware_p)
        # for i in range(passage_length):
        #     self_p[:, i, i] = 0
        mask = repeated_passage_mask.reshape(
            total_qa_count, passage_length, 1) * repeated_passage_mask.reshape(
                total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              device=self_p.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        lamb = util.masked_softmax(self_p, mask, dim=2)
        # lamb = util.masked_softmax(self_p, repeated_passage_mask, dim=2)
        # cnt * n * h
        self_aligned_p = torch.bmm(lamb, q_aware_p)

        # cnt * n * h
        fused_self_p = self.fuse_s(q_aware_p, self_aligned_p)
        # contextual_p = self._variational_dropout(self.contextual_layer_p(fused_self_p, repeated_passage_mask))
        contextual_p = self.contextual_layer_p(fused_self_p,
                                               repeated_passage_mask)

        # contextual_q = self._variational_dropout(self.contextual_layer_q(fused_q, question_mask))
        contextual_q = self.contextual_layer_q(fused_q, question_mask)
        # cnt * m
        gamma = util.masked_softmax(
            self.linear_self_align(contextual_q).squeeze(2),
            question_mask,
            dim=1)
        # cnt * h
        weighted_q = torch.bmm(gamma.unsqueeze(1), contextual_q).squeeze(1)

        span_start_logits = self.bilinear_layer_s(weighted_q, contextual_p)
        span_end_logits = self.bilinear_layer_e(weighted_q, contextual_p)

        # cnt * n * 1  cnt * 1 * h
        span_yesno_logits = self.yesno_predictor(
            torch.bmm(span_end_logits.unsqueeze(2), weighted_q.unsqueeze(1)))
        # span_yesno_logits = self.yesno_predictor(contextual_p)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits,
                                                       span_end_logits,
                                                       span_yesno_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss for training

        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
                                                    repeated_passage_mask),
                            span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask),
                             span_end.view(-1),
                             ignore_index=-1)
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(torch.nn.functional.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                yesno_pred = predicted_span[2]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_query_id_list.append(iid)
                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad_eval.metric_max_over_ground_truths(
                                    squad_eval.f1_score, best_span_string,
                                    refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
        return output_dict
예제 #18
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict
예제 #19
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()
        premise_sequence_lengths = get_lengths_from_binary_sequence_mask(
            premise_mask)
        hypothesis_sequence_lengths = get_lengths_from_binary_sequence_mask(
            hypothesis_mask)

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_sequence_lengths)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_sequence_lengths)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,
                                                   projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            if label.dim() == 2:
                _, label = label.max(-1)
            loss = self._loss(label_logits, label.view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
예제 #20
0
    def forward(
        self,  # type: ignore
        spans: torch.IntTensor,
        span_mask: torch.IntTensor,
        span_embeddings: torch.IntTensor,
        sentence_lengths: torch.Tensor,
        ner_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
        previous_step_output: Dict[str,
                                   Any] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings
        ner_scores = self._ner_scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        ner_scores = util.replace_masked_values(ner_scores, mask, -1e20)
        dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1]
        dummy_scores = ner_scores.new_zeros(*dummy_dims)
        if previous_step_output is not None and "predicted_span" in previous_step_output and not self.training:
            dummy_scores.masked_fill_(
                previous_step_output["predicted_span"].bool().unsqueeze(-1),
                -1e20)
            dummy_scores.masked_fill_(
                (1 -
                 previous_step_output["predicted_span"]).bool().unsqueeze(-1),
                1e20)

        ner_scores = torch.cat((dummy_scores, ner_scores), -1)

        if previous_step_output is not None and "predicted_seq_span" in previous_step_output and not self.training:
            for row_idx, all_spans in enumerate(spans):
                pred_spans = previous_step_output["predicted_seq_span"][
                    row_idx]
                pred_spans = all_spans.new_tensor(pred_spans)
                for col_idx, span in enumerate(all_spans):
                    if span_mask[row_idx][col_idx] == 0:
                        continue
                    bFind = False
                    for pred_span in pred_spans:
                        if span[0] == pred_span[0] and span[1] == pred_span[1]:
                            bFind = True
                            break
                    if bFind:
                        # if find, use the ner scores, set dummy to a big negative
                        ner_scores[row_idx, col_idx, 0] = -1e20
                    else:
                        # if not find, use the previous step, set dummy to a big positive
                        ner_scores[row_idx, col_idx, 0] = 1e20

        _, predicted_ner = ner_scores.max(2)

        output_dict = {
            "spans": spans,
            "span_mask": span_mask,
            "ner_scores": ner_scores,
            "predicted_ner": predicted_ner
        }

        if ner_labels is not None:
            self._ner_metrics(predicted_ner, ner_labels, span_mask)
            ner_scores_flat = ner_scores.view(-1, self._n_labels)
            ner_labels_flat = ner_labels.view(-1)
            mask_flat = span_mask.view(-1).bool()

            loss = self._loss(ner_scores_flat[mask_flat],
                              ner_labels_flat[mask_flat])
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["document"] = [x["sentence"] for x in metadata]

        return output_dict
예제 #21
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            p1_answer_marker: torch.IntTensor = None,
            p2_answer_marker: torch.IntTensor = None,
            p3_answer_marker: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            followup_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question[
            'token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question,
                                                      num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
            self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(
            self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question,
                                                 num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(
                max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(
                question_num_ind)
            embedded_question = torch.cat(
                [embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count,
                                                     passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat(
                [repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(
                    total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat(
                    [repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(
                        total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(
                        p3_answer_marker)
                    repeated_embedded_passage = torch.cat(
                        [repeated_embedded_passage, p3_answer_marker_emb],
                        dim=-1)

            repeated_encoded_passage = self._variational_dropout(
                self._phrase_layer(repeated_embedded_passage,
                                   repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(
                self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(
                total_qa_count, passage_length, self._encoding_dim)

        encoded_question = self._variational_dropout(
            self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(
            repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(total_qa_count, passage_length, self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            repeated_encoded_passage, passage_question_vectors,
            repeated_encoded_passage * passage_question_vectors,
            repeated_encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(
            self._residual_encoder(final_merged_passage,
                                   repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer,
                                                     residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs,
                                           residual_layer)
        self_attention_vecs = torch.cat([
            self_attention_vecs, residual_layer,
            residual_layer * self_attention_vecs
        ],
                                        dim=-1)
        residual_layer = F.relu(
            self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage,
                                             repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(
            torch.cat([final_merged_passage, start_rep], dim=-1),
            repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(
            -1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits,
                                                       span_end_logits,
                                                       span_yesno_logits,
                                                       span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
                                                    repeated_passage_mask),
                            span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask),
                             span_end.view(-1),
                             ignore_index=-1)
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1),
                             followup_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)
            self._span_followup_accuracy(_followup,
                                         followup_list.view(-1),
                                         mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad_eval.metric_max_over_ground_truths(
                                    squad_eval.f1_score, best_span_string,
                                    refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict
예제 #22
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno: torch.IntTensor = None,
            question_tf: torch.FloatTensor = None,
            passage_tf: torch.FloatTensor = None,
            q_em_cased: torch.IntTensor = None,
            p_em_cased: torch.IntTensor = None,
            q_em_uncased: torch.IntTensor = None,
            p_em_uncased: torch.IntTensor = None,
            q_in_lemma: torch.IntTensor = None,
            p_in_lemma: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        x1_c_emb = self._dropout(self._char_field_embedder(passage))
        x2_c_emb = self._dropout(self._char_field_embedder(question))

        # embedded_question = torch.cat([self._dropout(self._text_field_embedder(question)),
        #                                self._features_embedder(q_em_cased),
        #                                self._features_embedder(q_em_uncased),
        #                                self._features_embedder(q_in_lemma),
        #                                question_tf.unsqueeze(2)], dim=2)
        # embedded_passage = torch.cat([self._dropout(self._text_field_embedder(passage)),
        #                               self._features_embedder(p_em_cased),
        #                               self._features_embedder(p_em_uncased),
        #                               self._features_embedder(p_in_lemma),
        #                               passage_tf.unsqueeze(2)], dim=2)
        token_emb_q = self._dropout(self._text_field_embedder(question))
        token_emb_c = self._dropout(self._text_field_embedder(passage))
        token_emb_question, q_ner_and_pos = torch.split(token_emb_q, [300, 40],
                                                        dim=2)
        token_emb_passage, p_ner_and_pos = torch.split(token_emb_c, [300, 40],
                                                       dim=2)
        question_word_features = torch.cat([
            q_ner_and_pos,
            self._features_embedder(q_em_cased),
            self._features_embedder(q_em_uncased),
            self._features_embedder(q_in_lemma),
            question_tf.unsqueeze(2)
        ],
                                           dim=2)
        passage_word_features = torch.cat([
            p_ner_and_pos,
            self._features_embedder(p_em_cased),
            self._features_embedder(p_em_uncased),
            self._features_embedder(p_in_lemma),
            passage_tf.unsqueeze(2)
        ],
                                          dim=2)

        # embedded_question = self._highway_layer(embedded_q)
        # embedded_passage = self._highway_layer(embedded_q)

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        char_features_c = self._char_rnn(
            x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1),
                              x1_c_emb.size(2), x1_c_emb.size(3))),
            passage_lstm_mask.unsqueeze(2).repeat(
                1, 1, x1_c_emb.size(2)).reshape(
                    (x1_c_emb.size(0) * x1_c_emb.size(1),
                     x1_c_emb.size(2)))).reshape(
                         (x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2),
                          -1))[:, :, -1, :]
        char_features_q = self._char_rnn(
            x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1),
                              x2_c_emb.size(2), x2_c_emb.size(3))),
            question_lstm_mask.unsqueeze(2).repeat(
                1, 1, x2_c_emb.size(2)).reshape(
                    (x2_c_emb.size(0) * x2_c_emb.size(1),
                     x2_c_emb.size(2)))).reshape(
                         (x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2),
                          -1))[:, :, -1, :]

        # token_emb_q, char_emb_q, question_word_features = torch.split(embedded_question, [300, 300, 56], dim=2)
        # token_emb_c, char_emb_c, passage_word_features = torch.split(embedded_passage, [300, 300, 56], dim=2)

        # char_features_q = self._char_rnn(char_emb_q, question_lstm_mask)
        # char_features_c = self._char_rnn(char_emb_c, passage_lstm_mask)

        emb_question = torch.cat(
            [token_emb_question, char_features_q, question_word_features],
            dim=2)
        emb_passage = torch.cat(
            [token_emb_passage, char_features_c, passage_word_features], dim=2)

        encoded_question = self._dropout(
            self._phrase_layer(emb_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(emb_passage, passage_lstm_mask))

        batch_size = encoded_question.size(0)
        passage_length = encoded_passage.size(1)

        encoding_dim = encoded_question.size(-1)

        # c_check = self._stacked_brnn(encoded_passage, passage_lstm_mask)
        # q = self._stacked_brnn(encoded_question, question_lstm_mask)
        c_check = encoded_passage
        q = encoded_question
        for i in range(self.hops):
            q_tilde = self.interactive_aligners[i].forward(
                c_check, q, question_mask)
            c_bar = self.interactive_SFUs[i].forward(
                c_check,
                torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2))
            c_tilde = self.self_aligners[i].forward(c_bar, passage_mask)
            c_hat = self.self_SFUs[i].forward(
                c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde],
                                 2))
            c_check = self.aggregate_rnns[i].forward(c_hat, passage_mask)

        # Predict
        start_scores, end_scores, yesno_scores = self.mem_ans_ptr.forward(
            c_check, q, passage_mask, question_mask)

        best_span, yesno_predict, loc = self.get_best_span(
            start_scores, end_scores, yesno_scores)

        output_dict = {
            "span_start_logits": start_scores,
            "span_end_logits": end_scores,
            "best_span": best_span
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(start_scores, span_start.squeeze(-1))
            self._span_start_accuracy(start_scores, span_start.squeeze(-1))
            loss += nll_loss(end_scores, span_end.squeeze(-1))
            self._span_end_accuracy(end_scores, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

            gold_span_end_loc = []
            span_end = span_end.view(batch_size).squeeze().data.cpu().numpy()
            for i in range(batch_size):
                gold_span_end_loc.append(
                    max(span_end[i] + i * passage_length, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(_yesno, yesno.view(-1), ignore_index=-1)

            pred_span_end_loc = []
            for i in range(batch_size):
                pred_span_end_loc.append(max(loc[i], 0))
            predicted_end = span_start.new(pred_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(0,
                                                           predicted_end).view(
                                                               -1, 3)
            self._span_yesno_accuracy(_yesno, yesno.squeeze(-1))

            output_dict['loss'] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['yesno'] = yesno_predict
        return output_dict
예제 #23
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                sentence_spans: torch.IntTensor = None,
                sent_labels: torch.IntTensor = None,
                evd_chain_labels: torch.IntTensor = None,
                q_type: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self._sent_labels_src == 'chain':
            batch_size, num_spans = sent_labels.size()
            sent_labels_mask = (sent_labels >= 0).float()
            print("chain:", evd_chain_labels)
            # we use the chain as the label to supervise the gate
            # In this model, we only take the first chain in ``evd_chain_labels`` for supervision,
            # right now the number of chains should only be one too.
            evd_chain_labels = evd_chain_labels[:, 0].long()
            # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding
            # shape: (batch_size, 1+num_spans)
            sent_labels = sent_labels.new_zeros((batch_size, 1+num_spans))
            sent_labels.scatter_(1, evd_chain_labels, 1.)
            # remove the column for end embedding
            # shape: (batch_size, num_spans)
            sent_labels = sent_labels[:, 1:].float()
            # make the padding be -1
            sent_labels = sent_labels * sent_labels_mask + -1. * (1 - sent_labels_mask)

        print('\nBert wordpiece size:', passage['bert'].shape)
        # bert embedding for answer prediction
        # shape: [batch_size, max_q_len, emb_size]
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=0)
        # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim]
        embedded_passage = self._text_field_embedder(passage, num_wrapping_dims=1)
        # print('\npassage size:', embedded_passage.shape)
        #embedded_question = self._bert_projection(embedded_question)
        #embedded_passage = self._bert_projection(embedded_passage)
        #print('size embedded_passage:', embedded_passage.shape)
        # mask
        ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float()
        context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float()

        # gate prediction
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        # Shape(gate_mask): (batch_size, num_spans)
        #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask)
        gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(embedded_passage, context_mask,
                                                                         self._gate_self_attention_layer,
                                                                         self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = context_mask.size()

        loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
                          sent_labels.long().view(batch_size * num_spans), ignore_index=-1)

        gate = (gate >= 0.3).long()
        gate = gate.view(batch_size, num_spans)

        output_dict = {
            "pred_sent_labels": gate, #[B, num_span]
            "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span]
        }
        if self._output_att_scores:
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # Compute the loss for training.
        try:
            #loss = strong_sup_loss
            self._loss_trackers['loss'](loss)
            output_dict["loss"] = loss
        except RuntimeError:
            print('\n meta_data:', metadata)
            print(span_start_logits.shape)

        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            #token_spans_sp = []
            #token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            ids = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_sent_tokens'])
                #token_spans_sp.append(metadata[i]['token_spans_sp'])
                #token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                #offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0])
                ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_gate = gate[i].detach().cpu().numpy()
                    if any([pred_sent_gate[s_idx-1] > 0 for s_idx in metadata[i]['ans_sent_idxs']]):
                        self.evd_ans_metric(1)
                    else:
                        self.evd_ans_metric(0)
            self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1))
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_sent_tokens'] = passage_tokens
            #output_dict['token_spans_sp'] = token_spans_sp
            #output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['_id'] = ids

        return output_dict
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            evidence: torch.IntTensor = None,
            pad_idx=-1,
            max_select=5,
            gamma=0.95,
            teacher_forcing_ratio=1,
            features=None,
            metadata=None) -> Dict[str, torch.Tensor]:

        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        evidence : torch.IntTensor, optional (default = None)
            From a ``ListField``
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        #print([int(i.data[0]) for i in premise['tokens'][0,0]])

        premise_mask = get_text_field_mask(premise,
                                           num_wrapping_dims=1).float()

        hypothesis_mask = get_text_field_mask(hypothesis).float()

        aggregated_input = self._sentence_selection_esim(premise,
                                                         hypothesis,
                                                         premise_mask,
                                                         hypothesis_mask,
                                                         wrap_output=True,
                                                         features=features)

        batch_size, num_evidence, max_premise_length = premise_mask.shape
        #print(premise_mask.shape)
        aggregated_input = aggregated_input.view(batch_size, num_evidence, -1)
        evidence_mask = premise_mask.sum(dim=-1).gt(0)
        evidence_len = evidence_mask.view(batch_size, -1).sum(dim=-1)
        #print(aggregated_input.shape)
        #print(evidence_len)

        #for each element in the batch
        valid_indices = []
        indices = []
        probs = []
        baselines = []
        states = []
        selected_evidence_lengths = []
        for i in range(evidence.size(0)):
            #print(label[i].data[0], evidence[i])

            gold_evidence = None
            #teacher forcing, give a list of indices and get the probabilities
            #print(label[i])
            try:
                curr_label = label[i].data[0]
            except IndexError:
                curr_label = label[i].item()

            if random.random(
            ) > teacher_forcing_ratio and curr_label != self._nei_label and float(
                    evidence[i].ne(pad_idx).sum()) > 0:
                gold_evidence = evidence[i]
            #print(gold_evidence)

            output = self._ptr_extract_summ(aggregated_input[i],
                                            max_select,
                                            evidence_mask[i],
                                            gold_evidence,
                                            beam_size=self._beam_size)
            #print(output['states'].shape)
            #print(idxs)
            states.append(output.get('states', []))

            valid_idx = []
            try:
                curr_evidence_len = evidence_len[i].data[0]
            except IndexError:
                curr_evidence_len = evidence_len[i].item()
            for idx in output['idxs'][:min(max_select, curr_evidence_len)]:
                try:
                    curr_idx = idx.view(-1).data[0]
                except IndexError:
                    curr_idx = idx.view(-1).item()

                if curr_idx == num_evidence:
                    break
                valid_idx.append(curr_idx)

                if valid_idx[-1] >= curr_evidence_len:
                    valid_idx[-1] = 0

            #TODO: if it selects none, use the first one?

            selected_evidence_lengths.append(len(valid_idx))
            #print(selected_evidence_lengths[-1])
            indices.append(valid_idx)
            if 'scores' in output:
                baselines.append(output['scores'][:len(valid_idx)])
            if 'probs' in output:
                probs.append(output['probs'][:len(valid_idx)])

            valid_indices.append(torch.LongTensor(valid_idx + \
                                             [-1]*(max_select-len(valid_idx))))
        '''
        for q in range(label.size(0)):
            if selected_evidence_lengths[q] >= 5:
                continue
            print(label[q])
            print(evidence[q])
            print(valid_indices[q])
            if len(baselines):
                print(probs[q][0].probs)            
                print(baselines[q])
        '''

        output_dict = {'predicted_sentences': torch.stack(valid_indices)}

        predictions = torch.autograd.Variable(torch.stack(valid_indices))

        selected_premise = {}
        index = predictions.unsqueeze(2).expand(batch_size, max_select,
                                                max_premise_length)
        #B x num_selected
        l = torch.autograd.Variable(
            len_mask(selected_evidence_lengths,
                     max_len=max_select,
                     dtype=torch.FloatTensor))

        index = index * l.long().unsqueeze(-1)
        if torch.cuda.is_available() and premise_mask.is_cuda:
            idx = premise_mask.get_device()
            index = index.cuda(idx)
            l = l.cuda(idx)
            predictions = predictions.cuda(idx)

        if self._use_decoder_states:
            states = torch.cat(states, dim=0)
            label_sequence = make_label_sequence(predictions,
                                                 evidence,
                                                 label,
                                                 pad_idx=pad_idx,
                                                 nei_label=self._nei_label)
            #print(states.shape)
            batch_size, max_length, _ = states.shape
            label_logits = self._entailment_esim(
                features=states.view(batch_size * max_length, 1, -1))
            if 'loss' not in output_dict:
                output_dict['loss'] = 0
            output_dict['loss'] += sequence_loss(label_logits.view(
                batch_size, max_length, -1),
                                                 label_sequence,
                                                 self._evidence_loss,
                                                 pad_idx=pad_idx)
            output_dict['label_sequence_logits'] = label_logits.view(
                batch_size, max_length, -1)
            label_logits = output_dict['label_sequence_logits'][:, -1, :]
        else:
            for key in premise:
                selected_premise[key] = torch.gather(premise[key],
                                                     dim=1,
                                                     index=index)

            selected_mask = torch.gather(premise_mask, dim=1, index=index)

            selected_mask = selected_mask * l.unsqueeze(-1)

            selected_features = None
            if features is not None:
                index = predictions.unsqueeze(2).expand(
                    batch_size, max_select, features.size(-1))
                index = index * l.long().unsqueeze(-1)
                selected_features = torch.gather(features, dim=1, index=index)

            #UNDO!!!!!
            selected_features = selected_features[:, :, :200]
            label_logits = self._entailment_esim(selected_premise,
                                                 hypothesis,
                                                 premise_mask=selected_mask,
                                                 features=selected_features)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        #print(label_probs[0])
        '''
        key = 'tokens'
        for q in range(premise[key].size(0)):
            print(index[q,:,0])
            print([int(i.data[0]) for i in hypothesis[key][q]])
            print([self.vocab._index_to_token[key][i.data[0]] for i in hypothesis[key][q]])
            print([int(i.data[0]) for i in premise[key][q,0]])
            print([self.vocab._index_to_token[key][i.data[0]] for i in premise[key][q,0]])
            print([self.vocab._index_to_token[key][i.data[0]] for i in premise[key][q,index[q,0,0].data[0]]])            
            print([self.vocab._index_to_token[key][i.data[0]] for i in selected_premise[key][q,0]])
        
            print([int(i.data[0]) for i in premise_mask[q,0]])
            print(l[q])
            print([int(i.data[0]) for i in premise_mask[q,index[q,0,0].data[0]]])            
            for z in range(5):
                print([int(i.data[0]) for i in selected_mask[q,z]])

            print(label[q], label_probs[q])
        '''

        output_dict.update({
            "label_logits": label_logits,
            "label_probs": label_probs
        })

        #get fever score, recall, and accuracy

        if len(label.shape) > 1:
            self._accuracy(label_logits, label.squeeze(-1))
        else:
            self._accuracy(label_logits, label)

        fever_reward = self._fever(label_logits,
                                   label.squeeze(-1),
                                   predictions,
                                   evidence,
                                   indices=True,
                                   pad_idx=pad_idx,
                                   metadata=metadata)

        if not self._fix_sentence_extraction_params:
            #multiply the reward for the support/refute labels by a constant so that the model selects the correct evidence instead of just trying to predict the not enough info labels
            fever_reward = fever_reward * label.squeeze(-1).ne(
                self._nei_label
            ) * self._ei_reward_weight + fever_reward * label.squeeze(-1).eq(
                self._nei_label)

            #compute discounted reward
            rewards = []
            #print(fever_reward[0])
            avg_reward = 0
            for i in range(evidence.size(0)):
                avg_reward += float(fever_reward[i])
                #rewards.append(gamma ** torch.range(selected_evidence_lengths[i]-1,0,-1) * float(fever_reward[i]))
                rewards.append(
                    gamma**torch.arange(selected_evidence_lengths[i]).float() *
                    fever_reward[i].float())
            #print(fever_reward[0])
            #print(rewards[0])

            reward = torch.autograd.Variable(torch.cat(rewards),
                                             requires_grad=False)
            if torch.cuda.is_available() and fever_reward.is_cuda:
                idx = fever_reward.get_device()
                reward = reward.cuda(idx)

            #print(reward)
            if len(baselines):
                indices = list(itertools.chain(*indices))
                probs = list(itertools.chain(*probs))
                baselines = list(itertools.chain(*baselines))

                #print(baselines)

                # standardize rewards
                reward = (reward - reward.mean()) / (
                    reward.std() + float(np.finfo(np.float32).eps))

                #print(reward)
                baseline = torch.cat(baselines).squeeze()
                avg_advantage = 0
                losses = []
                for action, p, r, b in zip(indices, probs, reward, baseline):
                    #print(action, p, r, b)
                    action = torch.autograd.Variable(torch.LongTensor([action
                                                                       ]))
                    if torch.cuda.is_available() and r.is_cuda:
                        idx = r.get_device()
                        action = action.cuda(idx)

                    advantage = r - b
                    #print(r, b, advantage)
                    avg_advantage += advantage
                    losses.append(-p.log_prob(action) *
                                  (advantage / len(indices)))  # divide by T*B
                    #print(losses[-1])

                critic_loss = F.mse_loss(baseline, reward)

                output_dict['loss'] = critic_loss + sum(losses)

                #output_dict['loss'].backward(retain_graph=True)
                #grad_log = self.grad_fn()
                #print(grad_log)

                try:
                    output_dict['advantage'] = avg_advantage.data[0] / len(
                        indices)
                    output_dict['mse'] = critic_loss.data[0]
                except IndexError:
                    output_dict['advantage'] = avg_advantage.item() / len(
                        indices)
                    output_dict['mse'] = critic_loss.item()

            #output_dict['reward'] = avg_reward / evidence.size(0)

        if self.training and self._train_gold_evidence:

            if 'loss' not in output_dict:
                output_dict['loss'] = 0
            if evidence.sum() != -1 * torch.numel(evidence):
                if len(evidence.shape) > 2:
                    evidence = evidence.squeeze(-1)
                #print(evidence_len.long().data.cpu().numpy().tolist())
                #print(evidence.shape, evidence_len.shape)
                #print(evidence, evidence_len)
                output = self._ptr_extract_summ(
                    aggregated_input, None, None, evidence,
                    evidence_len.long().data.cpu().numpy().tolist())
                #print(output['states'].shape)

                loss = sequence_loss(output['scores'][:, :-1, :],
                                     evidence,
                                     self._evidence_loss,
                                     pad_idx=pad_idx)

                output_dict['loss'] += self.lambda_weight * loss

        if not self._fix_entailment_params:
            if self._use_decoder_states:
                if self.training:
                    label_sequence = make_label_sequence(
                        evidence,
                        evidence,
                        label,
                        pad_idx=pad_idx,
                        nei_label=self._nei_label)
                    batch_size, max_length, _ = output['states'].shape
                    label_logits = self._entailment_esim(
                        features=output['states'][:, 1:, :].contiguous().view(
                            batch_size * (max_length - 1), 1, -1))
                    if 'loss' not in output_dict:
                        output_dict['loss'] = 0
                    #print(label_logits.shape, label_sequence.shape)
                    output_dict['loss'] += sequence_loss(label_logits.view(
                        batch_size, max_length - 1, -1),
                                                         label_sequence,
                                                         self._evidence_loss,
                                                         pad_idx=pad_idx)
            else:
                #TODO: only update classifier if we have correct evidence
                #evidence_reward = self._fever_evidence_only(label_logits, label.squeeze(-1),
                #                                            predictions, evidence,
                #                                            indices=True, pad_idx=pad_idx)
                ###print(evidence_reward)
                ###print(label)
                #mask = evidence_reward > 0
                #target = mask * label.byte() + mask.eq(0) * self._nei_label

                mask = fever_reward != 2**7
                target = label.view(-1).masked_select(mask)

                ###print(target)

                mask = fever_reward != 2**7
                logit = label_logits.masked_select(
                    mask.unsqueeze(1).expand_as(
                        label_logits)).contiguous().view(
                            -1, label_logits.size(-1))

                loss = self._loss(
                    logit,
                    target.long())  #label_logits, label.long().view(-1))
                if 'loss' in output_dict:
                    output_dict["loss"] += self.lambda_weight * loss
                else:
                    output_dict["loss"] = self.lambda_weight * loss

        return output_dict