def test_tie_break_categorical_accuracy(self):
        accuracy = CategoricalAccuracy(tie_break=True)
        predictions = torch.Tensor([[0.35, 0.25, 0.35, 0.35, 0.35],
                                    [0.1, 0.6, 0.1, 0.2, 0.2],
                                    [0.1, 0.0, 0.1, 0.2, 0.2]])
        # Test without mask:
        targets = torch.Tensor([2, 1, 4])
        accuracy(predictions, targets)
        assert accuracy.get_metric(reset=True) == (0.25 + 1 + 0.5)/3.0

        # # # Test with mask
        mask = torch.Tensor([1, 0, 1])
        targets = torch.Tensor([2, 1, 4])
        accuracy(predictions, targets, mask)
        assert accuracy.get_metric(reset=True) == (0.25 + 0.5)/2.0

        # # Test tie-break with sequence
        predictions = torch.Tensor([[[0.35, 0.25, 0.35, 0.35, 0.35],
                                     [0.1, 0.6, 0.1, 0.2, 0.2],
                                     [0.1, 0.0, 0.1, 0.2, 0.2]],
                                    [[0.35, 0.25, 0.35, 0.35, 0.35],
                                     [0.1, 0.6, 0.1, 0.2, 0.2],
                                     [0.1, 0.0, 0.1, 0.2, 0.2]]])
        targets = torch.Tensor([[0, 1, 3],  # 0.25 + 1 + 0.5
                                [0, 3, 4]]) # 0.25 + 0 + 0.5 = 2.5
        accuracy(predictions, targets)
        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 2.5/6.0)
Exemple #2
0
class CoLATask(Task):
    '''Class for Warstdadt acceptability task'''
    def __init__(self, path, max_seq_len, name="acceptability"):
        ''' '''
        super(CoLATask, self).__init__(name, 2)
        self.pair_input = 0
        self.load_data(path, max_seq_len)
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.scorer1 = Average()
        self.scorer2 = CategoricalAccuracy()

    def load_data(self, path, max_seq_len):
        '''Load the data'''
        tr_data = load_tsv(os.path.join(path, "train.tsv"), max_seq_len,
                           s1_idx=3, s2_idx=None, targ_idx=1)
        val_data = load_tsv(os.path.join(path, "dev.tsv"), max_seq_len,
                            s1_idx=3, s2_idx=None, targ_idx=1)
        te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len,
                           s1_idx=1, s2_idx=None, targ_idx=None, idx_idx=0, skip_rows=1)
        self.train_data_text = tr_data
        self.val_data_text = val_data
        self.test_data_text = te_data
        log.info("\tFinished loading CoLA.")

    def get_metrics(self, reset=False):
        # NB: I think I call it accuracy b/c something weird in training
        return {'accuracy': self.scorer1.get_metric(reset),
                'acc': self.scorer2.get_metric(reset)}
Exemple #3
0
class LstmTagger(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        self.accuracy = CategoricalAccuracy()
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(sentence)
        embeddings = self.word_embeddings(sentence)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}
 def test_top_k_categorical_accuracy(self):
     accuracy = CategoricalAccuracy(top_k=2)
     predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2],
                                 [0.1, 0.6, 0.1, 0.2, 0.0]])
     targets = torch.Tensor([0, 3])
     accuracy(predictions, targets)
     actual_accuracy = accuracy.get_metric()
     assert actual_accuracy == 1.0
 def test_top_k_categorical_accuracy_respects_mask(self):
     accuracy = CategoricalAccuracy(top_k=2)
     predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2],
                                 [0.1, 0.6, 0.1, 0.2, 0.0],
                                 [0.1, 0.2, 0.5, 0.2, 0.0]])
     targets = torch.Tensor([0, 3, 0])
     mask = torch.Tensor([0, 1, 1])
     accuracy(predictions, targets, mask)
     actual_accuracy = accuracy.get_metric()
     assert actual_accuracy == 0.50
    def test_top_k_categorical_accuracy_works_for_sequences(self):
        accuracy = CategoricalAccuracy(top_k=2)
        predictions = torch.Tensor([[[0.35, 0.25, 0.1, 0.1, 0.2],
                                     [0.1, 0.6, 0.1, 0.2, 0.0],
                                     [0.1, 0.6, 0.1, 0.2, 0.0]],
                                    [[0.35, 0.25, 0.1, 0.1, 0.2],
                                     [0.1, 0.6, 0.1, 0.2, 0.0],
                                     [0.1, 0.6, 0.1, 0.2, 0.0]]])
        targets = torch.Tensor([[0, 3, 4],
                                [0, 1, 4]])
        accuracy(predictions, targets)
        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 0.6666666)

        # Test the same thing but with a mask:
        mask = torch.Tensor([[0, 1, 1],
                             [1, 0, 1]])
        accuracy(predictions, targets, mask)
        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 0.50)
 def test_top_k_categorical_accuracy_accumulates_and_resets_correctly(self):
     accuracy = CategoricalAccuracy(top_k=2)
     predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2],
                                 [0.1, 0.6, 0.1, 0.2, 0.0]])
     targets = torch.Tensor([0, 3])
     accuracy(predictions, targets)
     accuracy(predictions, targets)
     accuracy(predictions, torch.Tensor([4, 4]))
     accuracy(predictions, torch.Tensor([4, 4]))
     actual_accuracy = accuracy.get_metric(reset=True)
     assert actual_accuracy == 0.50
     assert accuracy.correct_count == 0.0
     assert accuracy.total_count == 0.0
class LstmTagger(Model):
    #### One thing that might seem unusual is that we're going pass in the embedder and the sequence encoder as constructor parameters. This allows us to experiment with different embedders and encoders without having to change the model code.
    def __init__(self,
                 #### The embedding layer is specified as an AllenNLP <code>TextFieldEmbedder</code> which represents a general way of turning tokens into tensors. (Here we know that we want to represent each unique word with a learned tensor, but using the general class allows us to easily experiment with different types of embeddings, for example <a href = "https://allennlp.org/elmo">ELMo</a>.)
                 word_embeddings: TextFieldEmbedder,
                 #### Similarly, the encoder is specified as a general <code>Seq2SeqEncoder</code> even though we know we want to use an LSTM. Again, this makes it easy to experiment with other sequence encoders, for example a Transformer.
                 encoder: Seq2SeqEncoder,
                 #### Every AllenNLP model also expects a <code>Vocabulary</code>, which contains the namespaced mappings of tokens to indices and labels to indices.
                 vocab: Vocabulary) -> None:
        #### Notice that we have to pass the vocab to the base class constructor.
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        #### The feed forward layer is not passed in as a parameter, but is constructed by us. Notice that it looks at the encoder to find the correct input dimension and looks at the vocabulary (and, in particular, at the label -> index mapping) to find the correct output dimension.
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        #### The last thing to notice is that we also instantiate a <code>CategoricalAccuracy</code> metric, which we'll use to track accuracy during each training and validation epoch.
        self.accuracy = CategoricalAccuracy()

    #### Next we need to implement <code>forward</code>, which is where the actual computation happens. Each <code>Instance</code> in your dataset will get (batched with other instances and) fed into <code>forward</code>. The <code>forward</code> method expects dicts of tensors as input, and it expects their names to be the names of the fields in your <code>Instance</code>. In this case we have a sentence field and (possibly) a labels field, so we'll construct our <code>forward</code> accordingly:
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        #### AllenNLP is designed to operate on batched inputs, but different input sequences have different lengths. Behind the scenes AllenNLP is padding the shorter inputs so that the batch has uniform shape, which means our computations need to use a mask to exclude the padding. Here we just use the utility function <code>get_text_field_mask</code>, which returns a tensor of 0s and 1s corresponding to the padded and unpadded locations.
        mask = get_text_field_mask(sentence)
        #### We start by passing the <code>sentence</code> tensor (each sentence a sequence of token ids) to the <code>word_embeddings</code> module, which converts each sentence into a sequence of embedded tensors.
        embeddings = self.word_embeddings(sentence)
        #### We next pass the embedded tensors (and the mask) to the LSTM, which produces a sequence of encoded outputs.
        encoder_out = self.encoder(embeddings, mask)
        #### Finally, we pass each encoded output tensor to the feedforward layer to produce logits corresponding to the various tags.
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}

        #### As before, the labels were optional, as we might want to run this model to make predictions on unlabeled data. If we do have labels, then we use them to update our accuracy metric and compute the "loss" that goes in our output.
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output

    #### We included an accuracy metric that gets updated each forward pass. That means we need to override a <code>get_metrics</code> method that pulls the data out of it. Behind the scenes, the <code>CategoricalAccuracy</code> metric is storing the number of predictions and the number of correct predictions, updating those counts during each call to forward. Each call to get_metric returns the calculated accuracy and (optionally) resets the counts, which is what allows us to track accuracy anew for each epoch.
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}
Exemple #9
0
class Task():
    '''Abstract class for a task

    Methods and attributes:
        - load_data: load dataset from a path and create splits
        - yield dataset for training
        - dataset size
        - validate and test

    Outside the task:
        - process: pad and indexify data given a mapping
        - optimizer
    '''
    __metaclass__ = ABCMeta

    def __init__(self, name, n_classes):
        self.name = name
        self.n_classes = n_classes
        self.train_data_text, self.val_data_text, self.test_data_text = \
            None, None, None
        self.train_data = None
        self.val_data = None
        self.test_data = None
        self.pred_layer = None
        self.pair_input = 1
        self.categorical = 1 # most tasks are
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.scorer1 = CategoricalAccuracy()
        self.scorer2 = None

    @abstractmethod
    def load_data(self, path, max_seq_len):
        '''
        Load data from path and create splits.
        '''
        raise NotImplementedError

    def get_metrics(self, reset=False):
        '''Get metrics specific to the task'''
        acc = self.scorer1.get_metric(reset)
        return {'accuracy': acc}
Exemple #10
0
class DialogQA(Model):
    """
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    max_turn_length: ``int``, optional (default=12)
        Maximum length of an interaction.
    """

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30,
                 max_turn_length: int = 12) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(max_turn_length,
                                                           marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3,
                                                                     self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(phrase_layer.get_input_dim(),
                               text_field_embedder.get_output_dim() +
                               marker_embedding_dim * num_context_answers,
                               "phrase layer input dim",
                               "embedding dim + marker dim * num context answers")

        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \
                      for yn_list in output_dict.pop("yesno")]
        followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \
                         for followup_list in output_dict.pop("followup")]
        output_dict['yesno'] = yesno_tags
        output_dict['followup'] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'yesno': self._span_yesno_accuracy.get_metric(reset),
                'followup': self._span_followup_accuracy.get_metric(reset),
                'f1': self._official_f1.get_metric(reset), }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      span_followup_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long)

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()
        span_followup_logits = span_followup_logits.data.cpu().numpy()
        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
span_start_accuracy_function = CategoricalAccuracy()
span_end_accuracy_function = CategoricalAccuracy()
span_accuracy_function = BooleanAccuracy()
squad_metrics_function = SquadEmAndF1()

# Compute the loss for training.
if span_start is not None:
    span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
    span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
    loss = span_start_loss + span_end_loss
    
    span_start_accuracy_function(span_start_logits, span_start.squeeze(-1))
    span_end_accuracy_function(span_end_logits, span_end.squeeze(-1))
    span_accuracy_function(best_span, torch.stack([span_start, span_end], -1))

    span_start_accuracy = span_start_accuracy_function.get_metric()
    span_end_accuracy =  span_end_accuracy_function.get_metric()
    span_accuracy = span_accuracy_function.get_metric()


    print ("Loss: ", loss)
    print ("span_start_accuracy: ", span_start_accuracy)
    print ("span_start_accuracy: ", span_start_accuracy)
    print ("span_end_accuracy: ", span_end_accuracy)
    
# Compute the EM and F1 on SQuAD and add the tokenized input to the output.
if metadata is not None:
    best_span_str = []
    question_tokens = []
    passage_tokens = []
    for i in range(batch_size):
Exemple #12
0
class ESIM(Model):
    """
    This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference"
    <https://www.semanticscholar.org/paper/Enhanced-LSTM-for-Natural-Language-Inference-Chen-Zhu/83e7654d545fbbaaf2328df365a781fb67b841b4>`_
    by Chen et al., 2017.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
        model.
    encoder : ``Seq2SeqEncoder``
        Used to encode the premise and hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between encoded
        words in the premise and words in the hypothesis.
    projection_feedforward : ``FeedForward``
        The feedforward network used to project down the encoded and enhanced premise and hypothesis.
    inference_encoder : ``Seq2SeqEncoder``
        Used to encode the projected premise and hypothesis for prediction.
    output_feedforward : ``FeedForward``
        Used to prepare the concatenated premise and hypothesis for prediction.
    output_logit : ``FeedForward``
        This feedforward network computes the output logits.
    dropout : ``float``, optional (default=0.5)
        Dropout percentage to use.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        similarity_function: SimilarityFunction,
        projection_feedforward: FeedForward,
        inference_encoder: Seq2SeqEncoder,
        output_feedforward: FeedForward,
        output_logit: FeedForward,
        dropout: float = 0.5,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._encoder = encoder

        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._projection_feedforward = projection_feedforward

        self._inference_encoder = inference_encoder

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.rnn_input_dropout = InputVariationalDropout(dropout)
        else:
            self.dropout = None
            self.rnn_input_dropout = None

        self._output_feedforward = output_feedforward
        self._output_logit = output_logit

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(
            text_field_embedder.get_output_dim(),
            encoder.get_input_dim(),
            "text field embedding dim",
            "encoder input dim",
        )
        check_dimensions_match(
            encoder.get_output_dim() * 4,
            projection_feedforward.get_input_dim(),
            "encoder output dim",
            "projection feedforward input",
        )
        check_dimensions_match(
            projection_feedforward.get_output_dim(),
            inference_encoder.get_input_dim(),
            "proj feedforward output dim",
            "inference lstm input dim",
        )

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    def forward(  # type: ignore
        self,
        premise: Dict[str, torch.LongTensor],
        hypothesis: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis,
                                           hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise,
                                                   encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
            [
                encoded_premise,
                attended_hypothesis,
                encoded_premise - attended_hypothesis,
                encoded_premise * attended_hypothesis,
            ],
            dim=-1,
        )
        hypothesis_enhanced = torch.cat(
            [
                encoded_hypothesis,
                attended_premise,
                encoded_hypothesis - attended_premise,
                encoded_hypothesis * attended_premise,
            ],
            dim=-1,
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(
            premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(
            hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(
                projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(
                projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise,
                                       premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis,
                                       hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(v_ai, premise_mask.unsqueeze(-1),
                                           -1e7).max(dim=1)
        v_b_max, _ = replace_masked_values(v_bi, hypothesis_mask.unsqueeze(-1),
                                           -1e7).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1),
                            dim=1) / torch.sum(premise_mask, 1, keepdim=True)
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1),
                            dim=1) / torch.sum(
                                hypothesis_mask, 1, keepdim=True)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self._accuracy.get_metric(reset)}
 def test_does_not_divide_by_zero_with_no_count(self):
     accuracy = CategoricalAccuracy()
     self.assertAlmostEqual(accuracy.get_metric(), 0.0)
class DecomposableAttention(Model):
    """
    This ``Model`` implements the Decomposable Attention model described in `"A Decomposable
    Attention Model for Natural Language Inference"
    <https://www.semanticscholar.org/paper/A-Decomposable-Attention-Model-for-Natural-Languag-Parikh-T%C3%A4ckstr%C3%B6m/07a9478e87a8304fc3267fa16e83e9f3bbd98b27>`_
    by Parikh et al., 2016, with some optional enhancements before the decomposable attention
    actually happens.  Parikh's original model allowed for computing an "intra-sentence" attention
    before doing the decomposable entailment step.  We generalize this to any
    :class:`Seq2SeqEncoder` that can be applied to the premise and/or the hypothesis before
    computing entailment.
    The basic outline of this model is to get an embedded representation of each word in the
    premise and hypothesis, align words between the two, compare the aligned phrases, and make a
    final entailment decision based on this aggregated comparison.  Each step in this process uses
    a feedforward network to modify the representation.
    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
        model.
    attend_feedforward : ``FeedForward``
        This feedforward network is applied to the encoded sentence representations before the
        similarity matrix is computed between words in the premise and words in the hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between words in
        the premise and words in the hypothesis.
    compare_feedforward : ``FeedForward``
        This feedforward network is applied to the aligned premise and hypothesis representations,
        individually.
    aggregate_feedforward : ``FeedForward``
        This final feedforward network is applied to the concatenated, summed result of the
        ``compare_feedforward`` network, and its output is used as the entailment class logits.
    premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the premise, we can optionally apply an encoder.  If this is ``None``, we
        will do nothing.
    hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the hypothesis, we can optionally apply an encoder.  If this is ``None``,
        we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder``
        is also ``None``).
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 attend_feedforward: FeedForward,
                 similarity_function: SimilarityFunction,
                 compare_feedforward: FeedForward,
                 aggregate_feedforward: FeedForward,
                 premise_encoder: Optional[Seq2SeqEncoder] = None,
                 hypothesis_encoder: Optional[Seq2SeqEncoder] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DecomposableAttention, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._attend_feedforward = TimeDistributed(attend_feedforward)
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._compare_feedforward = TimeDistributed(compare_feedforward)
        self._aggregate_feedforward = aggregate_feedforward
        self._premise_encoder = premise_encoder
        self._hypothesis_encoder = hypothesis_encoder or premise_encoder

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(text_field_embedder.get_output_dim(),
                               attend_feedforward.get_input_dim(),
                               "text field embedding dim",
                               "attend feedforward input dim")
        check_dimensions_match(aggregate_feedforward.get_output_dim(),
                               self._num_labels, "final output dimension",
                               "number of labels")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        Returns
        -------
        An output dictionary consisting of:
        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,
                                                   projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * \
            hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)

        output_dict = {
            "h2p_attention": h2p_attention,
            "p2h_attention": p2h_attention,
            "final_hidden": aggregate_input,
        }

        if self._aggregate_feedforward:
            label_logits = self._aggregate_feedforward(aggregate_input)
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
            output_dict["label_logits"] = label_logits
            output_dict["label_probs"] = label_probs

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'accuracy': self._accuracy.get_metric(reset),
        }

    @classmethod
    def from_params(cls, vocab: Vocabulary,
                    params: Params) -> 'DecomposableAttention':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, embedder_params)

        premise_encoder_params = params.pop("premise_encoder", None)
        premise_encoder = Seq2SeqEncoder.from_params(
            premise_encoder_params) if premise_encoder_params else None

        hypothesis_encoder_params = params.pop("hypothesis_encoder", None)
        hypothesis_encoder = Seq2SeqEncoder.from_params(
            hypothesis_encoder_params) if hypothesis_encoder_params else None

        attend_feedforward = FeedForward.from_params(
            params.pop('attend_feedforward'))
        similarity_function = SimilarityFunction.from_params(
            params.pop("similarity_function"))
        compare_feedforward = FeedForward.from_params(
            params.pop('compare_feedforward'))

        aggregated_params = params.pop('aggregate_feedforward', None)
        aggregate_feedforward = FeedForward.from_params(
            aggregated_params) if aggregated_params else None

        initializer = InitializerApplicator.from_params(
            params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(
            params.pop('regularizer', []))

        params.assert_empty(cls.__name__)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   attend_feedforward=attend_feedforward,
                   similarity_function=similarity_function,
                   compare_feedforward=compare_feedforward,
                   aggregate_feedforward=aggregate_feedforward,
                   premise_encoder=premise_encoder,
                   hypothesis_encoder=hypothesis_encoder,
                   initializer=initializer,
                   regularizer=regularizer)

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        # add label to output
        argmax_indices = output_dict['label_probs'].max(dim=-1)[1].data.numpy()
        output_dict['label'] = [
            self.vocab.get_token_from_index(x, namespace="labels")
            for x in argmax_indices
        ]
        # do not show last hidden layer
        del output_dict["final_hidden"]
        return output_dict
Exemple #15
0
class TweetJointly(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        transformer_model_name: str = "bert-base-uncased",
        feedforward: Optional[FeedForward] = None,
        smoothing: bool = False,
        smooth_alpha: float = 0.7,
        sentiment_task: bool = False,
        sentiment_task_weight: float = 1.0,
        sentiment_classification_with_label: bool = True,
        sentiment_seq2vec: Optional[Seq2VecEncoder] = None,
        candidate_span_task: bool = False,
        candidate_span_task_weight: float = 1.0,
        candidate_delay: int = 30000,
        candidate_span_num: int = 5,
        candidate_classification_layer_units: int = 128,
        candidate_span_extractor: Optional[SpanExtractor] = None,
        candidate_span_with_logits: bool = False,
        dropout: Optional[float] = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        if "BERTweet" not in transformer_model_name:
            self._text_field_embedder = BasicTextFieldEmbedder({
                "tokens":
                PretrainedTransformerEmbedder(transformer_model_name)
            })
        else:
            self._text_field_embedder = BasicTextFieldEmbedder(
                {"tokens": TweetBertEmbedder(transformer_model_name)})
        # span start & end task
        if feedforward is None:
            self._linear_layer = nn.Sequential(
                nn.Linear(self._text_field_embedder.get_output_dim(), 128),
                nn.ReLU(),
                nn.Linear(128, 2),
            )
        else:
            self._linear_layer = feedforward
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._jaccard = Jaccard()
        self._candidate_delay = candidate_delay
        self._delay = 0

        self._smoothing = smoothing
        self._smooth_alpha = smooth_alpha
        if smoothing:
            self._loss = nn.KLDivLoss(reduction="batchmean")
        else:
            self._loss = nn.CrossEntropyLoss()

        # sentiment task
        self._sentiment_task = sentiment_task
        if self._sentiment_task:
            self._sentiment_classification_accuracy = CategoricalAccuracy()
            self._sentiment_loss_log = LossLog()
            self.register_buffer("sentiment_task_weight",
                                 torch.tensor(sentiment_task_weight))
            self._sentiment_classification_with_label = (
                sentiment_classification_with_label)
            if sentiment_seq2vec is None:
                raise ConfigurationError(
                    "sentiment task is True, we need a sentiment seq2vec encoder"
                )
            else:
                self._sentiment_encoder = sentiment_seq2vec
                self._sentiment_linear = nn.Linear(
                    self._sentiment_encoder.get_output_dim(),
                    vocab.get_vocab_size("labels"),
                )

        # candidate span task
        self._candidate_span_task = candidate_span_task
        if candidate_span_task:
            assert candidate_span_num > 0
            assert candidate_span_task_weight > 0
            assert candidate_classification_layer_units > 0
            self._candidate_span_num = candidate_span_num
            self.register_buffer("candidate_span_task_weight",
                                 torch.tensor(candidate_span_task_weight))
            self._candidate_classification_layer_units = (
                candidate_classification_layer_units)
            self._span_classification_accuracy = CategoricalAccuracy()
            self._candidate_loss_log = LossLog()
            self._candidate_span_linear = nn.Linear(
                self._text_field_embedder.get_output_dim(),
                self._candidate_classification_layer_units,
            )

            if candidate_span_extractor is None:
                self._candidate_span_extractor = EndpointSpanExtractor(
                    input_dim=self._candidate_classification_layer_units)
            else:
                self._candidate_span_extractor = candidate_span_extractor

            if candidate_span_with_logits:
                self._candidate_with_logits = True
                self._candidate_span_vec_linear = nn.Linear(
                    self._candidate_span_extractor.get_output_dim() + 1, 1)
            else:
                self._candidate_with_logits = False
                self._candidate_span_vec_linear = nn.Linear(
                    self._candidate_span_extractor.get_output_dim(), 1)

            self._candidate_jaccard = Jaccard()

        if sentiment_task or candidate_span_task:
            self._base_loss_log = LossLog()
        else:
            self._base_loss_log = None

        if dropout is not None:
            self._dropout = nn.Dropout(dropout)
        else:
            self._dropout = None

    def forward(  # type: ignore
        self,
        text: Dict[str, Dict[str, torch.LongTensor]],
        sentiment: torch.IntTensor,
        text_with_sentiment: Dict[str, Dict[str, torch.LongTensor]],
        text_span: torch.IntTensor,
        selected_text_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # batch_size * text_length * hidden_dims
        embedded_question = self._text_field_embedder(text_with_sentiment)
        if self._dropout is not None:
            embedded_question = self._dropout(embedded_question)
        self._delay += int(embedded_question.size(0))
        # span start & span end task
        logits = self._linear_layer(embedded_question)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        span_start_logits = span_start_logits.squeeze(-1)
        span_end_logits = span_end_logits.squeeze(-1)

        possible_answer_mask = torch.zeros_like(
            util.get_token_ids_from_text_field_tensors(
                text_with_sentiment)).bool()
        for i, (start, end) in enumerate(text_span):
            possible_answer_mask[i, start:end + 1] = True

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       possible_answer_mask,
                                                       -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     possible_answer_mask,
                                                     -1e32)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
        best_spans = get_best_span(span_start_logits, span_end_logits)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_spans,
            "best_span_scores": best_span_scores,
        }

        loss = torch.tensor(0.0).to(embedded_question.device)
        # sentiment task
        if self._sentiment_task:
            if self._sentiment_classification_with_label:
                global_context_vec = self._sentiment_encoder(embedded_question)
            else:
                embedded_only_text = self._text_field_embedder(text)
                if self._dropout is not None:
                    embedded_only_text = self._dropout(embedded_only_text)
                global_context_vec = self._sentiment_encoder(
                    embedded_only_text)
            sentiment_logits = self._sentiment_linear(global_context_vec)
            sentiment_probs = torch.softmax(sentiment_logits, dim=-1)

            self._sentiment_classification_accuracy(sentiment_probs, sentiment)
            sentiment_loss = cross_entropy(sentiment_logits, sentiment)
            self._sentiment_loss_log(sentiment_loss)
            loss.add_(self.sentiment_task_weight * sentiment_loss)

            predict_sentiment_idx = sentiment_probs.argmax(dim=-1)
            sentiment_predicts = []
            for i in predict_sentiment_idx.tolist():
                sentiment_predicts.append(
                    self.vocab.get_token_from_index(i, "labels"))
            output_dict["sentiment_logits"] = sentiment_logits
            output_dict["sentiment_probs"] = sentiment_probs
            output_dict["sentiment_predicts"] = sentiment_predicts

        # span classification
        if self._candidate_span_task and (self._delay >=
                                          self._candidate_delay):
            # shape: (batch_size, passage_length, embedding_dim)
            text_features_for_candidate = self._candidate_span_linear(
                embedded_question)
            text_features_for_candidate = torch.relu(
                text_features_for_candidate)
            with torch.no_grad():
                # batch_size * candidate_num * 2
                candidate_span = get_candidate_span(span_start_probs,
                                                    span_end_probs,
                                                    self._candidate_span_num)
                candidate_span_list = candidate_span.tolist()
                output_dict["candidate_spans"] = candidate_span_list
            if selected_text_span is not None:
                candidate_span, candidate_span_label = self.candidate_span_with_labels(
                    candidate_span, selected_text_span)
            else:
                candidate_span_label = None
            # shape: (batch_size, candidate_num, span_extractor_output_dim)
            span_feature_vec = self._candidate_span_extractor(
                text_features_for_candidate, candidate_span)

            if self._candidate_with_logits:
                candidate_span_start_logits = torch.gather(
                    span_start_logits, 1, candidate_span[:, :, 0])
                candidate_span_end_logits = torch.gather(
                    span_end_logits, 1, candidate_span[:, :, 1])
                candidate_span_sum_logits = (candidate_span_start_logits +
                                             candidate_span_end_logits)
                span_feature_vec = torch.cat(
                    (span_feature_vec, candidate_span_sum_logits.unsqueeze(2)),
                    -1)
            # batch_size * candidate_num
            span_classification_logits = self._candidate_span_vec_linear(
                span_feature_vec).squeeze()
            span_classification_probs = torch.softmax(
                span_classification_logits, -1)
            output_dict[
                "span_classification_probs"] = span_classification_probs
            candidate_best_span_idx = span_classification_probs.argmax(dim=-1)
            view_idx = (
                candidate_best_span_idx +
                torch.arange(0, end=candidate_best_span_idx.shape[0]).to(
                    candidate_best_span_idx.device) * self._candidate_span_num)
            candidate_span_view = candidate_span.view(-1, 2)
            candidate_best_spans = candidate_span_view.index_select(
                0, view_idx)
            output_dict["candidate_best_spans"] = candidate_best_spans.tolist()

            if selected_text_span is not None:
                self._span_classification_accuracy(span_classification_probs,
                                                   candidate_span_label)
                candidate_span_loss = cross_entropy(span_classification_logits,
                                                    candidate_span_label)
                self._candidate_loss_log(candidate_span_loss)
                weighted_loss = self.candidate_span_task_weight * candidate_span_loss
                if candidate_span_loss > 1e2:
                    print(f"candidate loss: {candidate_span_loss}")
                    print(
                        f"span_classification_logits: {span_classification_logits}"
                    )
                    print(f"candidate_span_label: {candidate_span_label}")
                loss.add_(weighted_loss)

            candidate_best_spans = candidate_best_spans.detach().cpu().numpy()
            output_dict["best_candidate_span_str"] = []
            for metadata_entry, best_span in zip(metadata,
                                                 candidate_best_spans):
                text_with_sentiment_tokens = metadata_entry[
                    "text_with_sentiment_tokens"]
                predicted_start, predicted_end = tuple(best_span)
                if predicted_end >= len(text_with_sentiment_tokens):
                    predicted_end = len(text_with_sentiment_tokens) - 1
                best_span_string = self.span_tokens_to_text(
                    metadata_entry["text"],
                    text_with_sentiment_tokens,
                    predicted_start,
                    predicted_end,
                )
                output_dict["best_candidate_span_str"].append(best_span_string)
                answers = metadata_entry.get("selected_text", "")
                if len(answers) > 0:
                    self._candidate_jaccard(best_span_string, answers)

        # Compute the loss for training.
        if selected_text_span is not None:
            span_start = selected_text_span[:, 0]
            span_end = selected_text_span[:, 1]
            span_mask = span_start != -1
            self._span_accuracy(
                best_spans,
                selected_text_span,
                span_mask.unsqueeze(-1).expand_as(best_spans),
            )
            if not self._smoothing:
                start_loss = cross_entropy(span_start_logits,
                                           span_start,
                                           ignore_index=-1)
                if torch.any(start_loss > 1e9):
                    logger.critical("Start loss too high (%r)", start_loss)
                    logger.critical("span_start_logits: %r", span_start_logits)
                    logger.critical("span_start: %r", span_start)
                    logger.critical("text_with_sentiment: %r",
                                    text_with_sentiment)
                    assert False

                end_loss = cross_entropy(span_end_logits,
                                         span_end,
                                         ignore_index=-1)
                if torch.any(end_loss > 1e9):
                    logger.critical("End loss too high (%r)", end_loss)
                    logger.critical("span_end_logits: %r", span_end_logits)
                    logger.critical("span_end: %r", span_end)
                    assert False
            else:
                sequence_length = span_start_logits.size(1)
                device = span_start.device
                start_distance = get_sequence_distance_from_span_endpoint(
                    sequence_length, span_start)
                start_smooth_probs = torch.exp(
                    start_distance *
                    torch.log(torch.tensor(self._smooth_alpha).to(device)))
                start_smooth_probs = start_smooth_probs * possible_answer_mask
                start_smooth_probs = start_smooth_probs / start_smooth_probs.sum(
                    -1, keepdim=True)
                span_start_log_probs = span_start_logits - torch.log(
                    torch.exp(span_start_logits).sum(-1)).unsqueeze(-1)
                end_distance = get_sequence_distance_from_span_endpoint(
                    sequence_length, span_end)
                end_smooth_probs = torch.exp(
                    end_distance *
                    torch.log(torch.tensor(self._smooth_alpha).to(device)))
                end_smooth_probs = end_smooth_probs * possible_answer_mask
                end_smooth_probs = end_smooth_probs / end_smooth_probs.sum(
                    -1, keepdim=True)
                span_end_log_probs = span_end_logits - torch.log(
                    torch.exp(span_end_logits).sum(-1)).unsqueeze(-1)
                # print(end_smooth_probs)
                # print(start_smooth_probs)
                # print(span_end_log_probs)
                # print(span_start_log_probs)
                start_loss = self._loss(span_start_log_probs,
                                        start_smooth_probs)
                end_loss = self._loss(span_end_log_probs, end_smooth_probs)

            span_start_end_loss = (start_loss + end_loss) / 2
            if self._base_loss_log is not None:
                self._base_loss_log(span_start_end_loss)
            loss.add_(span_start_end_loss)
            self._span_start_accuracy(span_start_logits, span_start, span_mask)
            self._span_end_accuracy(span_end_logits, span_end, span_mask)

            output_dict["loss"] = loss

        # compute best span jaccard
        best_spans = best_spans.detach().cpu().numpy()
        output_dict["best_span_str"] = []

        for metadata_entry, best_span in zip(metadata, best_spans):
            text_with_sentiment_tokens = metadata_entry[
                "text_with_sentiment_tokens"]

            predicted_start, predicted_end = tuple(best_span)
            best_span_string = self.span_tokens_to_text(
                metadata_entry["text"],
                text_with_sentiment_tokens,
                predicted_start,
                predicted_end,
            )
            output_dict["best_span_str"].append(best_span_string)

            answers = metadata_entry.get("selected_text", "")
            if len(answers) > 0:
                self._jaccard(best_span_string, answers)

        return output_dict

    # @staticmethod
    # def candidate_span_with_labels(
    #     candidate_span: torch.Tensor, selected_text_span: torch.Tensor
    # ) -> Tuple[torch.Tensor, torch.Tensor]:
    #     correct_span_idx = (candidate_span == selected_text_span.unsqueeze(1)).prod(-1)
    #     candidate_span_adjust = torch.where(
    #         ~(correct_span_idx.unsqueeze(-1) == 1),
    #         candidate_span,
    #         selected_text_span.unsqueeze(1),
    #     )
    #     candidate_span_label = correct_span_idx.argmax(-1)
    #     return candidate_span_adjust, candidate_span_label

    @staticmethod
    def candidate_span_with_labels(
            candidate_span: torch.Tensor, selected_text_span: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        candidate_span_label = batch_span_jaccard(
            candidate_span, selected_text_span).max(-1).indices
        return candidate_span, candidate_span_label

    @staticmethod
    def get_candidate_span_mask(candidate_span: torch.Tensor,
                                passage_length: int) -> torch.Tensor:
        device = candidate_span.device
        batch_size, candidate_num = candidate_span.size()[:-1]
        candidate_span_mask = torch.zeros(batch_size, candidate_num,
                                          passage_length).to(device)
        for i in range(batch_size):
            for j in range(candidate_num):
                span_start, span_end = candidate_span[i][j]
                candidate_span_mask[i][j][span_start:span_end + 1] = 1
        return candidate_span_mask

    @staticmethod
    def span_tokens_to_text(source_text, tokens, span_start, span_end):
        text_with_sentiment_tokens = tokens
        predicted_start = span_start
        predicted_end = span_end

        while (predicted_start >= 0
               and text_with_sentiment_tokens[predicted_start].idx is None):
            predicted_start -= 1
        if predicted_start < 0:
            logger.warning(
                f"Could not map the token '{text_with_sentiment_tokens[span_start].text}' at index "
                f"'{span_start}' to an offset in the original text.")
            character_start = 0
        else:
            character_start = text_with_sentiment_tokens[predicted_start].idx

        while (predicted_end < len(text_with_sentiment_tokens)
               and text_with_sentiment_tokens[predicted_end].idx is None):
            predicted_end -= 1

        if predicted_end >= len(text_with_sentiment_tokens):
            print(text_with_sentiment_tokens)
            print(len(text_with_sentiment_tokens))
            print(span_end)
            print(predicted_end)
            logger.warning(
                f"Could not map the token '{text_with_sentiment_tokens[span_end].text}' at index "
                f"'{span_end}' to an offset in the original text.")
            character_end = len(source_text)
        else:
            end_token = text_with_sentiment_tokens[predicted_end]
            if end_token.idx == 0:
                character_end = (end_token.idx +
                                 len(sanitize_wordpiece(end_token.text)) + 1)
            else:
                character_end = end_token.idx + len(
                    sanitize_wordpiece(end_token.text))

        best_span_string = source_text[character_start:character_end].strip()
        return best_span_string

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        jaccard = self._jaccard.get_metric(reset)
        metrics = {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "jaccard": jaccard,
        }
        if self._candidate_span_task:
            metrics[
                "candidate_span_acc"] = self._span_classification_accuracy.get_metric(
                    reset)
            metrics["candidate_jaccard"] = self._candidate_jaccard.get_metric(
                reset)
            metrics["candidate_loss"] = self._candidate_loss_log.get_metric(
                reset)
        if self._sentiment_task:
            metrics[
                "sentiment_acc"] = self._sentiment_classification_accuracy.get_metric(
                    reset)
            metrics["sentiment_loss"] = self._sentiment_loss_log.get_metric(
                reset)
        if self._base_loss_log is not None:
            metrics["base_loss"] = self._base_loss_log.get_metric(reset)
        return metrics
Exemple #16
0
class LSTMBatchNormFreezeDetGlobalNoFinalImageFull(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 option_encoder: Seq2SeqEncoder,
                 input_dropout: float = 0.3,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 ):
        super(LSTMBatchNormFreezeDetGlobalNoFinalImageFull, self).__init__(vocab)
        self.rnn_input_dropout = TimeDistributed(InputVariationalDropout(input_dropout)) if input_dropout > 0 else None
        self.detector = SimpleDetector(pretrained=True, average_pool=True, semantic=False, final_dim=512)

        # freeze everything related to conv net
        for submodule in self.detector.backbone.modules():
            # if isinstance(submodule, BatchNorm2d):
                # submodule.track_running_stats = False
            for p in submodule.parameters():
                p.requires_grad = False

        for submodule in self.detector.after_roi_align.modules():
            # if isinstance(submodule, BatchNorm2d):
                # submodule.track_running_stats = False
            for p in submodule.parameters():
                p.requires_grad = False

        self.image_BN = BatchNorm1d(512)

        self.option_encoder = TimeDistributed(option_encoder)
        self.option_BN = torch.nn.Sequential(
            BatchNorm1d(512)
        )
        self.query_BN = torch.nn.Sequential(
            BatchNorm1d(512)
        )
        self.final_mlp = torch.nn.Sequential(
            torch.nn.Linear(1024, 512),
            torch.nn.ReLU(inplace=True),
        )
        self.final_BN = torch.nn.Sequential(
            BatchNorm1d(512)
        )
        self.final_mlp_linear = torch.nn.Sequential(
            torch.nn.Linear(512,1)
        )
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)
    # recevie redundent parameters for convinence        

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """
        span_tags_fixed = torch.clamp(span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None]
 
        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
           row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1)

    def embed_span(self, span, span_tags, span_mask, object_reps):
        """
        :param span: Thing that will get embed and turned into [batch_size, ..leading_dims.., L, word_dim]
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :param span_mask: [batch_size, ..leading_dims.., span_mask
        :return:
        """
        retrieved_feats = self._collect_obj_reps(span_tags, object_reps)

        span_rep = torch.cat((span['bert'], retrieved_feats), -1)
        # add recurrent dropout here
        if self.rnn_input_dropout:
            span_rep = self.rnn_input_dropout(span_rep)

        return span_rep, retrieved_feats
    def forward(self,
                images: torch.Tensor,
                objects: torch.LongTensor,
                segms: torch.Tensor,
                boxes: torch.Tensor,
                box_mask: torch.LongTensor,
                question: Dict[str, torch.Tensor],
                question_tags: torch.LongTensor,
                question_mask: torch.LongTensor,
                answers: Dict[str, torch.Tensor],
                answer_tags: torch.LongTensor,
                answer_mask: torch.LongTensor,
                metadata: List[Dict[str, Any]] = None,
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        """
        :param images: [batch_size, 3, im_height, im_width]
        :param objects: [batch_size, max_num_objects] Padded objects
        :param boxes:  [batch_size, max_num_objects, 4] Padded boxes
        :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK
        :param question: AllenNLP representation of the question. [batch_size, num_answers, seq_length]
        :param question_tags: A detection label for each item in the Q [batch_size, num_answers, seq_length]
        :param question_mask: Mask for the Q [batch_size, num_answers, seq_length]
        :param answers: AllenNLP representation of the answer. [batch_size, num_answers, seq_length]
        :param answer_tags: A detection label for each item in the A [batch_size, num_answers, seq_length]
        :param answer_mask: Mask for the As [batch_size, num_answers, seq_length]
        :param metadata: Ignore, this is about which dataset item we're on
        :param label: Optional, which item is valid
        :return: shit
        """

        # Trim off boxes that are too long. this is an issue b/c dataparallel, it'll pad more zeros that are
        # not needed
        max_len = int(box_mask.sum(1).max().item())
        objects = objects[:, :max_len]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        segms = segms[:, :max_len]

        obj_reps = self.detector(images=images, boxes=boxes, box_mask=box_mask, classes=objects, segms=segms)

        # option part
        batch_size, num_options, padded_seq_len, _ = answers['bert'].shape
        options, option_obj_reps = self.embed_span(answers, answer_tags, answer_mask, obj_reps['obj_reps'])
        assert (options.shape == (batch_size, num_options, padded_seq_len, 1280))
        option_rep = self.option_encoder(options, answer_mask) # (batch_size, 4, seq_len, emb_len(512))
        option_rep = replace_masked_values(option_rep, answer_mask[...,None], 0)

        seq_real_length = torch.sum(answer_mask, dim=-1, dtype=torch.float) # (batch_size, 4)
        seq_real_length = seq_real_length.view(-1,1) # (batch_size * 4,1)

        option_rep = option_rep.sum(dim=2) # (batch_size, 4, emb_len(512))
        option_rep = option_rep.view(batch_size * num_options,512) # (batch_size * 4, emb_len(512))
        option_rep = option_rep.div(seq_real_length) # (batch_size * 4, emb_len(512))
        option_rep = self.option_BN(option_rep)
        option_rep = option_rep.view(batch_size, num_options, 512) # (batch_size, 4, emb_len(512))

        # query part
        batch_size, num_options, padded_seq_len, _ = question['bert'].shape
        query, query_obj_reps = self.embed_span(question, question_tags, question_mask, obj_reps['obj_reps'])
        assert (query.shape == (batch_size, num_options, padded_seq_len, 1280))
        query_rep = self.option_encoder(query, question_mask) # (batch_size, 4, seq_len, emb_len(512))
        query_rep = replace_masked_values(query_rep, question_mask[...,None], 0)

        seq_real_length = torch.sum(question_mask, dim=-1, dtype=torch.float) # (batch_size, 4)
        seq_real_length = seq_real_length.view(-1,1) # (batch_size * 4,1)

        query_rep = query_rep.sum(dim=2) # (batch_size, 4, emb_len(512))
        query_rep = query_rep.view(batch_size * num_options,512) # (batch_size * 4, emb_len(512))
        query_rep = query_rep.div(seq_real_length) # (batch_size * 4, emb_len(512))
        query_rep = self.query_BN(query_rep)
        query_rep = query_rep.view(batch_size, num_options, 512) # (batch_size, 4, emb_len(512))

        # image part

        # assert (obj_reps['obj_reps'][:,0,:].shape == (batch_size, 512))
        # images = obj_reps['obj_reps'][:,0,:] #  the background i.e. whole image
        # images = self.image_BN(images)
        # images = images[:,None,:]
        # images = images.repeat(1,4,1) # (batch_size, 4, 512)
        # assert (images.shape == (batch_size, num_options,512))

        query_option_image_cat = torch.cat((option_rep,query_rep),-1)
        assert (query_option_image_cat.shape == (batch_size,num_options, 512*2))
        query_option_image_cat = self.final_mlp(query_option_image_cat)
        query_option_image_cat = query_option_image_cat.view(batch_size*num_options,512)
        query_option_image_cat = self.final_BN(query_option_image_cat)
        query_option_image_cat = query_option_image_cat.view(batch_size,num_options,512)
        logits = self.final_mlp_linear(query_option_image_cat)
        logits = logits.squeeze(2)
        class_probabilities = F.softmax(logits, dim=-1)
        output_dict = {"label_logits": logits, "label_probs": class_probabilities}
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            self._accuracy(logits, label)
            output_dict["loss"] = loss[None]

        # print ('one pass')
        return output_dict
    def get_metrics(self,reset=False):
        return {'accuracy': self._accuracy.get_metric(reset)}
Exemple #17
0
class DependencyParser(Model):
    """
    This dependency parser follows the model of
    ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
    <https://arxiv.org/abs/1611.01734>`_ .

    Word representations are generated using a bidirectional LSTM,
    followed by separate biaffine classifiers for pairs of words,
    predicting whether a directed arc exists between the two words
    and the dependency label the arc should have. Decoding can either
    be done greedily, or the optimal Minimum Spanning Tree can be
    decoded using Edmond's algorithm by viewing the dependency tree as
    a MST on a fully connected graph, where nodes are words and edges
    are scored dependency arcs.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    encoder : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use to generate representations
        of tokens.
    tag_representation_dim : ``int``, required.
        The dimension of the MLPs used for dependency tag prediction.
    arc_representation_dim : ``int``, required.
        The dimension of the MLPs used for head arc prediction.
    tag_feedforward : ``FeedForward``, optional, (default = None).
        The feedforward network used to produce tag representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    arc_feedforward : ``FeedForward``, optional, (default = None).
        The feedforward network used to produce arc representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    pos_tag_embedding : ``Embedding``, optional.
        Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model.
    use_mst_decoding_for_validation : ``bool``, optional (default = True).
        Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation.
        If false, decoding is greedy.
    dropout : ``float``, optional, (default = 0.0)
        The variational dropout applied to the output of the encoder and MLP layers.
    input_dropout : ``float``, optional, (default = 0.0)
        The dropout applied to the embedded text input.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 lemmatize_helper: LemmatizeHelper,
                 task_config: TaskConfig,
                 morpho_vector_dim: int = 0,
                 gram_val_representation_dim: int = -1,
                 lemma_representation_dim: int = -1,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.lemmatize_helper = lemmatize_helper
        self.task_config = task_config

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        assert self.task_config.params.get("use_pos_tag",
                                           False) == (self._pos_tag_embedding
                                                      is not None)

        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        if gram_val_representation_dim <= 0:
            self._gram_val_output = torch.nn.Linear(
                encoder_dim, self.vocab.get_vocab_size("grammar_value_tags"))
        else:
            self._gram_val_output = torch.nn.Sequential(
                Dropout(dropout),
                torch.nn.Linear(encoder_dim, gram_val_representation_dim),
                Dropout(dropout),
                torch.nn.Linear(
                    gram_val_representation_dim,
                    self.vocab.get_vocab_size("grammar_value_tags")))

        if lemma_representation_dim <= 0:
            self._lemma_output = torch.nn.Linear(encoder_dim,
                                                 len(lemmatize_helper))
        else:
            self._lemma_output = torch.nn.Sequential(
                Dropout(dropout),
                torch.nn.Linear(encoder_dim, lemma_representation_dim),
                Dropout(dropout),
                torch.nn.Linear(lemma_representation_dim,
                                len(lemmatize_helper)))

        representation_dim = text_field_embedder.get_output_dim(
        ) + morpho_vector_dim
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        self._gram_val_prediction_accuracy = CategoricalAccuracy()
        self._lemma_prediction_accuracy = CategoricalAccuracy()

        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            words: Dict[str, torch.LongTensor],
            metadata: List[Dict[str, Any]],
            morpho_embedding: torch.FloatTensor = None,
            pos_tags: torch.LongTensor = None,
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None,
            grammar_values: torch.LongTensor = None,
            lemma_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        metadata : List[Dict[str, Any]], optional (default=None)
            A dictionary of metadata for each batch element which has keys:
                words : ``List[str]``, required.
                    The tokens in the original sentence.
                pos : ``List[str]``, required.
                    The dependencies POS tags for each word.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        loss : ``torch.FloatTensor``, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)

        if morpho_embedding is not None:
            embedded_text_input = torch.cat(
                [embedded_text_input, morpho_embedding], -1)

        if grammar_values is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(grammar_values)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)

        output_dict = self._parse(embedded_text_input, mask, head_tags,
                                  head_indices, grammar_values, lemma_indices)

        if self.task_config.task_type == "multitask":
            losses = ["arc_nll", "tag_nll", "grammar_nll", "lemma_nll"]
        elif self.task_config.task_type == "single":
            if self.task_config.params["model"] == "morphology":
                losses = ["grammar_nll"]
            elif self.task_config.params["model"] == "lemmatization":
                losses = ["lemma_nll"]
            elif self.task_config.params["model"] == "syntax":
                losses = ["arc_nll", "tag_nll"]
            else:
                assert False, "Unknown model type {}".format(
                    self.task_config.params["model"])
        else:
            assert False, "Unknown task type {}".format(
                self.task_config.task_type)

        output_dict["loss"] = sum(output_dict[loss_name]
                                  for loss_name in losses)

        if head_indices is not None and head_tags is not None:
            evaluation_mask = self._get_mask_for_eval(mask, pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(output_dict["heads"][:, 1:],
                                    output_dict["head_tags"][:, 1:],
                                    head_indices, head_tags, evaluation_mask)

        output_dict["words"] = [meta["words"] for meta in metadata]
        if metadata and "pos" in metadata[0]:
            output_dict["pos"] = [meta["pos"] for meta in metadata]

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        predicted_gram_vals = output_dict.pop(
            "gram_vals").cpu().detach().numpy()
        predicted_lemmas = output_dict.pop("lemmas").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)

        assert len(head_tags) == len(heads) == len(lengths) == len(
            predicted_gram_vals) == len(predicted_lemmas)

        head_tag_labels, head_indices, decoded_gram_vals, decoded_lemmas = [], [], [], []
        for instance_index in range(len(head_tags)):
            instance_heads, instance_tags = heads[instance_index], head_tags[
                instance_index]
            words, length = output_dict["words"][instance_index], lengths[
                instance_index]
            gram_vals, lemmas = predicted_gram_vals[
                instance_index], predicted_lemmas[instance_index]

            words = words[:length.item() - 1]
            gram_vals = gram_vals[:length.item() - 1]
            lemmas = lemmas[:length.item() - 1]

            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [
                self.vocab.get_token_from_index(label, "head_tags")
                for label in instance_tags
            ]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

            decoded_gram_vals.append([
                self.vocab.get_token_from_index(gram_val, "grammar_value_tags")
                for gram_val in gram_vals
            ])

            decoded_lemmas.append([
                self.lemmatize_helper.lemmatize(word, lemmatize_rule_index)
                for word, lemmatize_rule_index in zip(words, lemmas)
            ])

        if self.task_config.task_type == "multitask":
            output_dict["predicted_dependencies"] = head_tag_labels
            output_dict["predicted_heads"] = head_indices
            output_dict["predicted_gram_vals"] = decoded_gram_vals
            output_dict["predicted_lemmas"] = decoded_lemmas
        elif self.task_config.task_type == "single":
            if self.task_config.params["model"] == "morphology":
                output_dict["predicted_gram_vals"] = decoded_gram_vals
            elif self.task_config.params["model"] == "lemmatization":
                output_dict["predicted_lemmas"] = decoded_lemmas
            elif self.task_config.params["model"] == "syntax":
                output_dict["predicted_dependencies"] = head_tag_labels
                output_dict["predicted_heads"] = head_indices
            else:
                assert False, "Unknown model type {}".format(
                    self.task_config.params["model"])
        else:
            assert False, "Unknown task type {}".format(
                self.task_config.task_type)

        return output_dict

    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        grammar_value_logits = self._gram_val_output(encoded_text)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

        lemma_logits = self._lemma_output(encoded_text)
        predicted_lemmas = lemma_logits.argmax(-1)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            grammar_nll = self._update_multiclass_prediction_metrics(
                logits=grammar_value_logits,
                targets=grammar_values,
                mask=token_mask,
                accuracy_metric=self._gram_val_prediction_accuracy)

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            lemma_nll = self._update_multiclass_prediction_metrics(
                logits=lemma_logits,
                targets=lemma_indices,
                mask=token_mask,
                accuracy_metric=self._lemma_prediction_accuracy,
                masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX)

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
        }

        return output_dict

    @staticmethod
    def _update_multiclass_prediction_metrics(logits,
                                              targets,
                                              mask,
                                              accuracy_metric,
                                              masked_index=None):
        accuracy_metric(logits, targets, mask)

        logits = logits.view(-1, logits.shape[-1])
        loss = F.cross_entropy(logits, targets.view(-1), reduction='none')
        if masked_index is not None:
            mask = mask * (targets != masked_index)
        loss_mask = mask.view(-1)
        return (loss * loss_mask).sum() / loss_mask.sum()

    def _construct_loss(
            self, head_tag_representation: torch.Tensor,
            child_tag_representation: torch.Tensor,
            attended_arcs: torch.Tensor, head_indices: torch.Tensor,
            head_tags: torch.Tensor,
            mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(
            attended_arcs,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        normalised_head_tag_logits = masked_log_softmax(
            head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll

    def _greedy_decode(
            self, head_tag_representation: torch.Tensor,
            child_tag_representation: torch.Tensor,
            attended_arcs: torch.Tensor,
            mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs independently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(
            attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation, heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags

    def _mst_decode(self, head_tag_representation: torch.Tensor,
                    child_tag_representation: torch.Tensor,
                    attended_arcs: torch.Tensor,
                    mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size(
        )

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [
            batch_size, sequence_length, sequence_length,
            tag_representation_dim
        ]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(
            *expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(
            *expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation,
                                                 child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits,
                                                        dim=3).permute(
                                                            0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs,
                                              dim=2).transpose(1, 2)

        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.
        batch_energy = torch.exp(
            normalized_arc_logits.unsqueeze(1) +
            normalized_pairwise_head_logits)
        return self._run_mst_decoding(batch_energy, lengths)

    @staticmethod
    def _run_mst_decoding(
            batch_energy: torch.Tensor,
            lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(),
                                           length,
                                           has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(
            numpy.stack(head_tags))

    def _get_head_tags(self, head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(
            batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[
            range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous(
        )
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits

    def _get_mask_for_eval(self, mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = self._attachment_scores.get_metric(reset)
        metrics['GramValAcc'] = self._gram_val_prediction_accuracy.get_metric(
            reset)
        metrics['LemmaAcc'] = self._lemma_prediction_accuracy.get_metric(reset)
        metrics['MeanAcc'] = (metrics['GramValAcc'] + metrics['LemmaAcc'] +
                              metrics['LAS']) / 3.

        return metrics
Exemple #18
0
class ImageTextMatchingModel(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 scibert_path: str,
                 pretrained: bool = True,
                 fusion_layer: int = 0,
                 num_layers: int = None,
                 image_root: str = None,
                 full_matching: bool = False,
                 retrieval_file: str = None,
                 dropout: float = None,
                 pretrained_bert: bool = False,
                 tokens_namespace: str = "tokens",
                 labels_namespace: str = "labels"):
        super().__init__(vocab)
        image_model = torchvision.models.resnet50(pretrained=pretrained)
        self.image_classifier_in_features = image_model.fc.in_features
        image_model.fc = torch.nn.Identity()
        self.image_feature_extractor = image_model
        config = BertConfig.from_json_file(
            os.path.join(scibert_path, 'config.json'))
        self.tokenizer = BertTokenizer(config=config,
                                       vocab_file=os.path.join(
                                           scibert_path, 'vocab.txt'))
        if dropout is not None:
            config.hidden_dropout_prob = dropout
        if num_layers is not None:
            config.num_hidden_layers = num_layers
        num_visual_positions = 1
        self.bert = VisBertModel(config, self.image_classifier_in_features,
                                 fusion_layer, num_visual_positions)
        num_classifier_in_features = self.bert.config.hidden_size
        self.matching_classifier = torch.nn.Linear(num_classifier_in_features,
                                                   1)
        if pretrained_bert:
            state = torch.load(os.path.join(scibert_path, 'pytorch_model.bin'))
            filtered_state = {}
            for key in state:
                if key[:5] == 'bert.':
                    filtered_state[key[5:]] = state[key]
            self.bert.load_state_dict(filtered_state, strict=False)
        self.mode = "binary_matching"
        self.max_sequence_length = 512
        self.head_input_feature_dim = self.bert.config.hidden_size
        self._tokens_namespace = tokens_namespace

        if full_matching:
            expected_img_size = 224
            self.image_transform = transforms.Compose([
                transforms.Resize(expected_img_size),
                transforms.CenterCrop(expected_img_size),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ])
            self.mode = "full_matching"
            self.images = []
            self.image_id_index_map = {}
            f = open(retrieval_file)
            lines = f.readlines()
            for line in lines:
                fname = json.loads(line)['image_id']
                img = Image.open(os.path.join(image_root,
                                              fname)).convert('RGB')
                self.images.append(self.image_transform(img))
                self.image_id_index_map[fname] = len(self.images) - 1
            self.images = torch.stack(self.images)
            self.top5_accuracy = CategoricalAccuracy(top_k=5)
            self.top10_accuracy = CategoricalAccuracy(top_k=10)
            self.top20_accuracy = CategoricalAccuracy(top_k=20)

        self.loss = torch.nn.CrossEntropyLoss()
        self.accuracy = CategoricalAccuracy()

    def forward(
        self,
        token_ids: Dict[str, torch.Tensor],
        segment_ids: torch.Tensor,
        images: torch.Tensor,
        image_ids: List[str],
        mask: torch.Tensor = None,
        matching_labels: torch.Tensor = None,
        labels: Dict[str, torch.Tensor] = None,
        labels_text: Dict[str, torch.Tensor] = None,
        words_tokens_map: torch.Tensor = None,
        words_to_mask: torch.Tensor = None,
        tokens_to_mask: torch.Tensor = None,
        question: List[str] = None,
        answer: List[str] = None,
        category: torch.Tensor = None,
    ):
        original_token_ids = token_ids
        if mask is not None:
            input_mask = mask
        else:
            input_mask = util.get_text_field_mask(token_ids)
        token_ids = token_ids[self._tokens_namespace]
        batch_size = min(token_ids.shape[0], images.shape[0])
        if images.shape[0] > token_ids.shape[0]:
            repeat_num = images.shape[0] // token_ids.shape[0]
            token_ids = token_ids.repeat(1, repeat_num).view(
                -1, token_ids.shape[-1])
            input_mask = input_mask.repeat(1, repeat_num).view(
                -1, input_mask.shape[-1])
            segment_ids = segment_ids.repeat(1, repeat_num).view(
                -1, segment_ids.shape[-1])
        elif token_ids.shape[0] > images.shape[0]:
            repeat_num = token_ids.shape[0] // images.shape[0]
            images = images.unsqueeze(1).repeat(1, repeat_num, 1, 1, 1).view(
                -1, images.shape[1], images.shape[2], images.shape[3])
        input_token_ids = token_ids
        visual_feats = self.image_feature_extractor(images)
        visual_feats = visual_feats.view(token_ids.shape[0],
                                         self.image_classifier_in_features)
        positions = torch.zeros(
            (visual_feats.shape[0], 4)).to(visual_feats.device).float()
        visual_inputs = (positions, visual_feats)
        sequence_encodings, joint_representation = self.bert(
            input_ids=input_token_ids,
            attention_mask=input_mask,
            token_type_ids=segment_ids,
            visual_inputs=visual_inputs)
        outputs = {'loss': torch.tensor(0.)}
        if self.mode == "binary_matching":
            joint_representation = joint_representation.view(
                batch_size, -1, joint_representation.shape[-1])
            match_predictions = self.matching_classifier(
                joint_representation).view(batch_size, -1)
            if labels is not None:
                loss = self.loss(match_predictions, labels)
                outputs['loss'] = loss
                self.accuracy(match_predictions, labels)
        if self.mode == "full_matching":
            with torch.no_grad():
                predictions = torch.zeros(
                    (batch_size, self.images.shape[0])).to(images.device)
                batch = []
                batch_image_ids = []
                finished_images = set()
                sub_batch_size = 100
                for i in range(0, self.images.shape[0], sub_batch_size):
                    batch_images = self.images[i:i + sub_batch_size].to(
                        images.device)
                    visual_feats = self.image_feature_extractor(batch_images)
                    visual_feats = visual_feats.view(
                        -1, self.image_classifier_in_features).repeat(
                            batch_size, 1)
                    positions = torch.zeros(
                        (visual_feats.shape[0],
                         4)).to(visual_feats.device).float()
                    visual_inputs = (positions, visual_feats)
                    batch_input_token_ids = input_token_ids.repeat(
                        1,
                        batch_images.shape[0]).view(-1,
                                                    input_token_ids.shape[-1])
                    batch_input_mask = input_mask.repeat(
                        1, batch_images.shape[0]).view(-1,
                                                       input_mask.shape[-1])
                    batch_segment_ids = segment_ids.repeat(
                        1, batch_images.shape[0]).view(-1,
                                                       segment_ids.shape[-1])
                    sequence_encodings, joint_representation = self.bert(
                        input_ids=batch_input_token_ids,
                        attention_mask=batch_input_mask,
                        token_type_ids=batch_segment_ids,
                        visual_inputs=visual_inputs)
                    joint_representation = joint_representation.view(
                        batch_size, batch_images.shape[0],
                        joint_representation.shape[-1])
                    match_predictions = self.matching_classifier(
                        joint_representation).view(batch_size,
                                                   batch_images.shape[0])
                    predictions[:, i:i +
                                batch_images.shape[0]] = match_predictions
                labels = torch.Tensor([
                    self.image_id_index_map[image_id] for image_id in image_ids
                ]).long().to(images.device)
                outputs['loss'] = self.loss(predictions, labels)
                outputs['predictions'] = predictions
                outputs['image_ids'] = image_ids
                self.accuracy(predictions, labels)
                self.top5_accuracy(predictions, labels)
                self.top10_accuracy(predictions, labels)
                self.top20_accuracy(predictions, labels)
        return outputs

    def get_metrics(self, reset: bool = False):
        metrics = {'accuracy': self.accuracy.get_metric(reset)}
        if "full_matching" in self.mode:
            metrics["top5_accuracy"] = self.top5_accuracy.get_metric(reset)
            metrics["top10_accuracy"] = self.top10_accuracy.get_metric(reset)
            metrics["top20_accuracy"] = self.top20_accuracy.get_metric(reset)
        return metrics
Exemple #19
0
class SequenceClassifier(Model):
    """
    This ``SequenceClassifier`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then
    predicts a label for the sequence.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    stacked_encoder : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and predicting output tags.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 stacked_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SequenceClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.stacked_encoder = stacked_encoder
        self.projection_layer = Linear(self.stacked_encoder.get_output_dim(),
                                       self.num_classes)

        if text_field_embedder.get_output_dim(
        ) != stacked_encoder.get_input_dim():
            raise ConfigurationError(
                "The output dimension of the text_field_embedder must match the "
                "input dimension of the phrase_encoder. Found {} and {}, "
                "respectively.".format(text_field_embedder.get_output_dim(),
                                       stacked_encoder.get_input_dim()))
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels of shape
            ``(batch_size, num_tokens)``.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.text_field_embedder(tokens)
        batch_size, sequence_length, _ = embedded_text_input.size()
        mask = get_text_field_mask(tokens)
        encoded_text = self.stacked_encoder(embedded_text_input, mask)
        logits = self.projection_layer(torch.mean(encoded_text, 1).squeeze())
        class_probabilities = F.softmax(logits)

        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label.squeeze(-1))

        return output_dict

    @overrides
    def forward_on_instance(self, instance: Instance,
                            cuda_device: int) -> Dict[str, numpy.ndarray]:
        """
        Takes an :class:`~allennlp.data.instance.Instance`, which typically has raw text in it,
        converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays
        through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing)
        and returns the result.  Before returning the result, we convert any ``torch.autograd.Variables``
        or ``torch.Tensors`` into numpy arrays and remove the batch dimension.
        """
        instance.index_fields(self.vocab)
        model_input = arrays_to_variables(instance.as_array_dict(),
                                          add_batch_dimension=True,
                                          cuda_device=cuda_device,
                                          for_training=False)
        outputs = self.decode(self.forward(**model_input))

        for name, output in list(outputs.items()):
            if isinstance(output, torch.autograd.Variable):
                output = output.data.cpu().numpy()
            outputs[name] = output
        return outputs

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does a simple position-wise argmax over each token, converts indices to string labels, and
        adds a ``"tags"`` key to the dictionary with the result.
        """
        all_predictions = output_dict['class_probabilities']
        if not isinstance(all_predictions, numpy.ndarray):
            all_predictions = all_predictions.data.numpy()

        argmax_i = numpy.argmax(all_predictions)
        logger.info(argmax_i)
        label = self.vocab.get_token_from_index(argmax_i, namespace="labels")
        output_dict['label'] = label
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self._accuracy.get_metric(reset)}

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'SimpleTagger':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, embedder_params)
        stacked_encoder = Seq2SeqEncoder.from_params(
            params.pop("stacked_encoder"))

        initializer = InitializerApplicator.from_params(
            params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(
            params.pop('regularizer', []))

        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   stacked_encoder=stacked_encoder,
                   initializer=initializer,
                   regularizer=regularizer)
Exemple #20
0
class BasicClassifier(Model):
    """
    This ``Model`` implements a basic text classifier. After embedding the text into
    a text field, we will optionally encode the embeddings with a ``Seq2SeqEncoder``. The
    resulting sequence is pooled using a ``Seq2VecEncoder`` and then passed to
    a linear classification layer, which projects into the label space. If a
    ``Seq2SeqEncoder`` is not provided, we will pass the embedded text directly to the
    ``Seq2VecEncoder``.
    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the input text into a ``TextField``
    seq2seq_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        Optional Seq2Seq encoder layer for the input text.
    seq2vec_encoder : ``Seq2VecEncoder``
        Required Seq2Vec encoder layer. If `seq2seq_encoder` is provided, this encoder
        will pool its output. Otherwise, this encoder will operate directly on the output
        of the `text_field_embedder`.
    dropout : ``float``, optional (default = ``None``)
        Dropout percentage to use.
    num_labels: ``int``, optional (default = ``None``)
        Number of labels to project to in classification layer. By default, the classification layer will
        project to the size of the vocabulary namespace corresponding to labels.
    label_namespace: ``str``, optional (default = "labels")
        Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        If provided, will be used to initialize the model parameters.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 seq2vec_encoder: Seq2VecEncoder,
                 seq2seq_encoder: Seq2SeqEncoder = None,
                 dropout: float = None,
                 num_labels: int = None,
                 label_namespace: str = "labels",
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:

        super().__init__(vocab)
        self._text_field_embedder = text_field_embedder

        if seq2seq_encoder:
            self._seq2seq_encoder = seq2seq_encoder
        else:
            self._seq2seq_encoder = None

        self._seq2vec_encoder = seq2vec_encoder
        self._classifier_input_dim = self._seq2vec_encoder.get_output_dim()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = None

        self._label_namespace = label_namespace

        if num_labels:
            self._num_labels = num_labels
        else:
            self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace)
        self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels)
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)

    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does a simple argmax over the probabilities, converts index to string label, and
        add ``"label"`` key to the dictionary with the result.
        """
        predictions = output_dict["probs"]
        if predictions.dim() == 2:
            predictions_list = [predictions[i] for i in range(predictions.shape[0])]
        else:
            predictions_list = [predictions]
        classes = []
        for prediction in predictions_list:
            label_idx = prediction.argmax(dim=-1).item()
            label_str = (self.vocab.get_index_to_token_vocabulary(self._label_namespace)
                         .get(label_idx, str(label_idx)))
            classes.append(label_str)
        output_dict["label"] = classes
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {'accuracy': self._accuracy.get_metric(reset)}
        return metrics
Exemple #21
0
class SyntacticEntailment(Model):
    """
    This ``Model`` implements the Decomposable Attention model with Late Fusion.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
        model.
    attend_feedforward : ``FeedForward``
        This feedforward network is applied to the encoded sentence representations before the
        similarity matrix is computed between words in the premise and words in the hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between words in
        the premise and words in the hypothesis.
    compare_feedforward : ``FeedForward``
        This feedforward network is applied to the aligned premise and hypothesis representations,
        individually.
    aggregate_feedforward : ``FeedForward``
        This final feedforward network is applied to the concatenated, summed result of the
        ``compare_feedforward`` network, and its output is used as the entailment class logits.
    parser_model_path : str
        This specifies the filepath of the pretrained parser.
    parser_cuda_device : int
        The cuda device that the pretrained parser should run on.
    freeze_parser : bool
        Whether to allow the parser to be fine-tuned.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 attend_feedforward: FeedForward,
                 similarity_function: SimilarityFunction,
                 compare_feedforward: FeedForward,
                 aggregate_feedforward: FeedForward,
                 parser_model_path: str,
                 parser_cuda_device: int,
                 freeze_parser: bool,
                 premise_encoder: Optional[Seq2SeqEncoder] = None,
                 hypothesis_encoder: Optional[Seq2SeqEncoder] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SyntacticEntailment, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._attend_feedforward = TimeDistributed(attend_feedforward)
        self._attention = LegacyMatrixAttention(similarity_function)
        self._compare_feedforward = TimeDistributed(compare_feedforward)
        self._aggregate_feedforward = aggregate_feedforward
        self._premise_encoder = premise_encoder
        self._hypothesis_encoder = hypothesis_encoder or premise_encoder

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(text_field_embedder.get_output_dim(),
                               attend_feedforward.get_input_dim(),
                               "text field embedding dim",
                               "attend feedforward input dim")
        check_dimensions_match(aggregate_feedforward.get_output_dim(),
                               self._num_labels, "final output dimension",
                               "number of labels")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        self._parser = load_archive(parser_model_path,
                                    cuda_device=parser_cuda_device).model
        self._parser._head_sentinel.requires_grad = False
        for child in self._parser.children():
            for param in child.parameters():
                param.requires_grad = False
        if not freeze_parser:
            for param in self._parser.encoder.parameters():
                param.requires_grad = True

        initializer(self)

    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            premise_tags: torch.LongTensor,
            hypothesis: Dict[str, torch.LongTensor],
            hypothesis_tags: torch.LongTensor,
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        premise_tags : torch.LongTensor
            The POS tags of the premise.
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``.
        hypothesis_tags: torch.LongTensor
            The POS tags of the hypothesis.
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._attention(projected_premise,
                                            projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        # running the parser
        encoded_p_parse, p_parse_mask = self._parser(premise, premise_tags)
        p_parse_encoder_final_state = get_final_encoder_states(
            encoded_p_parse, p_parse_mask)
        encoded_h_parse, h_parse_mask = self._parser(hypothesis,
                                                     hypothesis_tags)
        h_parse_encoder_final_state = get_final_encoder_states(
            encoded_h_parse, h_parse_mask)

        compared_premise = torch.cat(
            [compared_premise, p_parse_encoder_final_state], dim=-1)
        compared_hypothesis = torch.cat(
            [compared_hypothesis, h_parse_encoder_final_state], dim=-1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {'logits': label_logits, 'label_probs': label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict['loss'] = loss

        if metadata is not None:
            output_dict['premise_tokens'] = [
                x['premise_tokens'] for x in metadata
            ]
            output_dict['hypothesis_tokens'] = [
                x['hypothesis_tokens'] for x in metadata
            ]

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'accuracy': self._accuracy.get_metric(reset),
        }
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--vocab_file",
                        default='bert-base-uncased-vocab.txt',
                        type=str,
                        required=True)
    parser.add_argument("--model_file",
                        default='bert-base-uncased.tar.gz',
                        type=str,
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument(
        "--predict_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the predictions will be written.")

    # Other parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
    )
    parser.add_argument("--test_file", default=None, type=str)
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=2.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--view_id',
                        type=int,
                        default=1,
                        help="view id of multi-view co-training(two-view)")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--save_all', default=False, action='store_true')

    # Base setting
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--max_ctx', type=int, default=2)
    parser.add_argument('--task_name', type=str, default='race')
    parser.add_argument('--bert_name', type=str, default='pool-race')
    parser.add_argument('--reader_name', type=str, default='race')
    parser.add_argument('--per_eval_step', type=int, default=10000000)
    # model parameters
    parser.add_argument('--evidence_lambda', type=float, default=0.8)
    # Parameters for running labeling model
    parser.add_argument('--do_label', default=False, action='store_true')
    parser.add_argument('--sentence_id_file', nargs='*')
    parser.add_argument('--weight_threshold', type=float, default=0.0)
    parser.add_argument('--only_correct', default=False, action='store_true')
    parser.add_argument('--label_threshold', type=float, default=0.0)
    parser.add_argument('--multi_evidence', default=False, action='store_true')
    parser.add_argument('--metric', default='accuracy', type=str)
    parser.add_argument('--num_evidence', default=1, type=int)
    parser.add_argument('--power_length', default=1., type=float)
    parser.add_argument('--num_choices', default=4, type=int)

    args = parser.parse_args()

    logger = setting_logger(args.output_dir)
    logger.info('================== Program start. ========================')

    model_params = prepare_model_params(args)
    read_params = prepare_read_params(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict and not args.do_label:
        raise ValueError(
            "At least one of `do_train` or `do_predict` or `do_label` must be True."
        )

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    if args.do_train:
        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory () already exists and is not empty.")
        os.makedirs(args.output_dir, exist_ok=True)

    if args.do_predict:
        os.makedirs(args.predict_dir, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.vocab_file)

    data_reader = initialize_reader(args.reader_name)

    num_train_steps = None
    if args.do_train or args.do_label:
        train_examples = data_reader.read(input_file=args.train_file,
                                          **read_params)

        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}'.format(
            args.bert_model, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length), str(args.max_ctx), str(args.task_name))

        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except FileNotFoundError:
            train_features = data_reader.convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s",
                            cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)

        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    if args.pretrain is not None:
        logger.info('Load pretrained model from {}'.format(args.pretrain))
        model_state_dict = torch.load(args.pretrain, map_location='cuda:0')
        model = initialize_model(args.bert_name,
                                 args.model_file,
                                 state_dict=model_state_dict,
                                 **model_params)
    else:
        model = initialize_model(args.bert_name, args.model_file,
                                 **model_params)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    t_total = num_train_steps if num_train_steps is not None else -1
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
        warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
                                             t_total=t_total)
        logger.info(
            f"warm up linear: warmup = {warmup_linear.warmup}, t_total = {warmup_linear.t_total}."
        )
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    # Prepare data
    eval_examples = data_reader.read(input_file=args.predict_file,
                                     **read_params)
    eval_features = data_reader.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length)

    eval_tensors = data_reader.data_to_tensors(eval_features)
    eval_data = TensorDataset(*eval_tensors)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.predict_batch_size)

    if args.do_train:

        if args.do_label:
            logger.info('Training in State Wise.')
            sentence_label_file = args.sentence_id_file
            if sentence_label_file is not None:
                for file in sentence_label_file:
                    train_features = data_reader.generate_features_sentence_ids(
                        train_features, file)
            else:
                logger.info('No sentence id supervision is found.')
        else:
            logger.info('Training in traditional way.')

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Num train total optimization steps = %d", t_total)
        logger.info("  Batch size = %d", args.predict_batch_size)
        train_loss = AverageMeter()
        best_acc = 0.0
        best_loss = 1000000
        summary_writer = SummaryWriter(log_dir=args.output_dir)
        global_step = 0
        eval_loss = AverageMeter()
        eval_accuracy = CategoricalAccuracy()
        eval_epoch = 0

        train_tensors = data_reader.data_to_tensors(train_features)
        train_data = TensorDataset(*train_tensors)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        for epoch in range(int(args.num_train_epochs)):
            logger.info(f'Running at Epoch {epoch}')
            # Train
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Iteration",
                         dynamic_ncols=True)):
                model.train()
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, train_features, model_state=ModelState.Train)
                model_output = model(**inputs)
                loss = model_output['loss']
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    # if args.fp16 is False, BertAdam is used and handles this automatically
                    if args.fp16:
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                        summary_writer.add_scalar('lr', lr_this_step,
                                                  global_step)
                    else:
                        summary_writer.add_scalar('lr',
                                                  optimizer.get_lr()[0],
                                                  global_step)

                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                    train_loss.update(loss.item(), 1)
                    summary_writer.add_scalar('train_loss', train_loss.avg,
                                              global_step)
                    # logger.info(f'Train loss: {train_loss.avg}')

                if (step + 1) % args.per_eval_step == 0 or step == len(
                        train_dataloader) - 1:
                    # Evaluation
                    model.eval()
                    logger.info("Start evaluating")
                    for _, eval_batch in enumerate(
                            tqdm(eval_dataloader,
                                 desc="Evaluating",
                                 dynamic_ncols=True)):
                        if n_gpu == 1:
                            eval_batch = batch_to_device(
                                eval_batch,
                                device)  # multi-gpu does scattering it-self
                        inputs = data_reader.generate_inputs(
                            eval_batch,
                            eval_features,
                            model_state=ModelState.Evaluate)
                        with torch.no_grad():
                            output_dict = model(**inputs)
                            loss, choice_logits = output_dict[
                                'loss'], output_dict['choice_logits']
                            eval_loss.update(loss.item(), 1)
                            eval_accuracy(choice_logits, inputs["labels"])

                    eval_epoch_loss = eval_loss.avg
                    summary_writer.add_scalar('eval_loss', eval_epoch_loss,
                                              eval_epoch)
                    eval_loss.reset()
                    current_acc = eval_accuracy.get_metric(reset=True)
                    summary_writer.add_scalar('eval_acc', current_acc,
                                              eval_epoch)
                    torch.cuda.empty_cache()

                    if args.save_all:
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir, f"pytorch_model_{eval_epoch}.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)

                    if current_acc > best_acc:
                        best_acc = current_acc
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir, "pytorch_model.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                    if eval_epoch_loss < best_loss:
                        best_loss = eval_epoch_loss
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir, "pytorch_loss_model.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)

                    logger.info(
                        'Eval Epoch: %d, Accuracy: %.4f (Best Accuracy: %.4f)'
                        % (eval_epoch, current_acc, best_acc))
                    eval_epoch += 1
            logger.info(
                f'Epoch {epoch}: Accuracy: {best_acc}, Train Loss: {train_loss.avg}'
            )
        summary_writer.close()

    for output_model_name in ["pytorch_model.bin", "pytorch_loss_model.bin"]:
        # Loading trained model
        output_model_file = os.path.join(args.output_dir, output_model_name)
        model_state_dict = torch.load(output_model_file, map_location='cuda:0')
        model = initialize_model(args.bert_name,
                                 args.model_file,
                                 state_dict=model_state_dict,
                                 **model_params)
        model.to(device)

        # Write Yes/No predictions
        if args.do_predict and (args.local_rank == -1
                                or torch.distributed.get_rank() == 0):

            test_examples = data_reader.read(args.test_file)
            test_features = data_reader.convert_examples_to_features(
                test_examples, tokenizer, args.max_seq_length)

            test_tensors = data_reader.data_to_tensors(test_features)
            test_data = TensorDataset(*test_tensors)
            test_sampler = SequentialSampler(test_data)
            test_dataloader = DataLoader(test_data,
                                         sampler=test_sampler,
                                         batch_size=args.predict_batch_size)

            logger.info("***** Running predictions *****")
            logger.info("  Num orig examples = %d", len(test_examples))
            logger.info("  Num split examples = %d", len(test_features))
            logger.info("  Batch size = %d", args.predict_batch_size)

            model.eval()
            all_results = []
            test_acc = CategoricalAccuracy()
            logger.info("Start predicting yes/no on Dev set.")
            for batch in tqdm(test_dataloader, desc="Testing"):
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, test_features, model_state=ModelState.Evaluate)
                with torch.no_grad():
                    batch_choice_logits = model(**inputs)['choice_logits']
                    test_acc(batch_choice_logits, inputs['labels'])
                example_indices = batch[-1]
                for i, example_index in enumerate(example_indices):
                    choice_logits = batch_choice_logits[i].detach().cpu(
                    ).tolist()

                    test_feature = test_features[example_index.item()]
                    unique_id = int(test_feature.unique_id)

                    all_results.append(
                        RawResultChoice(unique_id=unique_id,
                                        choice_logits=choice_logits))

            if "loss" in output_model_name:
                logger.info(
                    'Predicting question choice on test set using model with lowest loss on validation set.'
                )
                output_prediction_file = os.path.join(args.predict_dir,
                                                      'loss_predictions.json')
            else:
                logger.info(
                    'Predicting question choice on test set using model with best accuracy on validation set,'
                )
                output_prediction_file = os.path.join(args.predict_dir,
                                                      'predictions.json')
            data_reader.write_predictions(test_examples, test_features,
                                          all_results, output_prediction_file)
            logger.info(
                f"Accuracy on Test set: {test_acc.get_metric(reset=True)}")

    # Loading trained model.
    if args.metric == 'accuracy':
        logger.info("Load model with best accuracy on validation set.")
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    elif args.metric == 'loss':
        logger.info("Load model with lowest loss on validation set.")
        output_model_file = os.path.join(args.output_dir,
                                         "pytorch_loss_model.bin")
    else:
        raise RuntimeError(
            f"Wrong metric type for {args.metric}, which must be in ['accuracy', 'loss']."
        )
    model_state_dict = torch.load(output_model_file, map_location='cuda:0')
    model = initialize_model(args.bert_name,
                             args.model_file,
                             state_dict=model_state_dict,
                             **model_params)
    model.to(device)

    # Labeling sentence id.
    if args.do_label and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):

        f = open('debug_log.txt', 'w')

        def softmax(x):
            """Compute softmax values for each sets of scores in x."""
            e_x = np.exp(x - np.max(x))
            return e_x / e_x.sum()

        def beam_search(sentence_sim, beam_num=10):
            """
            sentence_sim(numpy)
            """
            max_length = args.num_evidence
            sentence_sim = np.pad(sentence_sim, (1, 0),
                                  'constant',
                                  constant_values=(0, ))
            sentences = [{'sim': sentence_sim, 'sentences': [], 'value': 0.}]
            while sentences[0][
                    'sentences'] == [] or sentences[0]['sentences'][-1] != 0:
                new_sentences = []
                for sentence in sentences:
                    if sentence['sentences'] != [] and sentence['sentences'][
                            -1] == 0:
                        new_sentences.append(sentence)
                        continue
                    scores = softmax(sentence['sim'])
                    for i in range(len(sentence['sim'])):
                        if i == 0 and sentence['sentences'] == []:
                            continue
                        if len(sentence['sentences']) > max_length:
                            continue
                        if len(sentence['sentences']) == max_length and i != 0:
                            continue
                        if i in sentence['sentences']:
                            continue
                        if max_length == 1 and i == 0:
                            value = sentence['value']
                        else:
                            value = sentence['value'] + np.log(scores[i])
                        # `i - 1` refers to original sentence id
                        new_sentence = {
                            'sim': np.copy(sentence['sim']),
                            'sentences': sentence['sentences'] + [i],
                            'value': value
                        }
                        new_sentence['sim'][i] = -1e15
                        new_sentences.append(new_sentence)
                sentences = sorted(new_sentences,
                                   key=lambda x: x['value'] / np.power(
                                       len(x['sentences']), args.power_length),
                                   reverse=True)[:beam_num]
            sentence = sentences[0]
            sentence['value'] = sentence['value'] / np.power(
                len(sentence['sentences']), args.power_length)
            print(sentence['value'], file=f, flush=True)
            return sentence

        def batch_choice_beam_search(sentence_sim,
                                     sentence_mask,
                                     beam_num=10) -> List[List[Dict]]:
            """
            :param sentence_sim: [batch, num_choices, max_sen] -> torch.FloatTensor, device=cpu
            :param sentence_mask:  [batch, num_choices, max_sen] -> torch.FloatTensor, device=cpu
            :param beam_num: int
            :return: batch * num_choices * num_evidences -> List[List[int]]
            """
            batch_size = sentence_sim.size(0)
            num_choices = sentence_sim.size(1)
            sentence_sim = sentence_sim.numpy() + 1e-15
            sentence_mask = sentence_mask.numpy()
            sentence_ids = []
            for b in range(batch_size):
                choice_sentence_ids = []
                for c in range(num_choices):
                    choice_sentence_ids.append(
                        beam_search(
                            sentence_sim[b, c, :int(sum(sentence_mask[b, c]))],
                            beam_num))
                    print(
                        '=================single choice=====================',
                        file=f,
                        flush=True)
                sentence_ids.append(choice_sentence_ids)
            return sentence_ids

        test_examples = train_examples
        test_features = train_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running labeling *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start labeling.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(batch, device)
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                output_dict = model(**inputs)
                batch_choice_logits, batch_sentence_logits = output_dict[
                    "choice_logits"], output_dict["sentence_logits"]
                batch_sentence_mask = output_dict["sentence_mask"]
            example_indices = batch[-1]
            batch_beam_results = batch_choice_beam_search(
                batch_sentence_logits, batch_sentence_mask)
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu()
                evidence_list = batch_beam_results[i]

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    RawOutput(unique_id=unique_id,
                              model_output={
                                  "choice_logits": choice_logits,
                                  "evidence_list": evidence_list
                              }))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'sentence_id_file.json')
        data_reader.predict_sentence_ids(
            test_examples,
            test_features,
            all_results,
            output_prediction_file,
            weight_threshold=args.weight_threshold,
            only_correct=args.only_correct,
            label_threshold=args.label_threshold)
Exemple #23
0
class AMTask(Model):
    """
    A class that implements a task-specific model. It conceptually belongs to a formalism or corpus.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 name: str,
                 edge_model: EdgeModel,
                 loss_function: EdgeLoss,
                 supertagger: Supertagger,
                 lexlabeltagger: Supertagger,
                 supertagger_loss: SupertaggingLoss,
                 lexlabel_loss: SupertaggingLoss,
                 output_null_lex_label: bool = True,
                 loss_mixing: Dict[str, float] = None,
                 dropout: float = 0.0,
                 validation_evaluator: Optional[Evaluator] = None,
                 regularizer: Optional[RegularizerApplicator] = None):

        super().__init__(vocab, regularizer)
        self.name = name
        self.edge_model = edge_model
        self.supertagger = supertagger
        self.lexlabeltagger = lexlabeltagger
        self.supertagger_loss = supertagger_loss
        self.lexlabel_loss = lexlabel_loss
        self.loss_function = loss_function
        self.loss_mixing = loss_mixing or dict()
        self.validation_evaluator = validation_evaluator
        self.output_null_lex_label = output_null_lex_label

        self._dropout = InputVariationalDropout(dropout)

        loss_names = [
            "edge_existence", "edge_label", "supertagging", "lexlabel"
        ]
        for loss_name in loss_names:
            if loss_name not in self.loss_mixing:
                self.loss_mixing[loss_name] = 1.0
                logger.info(
                    f"Loss name {loss_name} not found in loss_mixing, using a weight of 1.0"
                )
            else:
                if self.loss_mixing[loss_name] is None:
                    if loss_name not in ["supertagging", "lexlabel"]:
                        raise ConfigurationError(
                            "Only the loss mixing coefficients for supertagging and lexlabel may be None, but not "
                            + loss_name)

        not_contained = set(self.loss_mixing.keys()) - set(loss_names)
        if len(not_contained):
            logger.critical(
                f"The following loss name(s) are unknown: {not_contained}")
            raise ValueError(
                f"The following loss name(s) are unknown: {not_contained}")

        self._supertagging_acc = CategoricalAccuracy()
        self._lexlabel_acc = CategoricalAccuracy()
        self._attachment_scores = AttachmentScores()
        self.current_epoch = 0

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

    def check_all_dimensions_match(self, encoder_output_dim):

        check_dimensions_match(encoder_output_dim,
                               self.edge_model.encoder_dim(),
                               "encoder output dim",
                               self.name + " input dim edge model")
        check_dimensions_match(encoder_output_dim,
                               self.supertagger.encoder_dim(),
                               "encoder output dim",
                               self.name + " supertagger input dim")
        check_dimensions_match(encoder_output_dim,
                               self.lexlabeltagger.encoder_dim(),
                               "encoder output dim",
                               self.name + " lexical label tagger input dim")

    @overrides
    def forward(
            self,  # type: ignore
            encoded_text_parsing: torch.Tensor,
            encoded_text_tagging: torch.Tensor,
            mask: torch.Tensor,
            pos_tags: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            supertags: torch.LongTensor = None,
            lexlabels: torch.LongTensor = None,
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        """
        Takes a batch of encoded sentences and returns a dictionary with loss and predictions.

        :param encoded_text_parsing: sentence representation of shape (batch_size, seq_len, encoder_output_dim)
        :param encoded_text_tagging: sentence representation of shape (batch_size, seq_len, encoder_output_dim)
            or None if formalism of batch doesn't need supertagging
        :param mask:  matching the sentence representation of shape (batch_size, seq_len)
        :param pos_tags: the accompanying pos tags (batch_size, seq_len)
        :param metadata:
        :param supertags: the accompanying supertags (batch_size, seq_len)
        :param lexlabels: the accompanying lexical labels (batch_size, seq_len)
        :param head_tags: the gold heads of every word (batch_size, seq_len)
        :param head_indices: the gold edge labels for each word (incoming edge, see amconll files) (batch_size, seq_len)
        :return:
        """

        encoded_text_parsing = self._dropout(encoded_text_parsing)
        if encoded_text_tagging is not None:
            encoded_text_tagging = self._dropout(encoded_text_tagging)

        batch_size, seq_len, _ = encoded_text_parsing.shape

        edge_existence_scores = self.edge_model.edge_existence(
            encoded_text_parsing, mask)  # shape (batch_size, seq_len, seq_len)
        # shape (batch_size, seq_len, num_supertags)
        if encoded_text_tagging is not None:
            supertagger_logits = self.supertagger.compute_logits(
                encoded_text_tagging)
            lexlabel_logits = self.lexlabeltagger.compute_logits(
                encoded_text_tagging
            )  # shape (batch_size, seq_len, num label tags)
        else:
            supertagger_logits = None
            lexlabel_logits = None

        # Make predictions on data:
        if self.training:
            predicted_heads = self._greedy_decode_arcs(edge_existence_scores,
                                                       mask)
            edge_label_logits = self.edge_model.label_scores(
                encoded_text_parsing, predicted_heads
            )  # shape (batch_size, seq_len, num edge labels)
            predicted_edge_labels = self._greedy_decode_edge_labels(
                edge_label_logits)
        else:
            # Find best tree with CLE
            predicted_heads = cle_decode(edge_existence_scores,
                                         mask.data.sum(dim=1).long())
            # With info about tree structure, get edge label scores
            edge_label_logits = self.edge_model.label_scores(
                encoded_text_parsing, predicted_heads)
            # Predict edge labels
            predicted_edge_labels = self._greedy_decode_edge_labels(
                edge_label_logits)

        output_dict = {
            "heads":
            predicted_heads,
            "edge_existence_scores":
            edge_existence_scores,
            "label_logits":
            edge_label_logits,  # shape (batch_size, seq_len, num edge labels)
            "full_label_logits":
            self.edge_model.full_label_scores(
                encoded_text_parsing
            ),  #these are mostly required for the projective decoder
            "mask":
            mask,
            "words": [meta["words"] for meta in metadata],
            "attributes": [meta["attributes"] for meta in metadata],
            "token_ranges": [meta["token_ranges"] for meta in metadata],
            "encoded_text_parsing":
            encoded_text_parsing,
            "encoded_text_tagging":
            encoded_text_tagging,
            "position_in_corpus":
            [meta["position_in_corpus"] for meta in metadata],
            "formalism":
            self.name
        }
        if encoded_text_tagging is not None and self.loss_mixing[
                "supertagging"] is not None:
            output_dict[
                "supertag_scores"] = supertagger_logits  # shape (batch_size, seq_len, num supertags)
            output_dict["best_supertags"] = Supertagger.top_k_supertags(
                supertagger_logits,
                1).squeeze(2)  # shape (batch_size, seq_len)
        if encoded_text_tagging is not None and self.loss_mixing[
                "lexlabel"] is not None:
            if not self.output_null_lex_label:
                bottom_lex_label_index = self.vocab.get_token_index(
                    "_", namespace=self.name + "_lex_labels")
                masked_lexlabel_logits = lexlabel_logits.clone().detach(
                )  # shape (batch_size, seq_len, num label tags)
                masked_lexlabel_logits[:, :, bottom_lex_label_index] = -1e20
            else:
                masked_lexlabel_logits = lexlabel_logits

            output_dict["lexlabels"] = Supertagger.top_k_supertags(
                masked_lexlabel_logits,
                1).squeeze(2)  # shape (batch_size, seq_len)

        is_annotated = metadata[0]["is_annotated"]
        if any(metadata[i]["is_annotated"] != is_annotated
               for i in range(batch_size)):
            raise ValueError(
                "Batch contained inconsistent information if data is annotated."
            )

        # Compute loss:
        if is_annotated and head_indices is not None and head_tags is not None:
            gold_edge_label_logits = self.edge_model.label_scores(
                encoded_text_parsing, head_indices)
            edge_label_loss = self.loss_function.label_loss(
                gold_edge_label_logits, mask, head_tags)
            edge_existence_loss = self.loss_function.edge_existence_loss(
                edge_existence_scores, head_indices, mask)

            # compute loss, remove loss for artificial root
            if encoded_text_tagging is not None and self.loss_mixing[
                    "supertagging"] is not None:
                supertagger_logits = supertagger_logits[:, 1:, :].contiguous()
                supertagging_nll = self.supertagger_loss.loss(
                    supertagger_logits, supertags, mask[:, 1:])
            else:
                supertagging_nll = None

            if encoded_text_tagging is not None and self.loss_mixing[
                    "lexlabel"] is not None:
                lexlabel_logits = lexlabel_logits[:, 1:, :].contiguous()
                lexlabel_nll = self.lexlabel_loss.loss(lexlabel_logits,
                                                       lexlabels, mask[:, 1:])
            else:
                lexlabel_nll = None

            loss = self.loss_mixing[
                "edge_existence"] * edge_existence_loss + self.loss_mixing[
                    "edge_label"] * edge_label_loss
            if supertagging_nll is not None:
                loss += self.loss_mixing["supertagging"] * supertagging_nll
            if lexlabel_nll is not None:
                loss += self.loss_mixing["lexlabel"] * lexlabel_nll

            # Compute LAS/UAS/Supertagging acc/Lex label acc:
            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            if edge_existence_loss is not None and edge_label_loss is not None:
                self._attachment_scores(predicted_heads[:, 1:],
                                        predicted_edge_labels[:, 1:],
                                        head_indices[:, 1:], head_tags[:, 1:],
                                        evaluation_mask)
            if supertagging_nll is not None:
                self._supertagging_acc(supertagger_logits, supertags,
                                       mask[:,
                                            1:])  # compare against gold data
            if lexlabel_nll is not None:
                self._lexlabel_acc(lexlabel_logits, lexlabels,
                                   mask[:, 1:])  # compare against gold data

            output_dict["arc_loss"] = edge_existence_loss
            output_dict["edge_label_loss"] = edge_label_loss
            output_dict["supertagging_loss"] = supertagging_nll
            output_dict["lexlabel_loss"] = lexlabel_nll
            output_dict["loss"] = loss
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]):
        """
        In contrast to its name, this function does not perform the decoding but only prepares it.
        Therefore, we take the result of forward and perform the following steps (for each sentence in batch):
        - remove padding
        - identifiy the root of the sentence, group other root-candidates under the proper root
        - collect a selection of supertags to speed up computation (top k selection is done later)
        :param output_dict: result of forward
        :return: output_dict with the following keys added:
            - lexlabels: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root)
            - supertags: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root)
        """
        best_supertags = output_dict.pop(
            "best_supertags").cpu().detach().numpy()
        supertag_scores = output_dict.pop(
            "supertag_scores")  # shape (batch_size, seq_len, num supertags)
        full_label_logits = output_dict.pop("full_label_logits").cpu().detach(
        ).numpy()  #shape (batch size, seq len, seq len, num edge labels)
        edge_existence_scores = output_dict.pop(
            "edge_existence_scores").cpu().detach().numpy(
            )  #shape (batch size, seq len, seq len, num edge labels)
        k = 10
        if self.validation_evaluator:  #retrieve k supertags from validation evaluator.
            if isinstance(self.validation_evaluator.predictor,
                          AMconllPredictor):
                k = self.validation_evaluator.predictor.k
        top_k_supertags = Supertagger.top_k_supertags(
            supertag_scores,
            k).cpu().detach().numpy()  # shape (batch_size, seq_len, k)
        supertag_scores = supertag_scores.cpu().detach().numpy()
        lexlabels = output_dict.pop(
            "lexlabels").cpu().detach().numpy()  #shape (batch_size, seq_len)
        heads = output_dict.pop("heads")
        heads_cpu = heads.cpu().detach().numpy()
        mask = output_dict.pop("mask")
        edge_label_logits = output_dict.pop("label_logits").cpu().detach(
        ).numpy()  # shape (batch_size, seq_len, num edge labels)
        encoded_text_parsing = output_dict.pop("encoded_text_parsing")
        output_dict.pop("encoded_text_tagging")  #don't need that
        lengths = get_lengths_from_binary_sequence_mask(mask)

        #here we collect things, in the end we will have one entry for each sentence:
        all_edge_label_logits = []
        all_supertags = []
        head_indices = []
        roots = []
        all_predicted_lex_labels = []
        all_full_label_logits = []
        all_edge_existence_scores = []
        all_supertag_scores = []

        #we need the following to identify the root
        root_edge_label_id = self.vocab.get_token_index("ROOT",
                                                        namespace=self.name +
                                                        "_head_tags")
        bot_id = self.vocab.get_token_index(AMSentence.get_bottom_supertag(),
                                            namespace=self.name +
                                            "_supertag_labels")

        for i, length in enumerate(lengths):
            instance_heads_cpu = list(heads_cpu[i, 1:length])
            #Postprocess heads and find root of sentence:
            instance_heads_cpu, root = find_root(
                instance_heads_cpu,
                best_supertags[i, 1:length],
                edge_label_logits[i, 1:length, :],
                root_edge_label_id,
                bot_id,
                modify=True)
            roots.append(root)
            #apply changes to instance_heads tensor:
            instance_heads = heads[i, :]
            for j, x in enumerate(instance_heads_cpu):
                instance_heads[j + 1] = torch.tensor(
                    x
                )  #+1 because we removed the first position from instance_heads_cpu

            # re-calculate edge label logits since heads might have changed:
            label_logits = self.edge_model.label_scores(
                encoded_text_parsing[i].unsqueeze(0),
                instance_heads.unsqueeze(0)).squeeze(0).detach().cpu().numpy()
            #(un)squeeze: fake batch dimension
            all_edge_label_logits.append(label_logits[1:length, :])

            all_full_label_logits.append(
                full_label_logits[i, :length, :length, :])
            all_edge_existence_scores.append(
                edge_existence_scores[i, :length, :length])

            #calculate supertags for this sentence:
            all_supertag_scores.append(supertag_scores[
                i, 1:length, :])  #new shape (sent length, num supertags)
            supertags_for_this_sentence = []
            for word in range(1, length):
                supertags_for_this_word = []
                for top_k in top_k_supertags[i, word]:
                    fragment, typ = AMSentence.split_supertag(
                        self.vocab.get_token_from_index(top_k,
                                                        namespace=self.name +
                                                        "_supertag_labels"))
                    score = supertag_scores[i, word, top_k]
                    supertags_for_this_word.append((score, fragment, typ))
                if bot_id not in top_k_supertags[
                        i,
                        word]:  #\bot is not in the top k, but we have to add it anyway in order for the decoder to work properly.
                    fragment, typ = AMSentence.split_supertag(
                        AMSentence.get_bottom_supertag())
                    supertags_for_this_word.append(
                        (supertag_scores[i, word, bot_id], fragment, typ))
                supertags_for_this_sentence.append(supertags_for_this_word)
            all_supertags.append(supertags_for_this_sentence)
            all_predicted_lex_labels.append([
                self.vocab.get_token_from_index(label,
                                                namespace=self.name +
                                                "_lex_labels")
                for label in lexlabels[i, 1:length]
            ])
            head_indices.append(instance_heads_cpu)

        output_dict["lexlabels"] = all_predicted_lex_labels
        output_dict["supertags"] = all_supertags
        output_dict["root"] = roots
        output_dict["label_logits"] = all_edge_label_logits
        output_dict["predicted_heads"] = head_indices
        output_dict["full_label_logits"] = all_full_label_logits
        output_dict["edge_existence_scores"] = all_edge_existence_scores
        output_dict["supertag_scores"] = all_supertag_scores
        return output_dict

    def _greedy_decode_edge_labels(
            self, edge_label_logits: torch.Tensor) -> torch.Tensor:
        """
        Assigns edge labels according to (existing) edges.
        Parameters
        ----------
        edge_label_logits: ``torch.Tensor`` of shape (batch_size, sequence_length, num_head_tags)

        Returns
        -------
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded head tags (labels of incoming edges) of each word.
        """
        _, head_tags = edge_label_logits.max(dim=2)
        return head_tags

    def _greedy_decode_arcs(self, existence_scores: torch.Tensor,
                            mask: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head  predictions by decoding the unlabeled arcs
        independently for each word. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        Parameters
        ----------
        existence_scores : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        mask: torch.Tensor, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        existence_scores = existence_scores + torch.diag(
            existence_scores.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).byte().unsqueeze(2)
            existence_scores.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = existence_scores.max(dim=2)
        return heads

    def _get_mask_for_eval(self, mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    def metrics(self,
                parser_model,
                reset: bool = False,
                model_path=None) -> Dict[str, float]:
        """
        Is called by a GraphDependencyParser
        :param parser_model: a GraphDependencyParser
        :param reset:
        :return:
        """
        r = self.get_metrics(reset)
        if reset:  #epoch done
            if self.training:  #done on the training data
                self.current_epoch += 1
            else:  #done on dev/test data
                if self.validation_evaluator:
                    metrics = self.validation_evaluator.eval(
                        parser_model, self.current_epoch, model_path)
                    for name, val in metrics.items():
                        r[name] = val
        return r

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        r = self._attachment_scores.get_metric(reset)
        if self.loss_mixing["supertagging"] is not None:
            r["Constant_Acc"] = self._supertagging_acc.get_metric(reset)
        if self.loss_mixing["lexlabel"] is not None:
            r["Label_Acc"] = self._lexlabel_acc.get_metric(reset)
        return r
Exemple #24
0
class SeqClassificationModel(Model):
    """
    Question answering model where answers are sentences
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        use_sep: bool = True,
        with_crf: bool = False,
        self_attn: Seq2SeqEncoder = None,
        bert_dropout: float = 0.1,
        sci_sum: bool = False,
        additional_feature_size: int = 0,
    ) -> None:
        super(SeqClassificationModel, self).__init__(vocab)

        self.text_field_embedder = text_field_embedder
        self.vocab = vocab
        self.use_sep = use_sep
        self.with_crf = with_crf
        self.sci_sum = sci_sum
        self.self_attn = self_attn
        self.additional_feature_size = additional_feature_size

        self.dropout = torch.nn.Dropout(p=bert_dropout)

        # define loss
        if self.sci_sum:
            self.loss = torch.nn.MSELoss(
                reduction='none')  # labels are rouge scores
            self.labels_are_scores = True
            self.num_labels = 1
        else:
            self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1,
                                                  reduction='none')
            self.labels_are_scores = False
            self.num_labels = self.vocab.get_vocab_size(namespace='labels')
            # define accuracy metrics
            self.label_accuracy = CategoricalAccuracy()
            self.label_f1_metrics = {}

            # define F1 metrics per label
            for label_index in range(self.num_labels):
                label_name = self.vocab.get_token_from_index(
                    namespace='labels', index=label_index)
                self.label_f1_metrics[label_name] = F1Measure(label_index)

        encoded_senetence_dim = text_field_embedder._token_embedders[
            'bert'].output_dim

        ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim(
        )
        ff_in_dim += self.additional_feature_size

        self.time_distributed_aggregate_feedforward = TimeDistributed(
            Linear(ff_in_dim, self.num_labels))

        if self.with_crf:
            self.crf = ConditionalRandomField(
                self.num_labels,
                constraints=None,
                include_start_end_transitions=True)

    def forward(
        self,  # type: ignore
        sentences: torch.LongTensor,
        labels: torch.IntTensor = None,
        confidences: torch.Tensor = None,
        additional_features: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences

        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = embedded_sentences.size()

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            sentences_mask = sentences[
                'bert'] == 102  # mask for all the SEP tokens in the batch
            embedded_sentences = embedded_sentences[
                sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
            # returns num_sentences_per_batch x vector_len
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(dim=0)
            embedded_sentences = self.dropout(embedded_sentences)

            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[
                    labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(
                        f'Found {num_labels} labels but {num_sentences} sentences'
                    )
                    labels = labels[:
                                    num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                    # We are ignoring this problem for now.
                    # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:
                                                                  num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = embedded_sentences.size()
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat(
                (embedded_sentences, additional_features), dim=-1)

        label_logits = self.time_distributed_aggregate_feedforward(
            embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            flattened_logits = label_logits.view((batch_size * num_sentences),
                                                 self.num_labels)
            flattened_gold = labels.contiguous().view(-1)

            if not self.with_crf:
                label_loss = self.loss(flattened_logits.squeeze(),
                                       flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(
                        label_loss).view(-1)
                label_loss = label_loss.mean()
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels,
                                          mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view(
                    (batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                evaluation_mask = (flattened_gold != -1)
                self.label_accuracy(flattened_probs.float().contiguous(),
                                    flattened_gold.squeeze(-1),
                                    mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(
                        namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs,
                           flattened_gold,
                           mask=evaluation_mask)

        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict

    def get_metrics(self, reset: bool = False):
        metric_dict = {}

        if not self.labels_are_scores:
            type_accuracy = self.label_accuracy.get_metric(reset)
            metric_dict['acc'] = type_accuracy

            average_F1 = 0.0
            for name, metric in self.label_f1_metrics.items():
                metric_val = metric.get_metric(reset)
                metric_dict[name + 'F'] = metric_val[2]
                average_F1 += metric_val[2]

            average_F1 /= len(self.label_f1_metrics.items())
            metric_dict['avgF'] = average_F1

        return metric_dict
Exemple #25
0
class BidirectionalAttentionFlow_1(Model):
    """
    This class implements a Bayesian version of Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).
    """
    
    def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None:
        super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer)
        
        """
        Initialize some data structures 
        """
        self.cf_a = cf_a
        # Bayesian data models
        self.VBmodels = []
        self.LinearModels = []
        """
        ############## TEXT FIELD EMBEDDER with ELMO ####################
        text_field_embedder : ``TextFieldEmbedder``
            Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
        """
        if (cf_a.use_ELMO):
            if (type(preloaded_elmo) != type(None)):
                text_field_embedder = preloaded_elmo
            else:
                text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput )
                print ("ELMO loaded from disk or downloaded")
        else:
            text_field_embedder = None
        
#        embedder_out_dim  = text_field_embedder.get_output_dim()
        self._text_field_embedder = text_field_embedder
        
        if(cf_a.Add_Linear_projection_ELMO):
            if (self.cf_a.VB_Linear_projection_ELMO):
                prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior))
                print ("----------------- Bayesian Linear Projection ELMO --------------")
                linear_projection_ELMO =  LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior)
                self.VBmodels.append(linear_projection_ELMO)
            else:
                linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200)
        
            self._linear_projection_ELMO = linear_projection_ELMO
            
        """
        ############## Highway layers ####################
        num_highway_layers : ``int``
            The number of highway layers to use in between embedding the input and passing it through
            the phrase layer.
        """
        
        Input_dimension_highway = None
        if (cf_a.Add_Linear_projection_ELMO):
            Input_dimension_highway = 200
        else:
            Input_dimension_highway = text_field_embedder.get_output_dim()
            
        num_highway_layers = cf_a.num_highway_layers
        # Linear later to compute the start 
        if (self.cf_a.VB_highway_layers):
            print ("----------------- Bayesian Highway network  --------------")
            prior = Vil.Prior(**(cf_a.VB_highway_layers_prior))
            highway_layer = HighwayVB(Input_dimension_highway,
                                                          num_highway_layers, prior = prior)
            self.VBmodels.append(highway_layer)
        else:
            
            highway_layer = Highway(Input_dimension_highway,
                                                          num_highway_layers)
        highway_layer = TimeDistributed(highway_layer)
        
        self._highway_layer = highway_layer
        
        """
        ############## Phrase layer ####################
        phrase_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between embedding tokens
            and doing the bidirectional attention.
        """
        if cf_a.phrase_layer_dropout > 0:       ## Create dropout layer
            dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout)
        else:
            dropout_phrase_layer = lambda x: x
        
        phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout))
        
        phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2
        self._phrase_layer = phrase_layer
        self._dropout_phrase_layer = dropout_phrase_layer
        
        """
        ############## Matrix attention layer ####################
        similarity_function : ``SimilarityFunction``
            The similarity function that we will use when comparing encoded passage and question
            representations.
        """
        
        # Linear later to compute the start 
        if (self.cf_a.VB_similarity_function):
            prior = Vil.Prior(**(cf_a.VB_similarity_function_prior))
            print ("----------------- Bayesian Similarity matrix --------------")
            similarity_function = LinearSimilarityVB(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim, prior = prior)
            self.VBmodels.append(similarity_function)
        else:
            similarity_function = LinearSimilarity(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim)
            
        matrix_attention = LegacyMatrixAttention(similarity_function)
        self._matrix_attention = matrix_attention
        
        """
        ############## Modelling Layer ####################
        modeling_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between the bidirectional
            attention and predicting span start and end.
        """
        ## Create dropout layer
        if cf_a.modeling_passage_dropout > 0:       ## Create dropout layer
            dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout)
        else:
            dropout_modeling_passage = lambda x: x
        
        modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout))

        self._modeling_layer = modeling_layer
        self._dropout_modeling_passage = dropout_modeling_passage
        
        """
        ############## Span Start Representation #####################
        span_end_encoder : ``Seq2SeqEncoder``
            The encoder that we will use to incorporate span start predictions into the passage state
            before predicting span end.
        """
        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        
        # Linear later to compute the start 
        if (self.cf_a.VB_span_start_predictor_linear):
            prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior))
            print ("----------------- Bayesian Span Start Predictor--------------")
            span_start_predictor_linear =  LinearVB(span_start_input_dim, 1, prior = prior)
            self.VBmodels.append(span_start_predictor_linear)
        else:
            span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1)
            
        self._span_start_predictor_linear = span_start_predictor_linear
        self._span_start_predictor = TimeDistributed(span_start_predictor_linear)

        """
        ############## Span End Representation #####################
        """
        
        ## Create dropout layer
        if cf_a.span_end_encoder_dropout > 0:
            dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
        else:
            dropout_span_end_encode = lambda x: x
        
        span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout))
   
        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        
        self._span_end_encoder = span_end_encoder
        self._dropout_span_end_encode = dropout_span_end_encode
        
        if (self.cf_a.VB_span_end_predictor_linear):
            print ("----------------- Bayesian Span End Predictor--------------")
            prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior))
            span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior)
            self.VBmodels.append(span_end_predictor_linear) 
        else:
            span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1)
        
        self._span_end_predictor_linear = span_end_predictor_linear
        self._span_end_predictor = TimeDistributed(span_end_predictor_linear)

        """
        Dropput last layers
        """
        if cf_a.spans_output_dropout > 0:
            dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
        else:
            dropout_spans_output = lambda x: x
        
        self._dropout_spans_output = dropout_spans_output
        
        """
        Checkings and accuracy
        """
        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        """
        mask_lstms : ``bool``, optional (default=True)
            If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
            with only a slight performance decrease, if any.  We haven't experimented much with this
            yet, but have confirmed that we still get very similar performance with much faster
            training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
            required when using masking with pytorch LSTMs.
        """
        self._mask_lstms = cf_a.mask_lstms

    
        """
        ################### Initialize parameters ##############################
        """
        #### THEY ARE ALL INITIALIZED WHEN INSTANTING THE COMPONENTS ###
    
        """
        ####################### OPTIMIZER ################
        """
        optimizer = pytut.get_optimizers(self, cf_a)
        self._optimizer = optimizer
        #### TODO: Learning rate scheduler ####
        #scheduler = optim.ReduceLROnPlateau(optimizer, 'max')
    
    def forward_ensemble(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False) -> Dict[str, torch.Tensor]:
        """
        Sample 10 times and add them together
        """
        self.set_posterior_mean(True)
        most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)
        self.set_posterior_mean(False)
       
        subresults = [most_likely_output]
        for i in range(10):
           subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information))

        batch_size = len(subresults[0]["best_span"])

        best_span = bidut.merge_span_probs(subresults)
        
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
                        
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output
    
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False,
                get_attentions = False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        
        """
        #################### Sample Bayesian weights ##################
        """
        self.sample_posterior()
        
        """
        ################## MASK COMPUTING ########################
        """
                
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None
        
        """
        ###################### EMBEDDING + HIGHWAY LAYER ########################
        """
#        self.cf_a.use_ELMO
        
        if(self.cf_a.Add_Linear_projection_ELMO):
            embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]))
            embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]))
        else:
            embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])
            embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        
        """
        ###################### phrase_layer LAYER ########################
        """

        encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        """
        ###################### Attention LAYER ########################
        """
        
        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)
        
        """
        ###################### Spans LAYER ########################
        """
        
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout_spans_output(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout_spans_output(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        
        best_span = bidut.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            
            span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            loss = span_start_loss + span_end_loss

            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            
            output_dict["loss"] = loss
            output_dict["span_start_loss"] = span_start_loss
            output_dict["span_end_loss"] = span_end_loss
            
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            if (get_sample_level_information):
                output_dict["em_samples"] = []
                output_dict["f1_samples"] = []
                
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output_dict["em_samples"].append(em_sample)
                        output_dict["f1_samples"].append(f1_sample)
                        
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output_dict["span_start_sample_loss"] = []
            output_dict["span_end_sample_loss"] = []
            for i in range (batch_size):
                span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]])
                
                output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        if(get_attentions):
            output_dict["C2Q_attention"] = passage_question_attention
            output_dict["Q2C_attention"] = question_passage_attention
            output_dict["simmilarity"] = passage_question_similarity
            
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
                }
    
    def train_batch(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        
        """
        It is enough to just compute the total loss because the normal weights 
        do not depend on the KL Divergence
        """
        # Now we can just compute both losses which will build the dynamic graph
        
        output = self.forward(question,passage,span_start,span_end,metadata )
        data_loss = output["loss"]
        
        KL_div = self.get_KL_divergence()
        total_loss =  self.combine_losses(data_loss, KL_div)
        
        self.zero_grad()     # zeroes the gradient buffers of all parameters
        total_loss.backward()
        
        if (type(self._optimizer) == type(None)):
            parameters = filter(lambda p: p.requires_grad, self.parameters())
            with torch.no_grad():
                for f in parameters:
                    f.data.sub_(f.grad.data * self.lr )
        else:
#            print ("Training")
            self._optimizer.step()
            self._optimizer.zero_grad()
            
        return output
    
    def fill_batch_training_information(self, training_logger,
                                        output_batch):
        """
        Function to fill the the training_logger for each batch. 
        training_logger: Dictionary that will hold all the training info
        output_batch: Output from training the batch
        """
        training_logger["train"]["span_start_loss_batch"].append(output_batch["span_start_loss"].detach().cpu().numpy())
        training_logger["train"]["span_end_loss_batch"].append(output_batch["span_end_loss"].detach().cpu().numpy())
        training_logger["train"]["loss_batch"].append(output_batch["loss"].detach().cpu().numpy())
        # Training metrics:
        metrics = self.get_metrics()
        training_logger["train"]["start_acc_batch"].append(metrics["start_acc"])
        training_logger["train"]["end_acc_batch"].append(metrics["end_acc"])
        training_logger["train"]["span_acc_batch"].append(metrics["span_acc"])
        training_logger["train"]["em_batch"].append(metrics["em"])
        training_logger["train"]["f1_batch"].append(metrics["f1"])
    
        
    def fill_epoch_training_information(self, training_logger,device,
                                        validation_iterable, num_batches_validation):
        """
        Fill the information per each epoch
        """
        Ntrials_CUDA = 100
        # Training Epoch final metrics
        metrics = self.get_metrics(reset = True)
        training_logger["train"]["start_acc"].append(metrics["start_acc"])
        training_logger["train"]["end_acc"].append(metrics["end_acc"])
        training_logger["train"]["span_acc"].append(metrics["span_acc"])
        training_logger["train"]["em"].append(metrics["em"])
        training_logger["train"]["f1"].append(metrics["f1"])
        
        self.set_posterior_mean(True)
        self.eval()
        
        data_loss_validation = 0
        loss_validation = 0
        with torch.no_grad():
            # Compute the validation accuracy by using all the Validation dataset but in batches.
            for j in range(num_batches_validation):
                tensor_dict = next(validation_iterable)
                
                trial_index = 0
                while (1):
                    try:
                        tensor_dict = pytut.move_to_device(tensor_dict, device) ## Move the tensor to cuda
                        output_batch = self.forward(**tensor_dict)
                        break;
                    except RuntimeError as er:
                        print (er.args)
                        torch.cuda.empty_cache()
                        time.sleep(5)
                        torch.cuda.empty_cache()
                        trial_index += 1
                        if (trial_index == Ntrials_CUDA):
                            print ("Too many failed trials to allocate in memory")
                            send_error_email(str(er.args))
                            sys.exit(0)
                
                data_loss_validation += output_batch["loss"].detach().cpu().numpy() 
                        
                ## Memmory management !!
            if (self.cf_a.force_free_batch_memory):
                del tensor_dict["question"]; del tensor_dict["passage"]
                del tensor_dict
                del output_batch
                torch.cuda.empty_cache()
            if (self.cf_a.force_call_garbage_collector):
                gc.collect()
                
            data_loss_validation = data_loss_validation/num_batches_validation
#            loss_validation = loss_validation/num_batches_validation
    
            # Training Epoch final metrics
        metrics = self.get_metrics(reset = True)
        training_logger["validation"]["start_acc"].append(metrics["start_acc"])
        training_logger["validation"]["end_acc"].append(metrics["end_acc"])
        training_logger["validation"]["span_acc"].append(metrics["span_acc"])
        training_logger["validation"]["em"].append(metrics["em"])
        training_logger["validation"]["f1"].append(metrics["f1"])
        
        training_logger["validation"]["data_loss"].append(data_loss_validation)
        self.train()
        self.set_posterior_mean(False)
    
    def trim_model(self, mu_sigma_ratio = 2):
        
        total_size_w = []
        total_removed_w = []
        total_size_b = []
        total_removed_b = []
        
        if (self.cf_a.VB_Linear_projection_ELMO):
                VBmodel = self._linear_projection_ELMO
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_highway_layers):
                VBmodel = self._highway_layer._module.VBmodels[0]
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_similarity_function):
                VBmodel = self._matrix_attention._similarity_function
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_span_start_predictor_linear):
                VBmodel = self._span_start_predictor_linear
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_span_end_predictor_linear):
                VBmodel = self._span_end_predictor_linear
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        
        return  total_size_w, total_removed_w, total_size_b, total_removed_b
#    print (weights_to_remove_W.shape)

    
    """
    BAYESIAN NECESSARY FUNCTIONS
    """
    sample_posterior = GeneralVBModel.sample_posterior
    get_KL_divergence = GeneralVBModel.get_KL_divergence
    set_posterior_mean = GeneralVBModel.set_posterior_mean
    combine_losses = GeneralVBModel.combine_losses
    
    def save_VB_weights(self):
        """
        Function that saves only the VB weights of the model.
        """
        pretrained_dict = ...
        model_dict = self.state_dict()
        
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict) 
        # 3. load the new state dict
        self.load_state_dict(pretrained_dict)
Exemple #26
0
class LstmClassifier(Model):
    def __init__(self,
                 embedder: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 vocab: Vocabulary,
                 positive_label: str = '4') -> None:
        super().__init__(vocab)
        # We need the embeddings to convert word IDs to their vector representations
        self.embedder = embedder

        self.encoder = encoder

        # After converting a sequence of vectors to a single vector, we feed it into
        # a fully-connected linear layer to reduce the dimension to the total number of labels.
        self.linear = torch.nn.Linear(
            in_features=encoder.get_output_dim(),
            out_features=vocab.get_vocab_size('labels'))

        # Monitor the metrics - we use accuracy, as well as prec, rec, f1 for 4 (very positive)
        positive_index = vocab.get_token_index(positive_label,
                                               namespace='labels')
        self.accuracy = CategoricalAccuracy()
        self.f1_measure = F1Measure(positive_index)

        # We use the cross entropy loss because this is a classification task.
        # Note that PyTorch's CrossEntropyLoss combines softmax and log likelihood loss,
        # which makes it unnecessary to add a separate softmax layer.
        self.loss_function = torch.nn.CrossEntropyLoss()

    # Instances are fed to forward after batching.
    # Fields are passed through arguments with the same name.
    def forward(self,
                tokens: TextFieldTensors,
                label: torch.Tensor = None) -> torch.Tensor:
        # In deep NLP, when sequences of tensors in different lengths are batched together,
        # shorter sequences get padded with zeros to make them equal length.
        # Masking is the process to ignore extra zeros added by padding
        mask = get_text_field_mask(tokens)

        # Forward pass
        embeddings = self.embedder(tokens)
        encoder_out = self.encoder(embeddings, mask)
        logits = self.linear(encoder_out)

        probs = torch.softmax(logits, dim=-1)
        # In AllenNLP, the output of forward() is a dictionary.
        # Your output dictionary must contain a "loss" key for your model to be trained.
        output = {"logits": logits, "cls_emb": encoder_out, "probs": probs}
        if label is not None:
            self.accuracy(logits, label)
            self.f1_measure(logits, label)
            output["loss"] = self.loss_function(logits, label)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        precision, recall, f1_measure = self.f1_measure.get_metric(reset)
        return {
            'accuracy': self.accuracy.get_metric(reset),
            'precision': precision,
            'recall': recall,
            'f1_measure': f1_measure
        }
Exemple #27
0
class SimpleFusion(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 regularizer: Optional[RegularizerApplicator] = None,
                 detector_final_dim: int = 512,
                 dropout: float = 0.5,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        """

        :param vocab:
        :param text_field_embedder:
        :param encoder:
        :param output_feedforward:
        :param regularizer:
        :param detector_final_dim:
        :param dropout:
        :param initializer:
        """
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._encoder = encoder
        self.detector = SimpleDetector(detector_final_dim)

        if dropout:
            self.dropout = nn.Dropout(dropout)
            self.rnn_input_dropout = InputVariationalDropout(dropout)
        else:
            self.dropout = None
            self.rnn_input_dropout = None

        self._output_feedforward = output_feedforward

        self._accuracy = CategoricalAccuracy()
        self._loss = nn.CrossEntropyLoss()
        initializer(self)

    def forward(self,
                premise_img: torch.Tensor,
                hypothesis: Dict[str, torch.Tensor],
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        """

        :param premise_img:
        :param hypothesis:
        :param label:
        :return:
        """
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self.rnn_input_dropout:
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        hypothesis_hidden_state = get_final_encoder_states(
            encoded_hypothesis,
            hypothesis_mask,
            self._encoder.is_bidirectional()
        )

        img_feats = self.detector(premise_img)

        fused_features = torch.cat((img_feats, hypothesis_hidden_state), dim=-1)

        label_logits = self._output_feedforward(fused_features)
        label_probs = nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self._accuracy.get_metric(reset)}
Exemple #28
0
class BertForClassification(Model):
    """
    An AllenNLP Model that runs pretrained BERT,
    takes the pooled output, and adds a Linear layer on top.
    If you want an easy way to use BERT for classification, this is it.
    Note that this is a somewhat non-AllenNLP-ish model architecture,
    in that it essentially requires you to use the "bert-pretrained"
    token indexer, rather than configuring whatever indexing scheme you like.
    See `allennlp/tests/fixtures/bert/bert_for_classification.jsonnet`
    for an example of what your config might look like.
    # Parameters
    vocab : ``Vocabulary``
    bert_model : ``Union[str, BertModel]``
        The BERT model to be wrapped. If a string is provided, we will call
        ``BertModel.from_pretrained(bert_model)`` and use the result.
    num_labels : ``int``, optional (default: None)
        How many output classes to predict. If not provided, we'll use the
        vocab_size for the ``label_namespace``.
    index : ``str``, optional (default: "bert")
        The index of the token indexer that generates the BERT indices.
    label_namespace : ``str``, optional (default : "labels")
        Used to determine the number of classes if ``num_labels`` is not supplied.
    trainable : ``bool``, optional (default : True)
        If True, the weights of the pretrained BERT model will be updated during training.
        Otherwise, they will be frozen and only the final linear layer will be trained.
    initializer : ``InitializerApplicator``, optional
        If provided, will be used to initialize the final linear layer *only*.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        dropout: float = 0.0,
        num_labels: int = None,
        index: str = "bert",
        label_namespace: str = "labels",
        trainable: bool = True,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        if isinstance(bert_model, str):
            self.bert_model = PretrainedBertModel.load(bert_model)
        else:
            self.bert_model = bert_model

        for param in self.bert_model.parameters():
            param.requires_grad = trainable

        in_features = self.bert_model.config.hidden_size

        self._label_namespace = label_namespace

        if num_labels:
            out_features = num_labels
        else:
            out_features = vocab.get_vocab_size(
                namespace=self._label_namespace)

        self._dropout = torch.nn.Dropout(p=dropout)

        self._classification_layer = torch.nn.Linear(in_features, out_features)
        self._accuracy = CategoricalAccuracy()

        # ****** add by jlk ******
        self._f1score = F1Measure(positive_label=1)
        # ****** add by jlk ******

        self._loss = torch.nn.CrossEntropyLoss()
        self._index = index
        initializer(self._classification_layer)

    def forward(  # type: ignore
            self,
            data_id: Dict[str, torch.LongTensor],
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        # Parameters
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField`` (that has a bert-pretrained token indexer)
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        # Returns
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        input_ids = tokens[self._index]
        token_type_ids = tokens[f"{self._index}-type-ids"]
        input_mask = (input_ids != 0).long()

        _, pooled = self.bert_model(input_ids=input_ids,
                                    token_type_ids=token_type_ids,
                                    attention_mask=input_mask)

        pooled = self._dropout(pooled)

        # apply classification layer
        logits = self._classification_layer(pooled)

        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)
            self._f1score(logits, label)

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does a simple argmax over the probabilities, converts index to string label, and
        add ``"label"`` key to the dictionary with the result.
        """
        predictions = output_dict["probs"]
        if predictions.dim() == 2:
            predictions_list = [
                predictions[i] for i in range(predictions.shape[0])
            ]
        else:
            predictions_list = [predictions]
        classes = []
        for prediction in predictions_list:
            label_idx = prediction.argmax(dim=-1).item()
            label_str = self.vocab.get_index_to_token_vocabulary(
                self._label_namespace).get(label_idx, str(label_idx))
            classes.append(label_str)
        output_dict["label"] = classes
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        # ****** add by jlk ******
        metrics = {
            "accuracy": self._accuracy.get_metric(reset),
            "f1score": self._f1score.get_metric(reset)[2]
        }
        return metrics
Exemple #29
0
class ShallowProductOfExpertsClassifier(Model):
    """
    This `Model` implements a basic text classifier. After embedding the text into
    a text field, we will optionally encode the embeddings with a `Seq2SeqEncoder`. The
    resulting sequence is pooled using a `Seq2VecEncoder` and then passed to
    a linear classification layer, which projects into the label space. If a
    `Seq2SeqEncoder` is not provided, we will pass the embedded text directly to the
    `Seq2VecEncoder`.

    Registered as a `Model` with name "basic_classifier".

    # Parameters

    vocab : `Vocabulary`
    text_field_embedder : `TextFieldEmbedder`
        Used to embed the input text into a `TextField`
    seq2seq_encoder : `Seq2SeqEncoder`, optional (default=`None`)
        Optional Seq2Seq encoder layer for the input text.
    seq2vec_encoder : `Seq2VecEncoder`
        Required Seq2Vec encoder layer. If `seq2seq_encoder` is provided, this encoder
        will pool its output. Otherwise, this encoder will operate directly on the output
        of the `text_field_embedder`.
    feedforward : `FeedForward`, optional, (default = None).
        An optional feedforward layer to apply after the seq2vec_encoder.
    dropout : `float`, optional (default = `None`)
        Dropout percentage to use.
    num_labels : `int`, optional (default = `None`)
        Number of labels to project to in classification layer. By default, the classification layer will
        project to the size of the vocabulary namespace corresponding to labels.
    label_namespace : `str`, optional (default = "labels")
        Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        If provided, will be used to initialize the model parameters.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        beta: float,
        text_field_embedder: TextFieldEmbedder,
        seq2vec_encoder: Seq2VecEncoder,
        seq2seq_encoder: Seq2SeqEncoder = None,
        feedforward: Optional[FeedForward] = None,
        feedforward_hyp_only: Optional[FeedForward] = None,
        dropout: float = None,
        num_labels: int = None,
        label_namespace: str = "labels",
        evaluation_mode: bool = False,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:

        super().__init__(vocab, **kwargs)
        self.evaluation_mode = evaluation_mode
        self._text_field_embedder = text_field_embedder

        if seq2seq_encoder:
            self._seq2seq_encoder = seq2seq_encoder
        else:
            self._seq2seq_encoder = None

        self._seq2vec_encoder = seq2vec_encoder
        self._feedforward = feedforward
        self._feedforward_hyp_only = feedforward_hyp_only

        if feedforward is not None:
            self._classifier_input_dim = self._feedforward.get_output_dim()
        else:
            self._classifier_input_dim = self._seq2vec_encoder.get_output_dim()

        if feedforward_hyp_only is not None:
            self._classifier_hyp_only_input_dim = self._feedforward_hyp_only.get_output_dim(
            )
        else:
            self._classifier_hyp_only_input_dim = self._seq2vec_encoder.get_output_dim(
            )

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = None

        self._label_namespace = label_namespace

        if num_labels:
            self._num_labels = num_labels
        else:
            self._num_labels = vocab.get_vocab_size(
                namespace=self._label_namespace)

        self._classification_layer = torch.nn.Linear(
            self._classifier_input_dim, self._num_labels)
        self._classification_layer_hyp_only = torch.nn.Linear(
            self._classifier_hyp_only_input_dim, self._num_labels)

        self._beta = beta

        self._accuracy = CategoricalAccuracy()
        self._hyp_only_accuracy = CategoricalAccuracy()
        self._cross_ent_loss = torch.nn.CrossEntropyLoss()
        self._nll_loss = torch.nn.NLLLoss()
        initializer(self)

    def finetune(self):
        self.evaluation_mode = True

    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            bias_tokens: TextFieldTensors = None,
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        sentence_pair_logits = self._classification_layer(embedded_text)

        # If we're training, also compute loss and accuracy for the bias-only model
        if not self.evaluation_mode and bias_tokens is not None:
            # Make predictions with hypothesis only
            embedded_text = self._text_field_embedder(bias_tokens)
            mask = get_text_field_mask(bias_tokens)

            if self._seq2seq_encoder:
                embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)
            embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

            if self._dropout:
                embedded_text = self._dropout(embedded_text)

            if self._feedforward_hyp_only is not None:
                embedded_text = self._feedforward_hyp_only(embedded_text)
            hyp_only_logits = self._classification_layer_hyp_only(
                embedded_text)

            log_probs_pair = torch.log_softmax(sentence_pair_logits, dim=1)
            log_probs_hyp = torch.log_softmax(hyp_only_logits, dim=1)

            # Combine with product of experts (normalized log space sum)
            # Do not require gradients from hyp-only classifier
            combined = log_probs_pair + log_probs_hyp.detach()

            # NLL loss over combined labels
            loss = self._nll_loss(combined, label.long().view(-1))
            hyp_loss = self._nll_loss(log_probs_hyp, label.long().view(-1))
            self._accuracy(combined, label)
            self._hyp_only_accuracy(hyp_only_logits, label)

            output_dict = {"loss": loss + self._beta * hyp_loss}

            return output_dict

        else:
            loss = self._cross_ent_loss(sentence_pair_logits, label)
            self._accuracy(sentence_pair_logits, label)
            return {
                "loss": loss,
                "logits": sentence_pair_logits,
                "probs": torch.softmax(sentence_pair_logits, dim=1)
            }

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does a simple argmax over the probabilities, converts index to string label, and
        add `"label"` key to the dictionary with the result.
        """
        predictions = output_dict["probs"]
        if predictions.dim() == 2:
            predictions_list = [
                predictions[i] for i in range(predictions.shape[0])
            ]
        else:
            predictions_list = [predictions]
        classes = []
        for prediction in predictions_list:
            label_idx = prediction.argmax(dim=-1).item()
            label_str = self.vocab.get_index_to_token_vocabulary(
                self._label_namespace).get(label_idx, str(label_idx))
            classes.append(label_str)
        output_dict["label"] = classes
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {
            "hyp_only_accuracy": self._hyp_only_accuracy.get_metric(reset),
            "accuracy": self._accuracy.get_metric(reset)
        }
        return metrics
Exemple #30
0
class SmartdataEventxModel(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 span_extractor: SpanExtractor,
                 entity_embedder: TokenEmbedder,
                 hidden_dim: int,
                 loss_weight: float = 1.0,
                 trigger_gamma: float = None,
                 role_gamma: float = None,
                 positive_class_weight: float = 1.0,
                 triggers_namespace: str = 'trigger_labels',
                 roles_namespace: str = 'arg_role_labels',
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: RegularizerApplicator = None) -> None:
        super().__init__(vocab=vocab, regularizer=regularizer)
        self._triggers_namespace = triggers_namespace
        self._roles_namespace = roles_namespace
        self.num_trigger_classes = self.vocab.get_vocab_size(
            triggers_namespace)
        self.num_role_classes = self.vocab.get_vocab_size(roles_namespace)
        self.hidden_dim = hidden_dim
        self.loss_weight = loss_weight
        self.trigger_gamma = trigger_gamma
        self.role_gamma = role_gamma
        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.entity_embedder = entity_embedder
        self.span_extractor = span_extractor
        self.trigger_projection = Linear(self.encoder.get_output_dim(),
                                         self.num_trigger_classes)
        self.trigger_to_hidden = Linear(self.encoder.get_output_dim(),
                                        self.hidden_dim)
        self.entities_to_hidden = Linear(self.encoder.get_output_dim(),
                                         self.hidden_dim)
        self.hidden_bias = Parameter(torch.Tensor(self.hidden_dim))
        torch.nn.init.normal_(self.hidden_bias)
        self.hidden_to_roles = Linear(self.hidden_dim, self.num_role_classes)
        self.trigger_accuracy = CategoricalAccuracy()
        trigger_labels_to_idx = self.vocab.get_token_to_index_vocabulary(
            namespace=triggers_namespace)
        evaluated_trigger_idxs = list(trigger_labels_to_idx.values())
        evaluated_trigger_idxs.remove(
            trigger_labels_to_idx[NEGATIVE_TRIGGER_LABEL])
        self.trigger_f1 = MicroFBetaMeasure(
            average='micro',  # Macro averaging in get_metrics
            labels=evaluated_trigger_idxs)
        role_labels_to_idx = self.vocab.get_token_to_index_vocabulary(
            namespace=roles_namespace)
        evaluated_role_idxs = list(role_labels_to_idx.values())
        evaluated_role_idxs.remove(role_labels_to_idx[NEGATIVE_ARGUMENT_LABEL])
        self.role_accuracy = CategoricalAccuracy()
        self.role_f1 = MicroFBetaMeasure(
            average='micro',  # Macro averaging in get_metrics
            labels=evaluated_role_idxs)
        # Trigger class weighting as done in JMEE repo
        trigger_labels_to_idx = self.vocab\
            .get_token_to_index_vocabulary(namespace=triggers_namespace)
        self.trigger_class_weights = torch.ones(
            len(trigger_labels_to_idx)) * positive_class_weight
        self.trigger_class_weights[
            trigger_labels_to_idx[NEGATIVE_TRIGGER_LABEL]] = 1.0
        initializer(self)

    @overrides
    def forward(
            self,
            tokens: Dict[str, torch.LongTensor],
            entity_tags: torch.LongTensor,
            entity_spans: torch.LongTensor,
            trigger_spans: torch.LongTensor,
            trigger_labels: torch.LongTensor = None,
            arg_roles: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        embedded_tokens = self.text_field_embedder(tokens)
        text_mask = get_text_field_mask(tokens)
        embedded_entity_tags = self.entity_embedder(entity_tags)
        embedded_input = torch.cat([embedded_tokens, embedded_entity_tags],
                                   dim=-1)

        encoded_input = self.encoder(embedded_input, text_mask)

        ###########################
        # Trigger type prediction #
        ###########################

        # Extract the spans of the triggers
        trigger_spans_mask = (trigger_spans[:, :, 0] >= 0).long()
        encoded_triggers = self.span_extractor(
            sequence_tensor=encoded_input,
            span_indices=trigger_spans,
            sequence_mask=text_mask,
            span_indices_mask=trigger_spans_mask)

        # Pass the extracted triggers through a projection for classification
        trigger_logits = self.trigger_projection(encoded_triggers)

        # Add the trigger predictions to the output
        trigger_probabilities = F.softmax(trigger_logits, dim=-1)
        output_dict = {
            "trigger_logits": trigger_logits,
            "trigger_probabilities": trigger_probabilities
        }

        if trigger_labels is not None:
            # Compute loss and metrics using the given trigger labels
            trigger_mask = (trigger_labels != -1)
            trigger_labels = trigger_labels * trigger_mask
            self.trigger_accuracy(trigger_logits, trigger_labels,
                                  trigger_mask.float())
            self.trigger_f1(trigger_logits, trigger_labels,
                            trigger_mask.float())

            trigger_logits_t = trigger_logits.permute(0, 2, 1)
            trigger_loss = self._cross_entropy_focal_loss(
                logits=trigger_logits_t,
                target=trigger_labels,
                target_mask=trigger_mask,
                gamma=self.trigger_gamma,
                alpha=self.trigger_class_weights)

            output_dict["triggers_loss"] = trigger_loss
            output_dict["loss"] = trigger_loss

        ########################################
        # Argument detection and role labeling #
        ########################################

        # Extract the spans of the encoded entities
        entity_spans_mask = (entity_spans[:, :, 0] >= 0).long()
        encoded_entities = self.span_extractor(
            sequence_tensor=encoded_input,
            span_indices=entity_spans,
            sequence_mask=text_mask,
            span_indices_mask=entity_spans_mask)

        # Project both triggers and entities/args into a 'hidden' comparison space
        triggers_hidden = self.trigger_to_hidden(encoded_triggers)
        args_hidden = self.entities_to_hidden(encoded_entities)

        # Create the cross-product of triggers and args via broadcasting
        trigger = triggers_hidden.unsqueeze(2)  # B x T x 1 x H
        args = args_hidden.unsqueeze(1)  # B x T x E x H
        trigger_arg = trigger + args + self.hidden_bias  # B x T x E x H

        # Pass through activation and projection for classification
        role_activations = F.relu(trigger_arg)
        role_logits = self.hidden_to_roles(role_activations)  # B x T x E x R

        # Add the role predictions to the output
        role_probabilities = torch.softmax(role_logits, dim=-1)
        output_dict['role_logits'] = role_logits
        output_dict['role_probabilities'] = role_probabilities

        # Compute loss and metrics using the given role labels
        if arg_roles is not None:
            arg_roles = self._assert_target_shape(logits=role_logits,
                                                  target=arg_roles)

            target_mask = (arg_roles != -1)
            target = arg_roles * target_mask  # remove negative indices

            self.role_accuracy(role_logits, target, target_mask.float())
            self.role_f1(role_logits, target, target_mask.float())

            # Masked batch-wise cross entropy loss, optionally with focal-loss
            role_logits_t = role_logits.permute(0, 3, 1, 2)
            role_loss = cross_entropy_focal_loss(logits=role_logits_t,
                                                 target=target,
                                                 target_mask=target_mask,
                                                 gamma=self.role_gamma)

            output_dict['role_loss'] = role_loss
            output_dict['loss'] += self.loss_weight * role_loss

        # Append the original tokens for visualization
        if metadata is not None:
            output_dict["words"] = [x["words"] for x in metadata]

        # Append the trigger and entity spans to reconstruct the event after prediction
        output_dict['entity_spans'] = entity_spans
        output_dict['trigger_spans'] = trigger_spans

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        trigger_predictions = output_dict['trigger_probabilities'].cpu(
        ).data.numpy()
        trigger_labels = [[
            self.vocab.get_token_from_index(trigger_idx,
                                            namespace=self._triggers_namespace)
            for trigger_idx in example
        ] for example in np.argmax(trigger_predictions, axis=-1)]
        output_dict['trigger_labels'] = trigger_labels

        arg_role_predictions = output_dict['role_logits'].cpu().data.numpy()
        arg_role_labels = [[
            [
                self.vocab.get_token_from_index(
                    role_idx, namespace=self._roles_namespace)
                for role_idx in event
            ] for event in example
        ] for example in np.argmax(arg_role_predictions, axis=-1)]
        output_dict['role_labels'] = arg_role_labels

        events = []
        for batch_idx in range(len(trigger_labels)):
            words = output_dict['words'][batch_idx]
            batch_events = []
            for trigger_idx, trigger_label in enumerate(
                    trigger_labels[batch_idx]):
                if trigger_label == NEGATIVE_TRIGGER_LABEL:
                    continue
                trigger_span = output_dict['trigger_spans'][batch_idx][
                    trigger_idx]
                trigger_start = trigger_span[0].item()
                trigger_end = trigger_span[1].item() + 1
                if trigger_start < 0:
                    continue
                event = {
                    'event_type': trigger_label,
                    'trigger': {
                        'text': " ".join(words[trigger_start:trigger_end]),
                        'start': trigger_start,
                        'end': trigger_end
                    },
                    'arguments': []
                }
                for entity_idx, role_label in enumerate(
                        arg_role_labels[batch_idx][trigger_idx]):
                    if role_label == NEGATIVE_ARGUMENT_LABEL:
                        continue
                    arg_span = output_dict['entity_spans'][batch_idx][
                        entity_idx]
                    arg_start = arg_span[0].item()
                    arg_end = arg_span[1].item() + 1
                    if arg_start < 0:
                        continue
                    argument = {
                        'text': " ".join(words[arg_start:arg_end]),
                        'start': arg_start,
                        'end': arg_end,
                        'role': role_label
                    }
                    event['arguments'].append(argument)
                if len(event['arguments']) > 0:
                    batch_events.append(event)
            events.append(batch_events)
        output_dict['events'] = events

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'trigger_acc': self.trigger_accuracy.get_metric(reset=reset),
            'trigger_f1': self.trigger_f1.get_metric(reset=reset)['fscore'],
            'role_acc': self.role_accuracy.get_metric(reset=reset),
            'role_f1': self.role_f1.get_metric(reset=reset)['fscore']
        }

    @staticmethod
    def _assert_target_shape(logits, target):
        """
        Asserts that target tensors are always of the same size of logits. This is not always
        the case since some batches are not completely filled.
        """
        expected_shape = logits.shape[:-1]
        if target.shape == expected_shape:
            return target
        else:
            new_target = torch.full(size=expected_shape,
                                    fill_value=-1,
                                    dtype=target.dtype,
                                    device=target.device)
            batch_size, triggers_len, arguments_len = target.shape
            new_target[:, :triggers_len, :arguments_len] = target
            return new_target
Exemple #31
0
class DecomposableAttention(Model):
    """
    This `Model` implements the Decomposable Attention model described in [A Decomposable
    Attention Model for Natural Language Inference](
    https://www.semanticscholar.org/paper/A-Decomposable-Attention-Model-for-Natural-Languag-Parikh-T%C3%A4ckstr%C3%B6m/07a9478e87a8304fc3267fa16e83e9f3bbd98b27)
    by Parikh et al., 2016, with some optional enhancements before the decomposable attention
    actually happens.  Parikh's original model allowed for computing an "intra-sentence" attention
    before doing the decomposable entailment step.  We generalize this to any
    [`Seq2SeqEncoder`](../modules/seq2seq_encoders/seq2seq_encoder.md) that can be applied to
    the premise and/or the hypothesis before computing entailment.

    The basic outline of this model is to get an embedded representation of each word in the
    premise and hypothesis, align words between the two, compare the aligned phrases, and make a
    final entailment decision based on this aggregated comparison.  Each step in this process uses
    a feedforward network to modify the representation.

    Registered as a `Model` with name "decomposable_attention".

    # Parameters

    vocab : `Vocabulary`
    text_field_embedder : `TextFieldEmbedder`
        Used to embed the `premise` and `hypothesis` `TextFields` we get as input to the
        model.
    attend_feedforward : `FeedForward`
        This feedforward network is applied to the encoded sentence representations before the
        similarity matrix is computed between words in the premise and words in the hypothesis.
    matrix_attention : `MatrixAttention`
        This is the attention function used when computing the similarity matrix between words in
        the premise and words in the hypothesis.
    compare_feedforward : `FeedForward`
        This feedforward network is applied to the aligned premise and hypothesis representations,
        individually.
    aggregate_feedforward : `FeedForward`
        This final feedforward network is applied to the concatenated, summed result of the
        `compare_feedforward` network, and its output is used as the entailment class logits.
    premise_encoder : `Seq2SeqEncoder`, optional (default=`None`)
        After embedding the premise, we can optionally apply an encoder.  If this is `None`, we
        will do nothing.
    hypothesis_encoder : `Seq2SeqEncoder`, optional (default=`None`)
        After embedding the hypothesis, we can optionally apply an encoder.  If this is `None`,
        we will use the `premise_encoder` for the encoding (doing nothing if `premise_encoder`
        is also `None`).
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        attend_feedforward: FeedForward,
        matrix_attention: MatrixAttention,
        compare_feedforward: FeedForward,
        aggregate_feedforward: FeedForward,
        premise_encoder: Optional[Seq2SeqEncoder] = None,
        hypothesis_encoder: Optional[Seq2SeqEncoder] = None,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        self._text_field_embedder = text_field_embedder
        self._attend_feedforward = TimeDistributed(attend_feedforward)
        self._matrix_attention = matrix_attention
        self._compare_feedforward = TimeDistributed(compare_feedforward)
        self._aggregate_feedforward = aggregate_feedforward
        self._premise_encoder = premise_encoder
        self._hypothesis_encoder = hypothesis_encoder or premise_encoder

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(
            text_field_embedder.get_output_dim(),
            attend_feedforward.get_input_dim(),
            "text field embedding dim",
            "attend feedforward input dim",
        )
        check_dimensions_match(
            aggregate_feedforward.get_output_dim(),
            self._num_labels,
            "final output dimension",
            "number of labels",
        )

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    def forward(  # type: ignore
        self,
        premise: TextFieldTensors,
        hypothesis: TextFieldTensors,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        premise : `TextFieldTensors`
            From a `TextField`
        hypothesis : `TextFieldTensors`
            From a `TextField`
        label : `torch.IntTensor`, optional (default = `None`)
            From a `LabelField`
        metadata : `List[Dict[str, Any]]`, optional (default = `None`)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        # Returns

        An output dictionary consisting of:

        label_logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_labels)` representing unnormalised log
            probabilities of the entailment label.
        label_probs : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_labels)` representing probabilities of the
            entailment label.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise)
        hypothesis_mask = get_text_field_mask(hypothesis)

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,
                                                   projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "h2p_attention": h2p_attention,
            "p2h_attention": p2h_attention,
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [
                x["premise_tokens"] for x in metadata
            ]
            output_dict["hypothesis_tokens"] = [
                x["hypothesis_tokens"] for x in metadata
            ]

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self._accuracy.get_metric(reset)}

    default_predictor = "textual_entailment"
Exemple #32
0
class EntityNLMDiscriminator(Model):
    """
    Implementation of the discriminative model from:
        https://arxiv.org/abs/1708.00781
    used to draw importance samples.

    Parameters
    ----------
    vocab : ``Vocabulary``
        The model vocabulary.
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed tokens.
    encoder : ``Seq2SeqEncoder``
        Used to encode the sequence of token embeddings.
    embedding_dim : ``int``
        The dimension of entity / length embeddings. Should match the encoder output size.
    max_mention_length : ``int``
        Maximum entity mention length.
    max_embeddings : ``int``
        Maximum number of embeddings.
    variational_dropout_rate : ``float``, optional
        Dropout rate of variational dropout applied to input embeddings. Default: 0.0
    dropout_rate : ``float``, optional
        Dropout rate applied to hidden states. Default: 0.0
    initializer : ``InitializerApplicator``, optional
        Used to initialize model parameters.
    """
    # pylint: disable=line-too-long

    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 embedding_dim: int,
                 max_mention_length: int,
                 max_embeddings: int,
                 variational_dropout_rate: float = 0.0,
                 dropout_rate: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        super(EntityNLMDiscriminator, self).__init__(vocab)

        self._text_field_embedder = text_field_embedder
        self._encoder = encoder
        self._embedding_dim = embedding_dim
        self._max_mention_length = max_mention_length
        self._max_embeddings = max_embeddings

        self._state: Optional[StateDict] = None

        # Input variational dropout
        self._variational_dropout = InputVariationalDropout(
            variational_dropout_rate)
        self._dropout = torch.nn.Dropout(dropout_rate)

        # For entity type prediction
        self._entity_type_projection = torch.nn.Linear(in_features=embedding_dim,
                                                       out_features=2,
                                                       bias=False)
        self._dynamic_embeddings = DynamicEmbedding(embedding_dim=embedding_dim,
                                                    max_embeddings=max_embeddings)

        # For mention length prediction
        self._mention_length_projection = torch.nn.Linear(in_features=2*embedding_dim,
                                                          out_features=max_mention_length)

        self._entity_type_accuracy = CategoricalAccuracy()
        self._entity_id_accuracy = CategoricalAccuracy()
        self._mention_length_accuracy = CategoricalAccuracy()

        initializer(self)

    @overrides
    def forward(self,  # pylint: disable=arguments-differ
                tokens: Dict[str, torch.Tensor],
                entity_types: Optional[torch.Tensor] = None,
                entity_ids: Optional[torch.Tensor] = None,
                mention_lengths: Optional[torch.Tensor] = None,
                reset: bool = False)-> Dict[str, torch.Tensor]:
        """
        Computes the loss during training / validation.

        Parameters
        ----------
        tokens : ``Dict[str, torch.Tensor]``
            A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of
            tokens.
        entity_types : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the
            corresponding token belongs to a mention.
        entity_ids : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the
            entities the corresponding token is mentioning.
        mention_lengths : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining
            tokens (including the current one) there are in the mention.
        reset : ``bool``
            Whether or not to reset the model's state. This should be done at the start of each
            new sequence.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.Tensor``
            The combined loss.
        """
        batch_size = tokens['tokens'].shape[0]

        if reset:
            self.reset_states(batch_size)
        else:
            self.detach_states()

        if entity_types is not None:
            output_dict = self._forward_loop(tokens=tokens,
                                             entity_types=entity_types,
                                             entity_ids=entity_ids,
                                             mention_lengths=mention_lengths)
        else:
            output_dict = {}

        return output_dict

    def sample(self,
               tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Generates a sample from the discriminative model.

        WARNING: Unlike during training, this function expects the full (unsplit) sequence of
        tokens.

        Parameters
        ----------
        tokens : ``Dict[str, torch.Tensor]``
            A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of
            tokens.

        Returns
        -------
        An output dictionary consisting of:
        logp : ``torch.Tensor``
            A tensor containing the log-probability of the sample (averaged over time)
        entity_types : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the
            corresponding token belongs to a mention.
        entity_ids : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the
            entities the corresponding token is mentioning.
        mention_lengths : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining
            tokens (including the current one) there are in the mention.
        """
        batch_size, sequence_length = tokens['tokens'].shape

        # We will use a standard iterator during evaluation instead of a split iterator. Otherwise
        # it will be a pain to handle generating multiple samples for a sequence since there's no
        # way to get back to the first split.
        self.reset_states(batch_size)

        # Embed tokens and get RNN hidden state.
        mask = get_text_field_mask(tokens)
        embeddings = self._text_field_embedder(tokens)
        hidden = self._encoder(embeddings, mask)
        prev_mention_lengths = tokens['tokens'].new_ones(batch_size)

        # Initialize outputs
        # Track total logp for **each** generated sample
        logp = hidden.new_zeros(batch_size)
        entity_types = torch.zeros_like(tokens['tokens'], dtype=torch.uint8)
        entity_ids = torch.zeros_like(tokens['tokens'])
        mention_lengths = torch.ones_like(tokens['tokens'])

        # Generate outputs
        for timestep in range(sequence_length):

            current_hidden = hidden[:, timestep]

            # We only predict types / ids / lengths if the previous mention is terminated.
            predict_mask = prev_mention_lengths == 1
            predict_mask = predict_mask * mask[:, timestep].byte()
            if predict_mask.sum() > 0:

                # Predict entity types
                entity_type_logits = self._entity_type_projection(
                    current_hidden[predict_mask])
                entity_type_logp = F.log_softmax(entity_type_logits, dim=-1)
                entity_type_prediction_logp, entity_type_predictions = sample_from_logp(
                    entity_type_logp)
                entity_type_predictions = entity_type_predictions.byte()
                entity_types[predict_mask, timestep] = entity_type_predictions
                logp[predict_mask] += entity_type_prediction_logp

                # Only predict entity and mention lengths if we predicted that there was a mention
                predict_em = entity_types[:, timestep] * predict_mask
                if predict_em.sum() > 0:
                    # Predict entity ids
                    entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden,
                                                                            mask=predict_em)
                    entity_id_logits = entity_id_prediction_outputs['logits']
                    entity_id_logp = F.log_softmax(entity_id_logits, dim=-1)
                    entity_id_prediction_logp, entity_id_predictions = sample_from_logp(
                        entity_id_logp)

                    # Predict mention lengths - we do this before writing the
                    # entity id predictions since we'll need to reindex the new
                    # entities, but need the null embeddings here.
                    predicted_entity_embeddings = self._dynamic_embeddings.embeddings[
                        predict_em, entity_id_predictions]
                    concatenated = torch.cat(
                        (current_hidden[predict_em], predicted_entity_embeddings), dim=-1)
                    mention_length_logits = self._mention_length_projection(
                        concatenated)
                    mention_length_logp = F.log_softmax(
                        mention_length_logits, dim=-1)
                    mention_length_prediction_logp, mention_length_predictions = sample_from_logp(
                        mention_length_logp)

                    # Write predictions
                    new_entity_mask = entity_id_predictions == 0
                    new_entity_labels = self._dynamic_embeddings.num_embeddings[predict_em]
                    entity_id_predictions[new_entity_mask] = new_entity_labels[new_entity_mask]
                    entity_ids[predict_em, timestep] = entity_id_predictions
                    logp[predict_em] += entity_id_prediction_logp

                    mention_lengths[predict_em,
                                    timestep] = mention_length_predictions
                    logp[predict_em] += mention_length_prediction_logp

                # Add / update entity embeddings
                new_entities = entity_ids[:,
                                          timestep] == self._dynamic_embeddings.num_embeddings
                self._dynamic_embeddings.add_embeddings(timestep, new_entities)

                self._dynamic_embeddings.update_embeddings(hidden=current_hidden,
                                                           update_indices=entity_ids[:,
                                                                                     timestep],
                                                           timestep=timestep,
                                                           mask=predict_em)

            # If the previous mentions are ongoing, we assign the output deterministically. Mention
            # lengths decrease by 1, all other outputs are copied from the previous timestep. Do
            # not need to add anything to logp since these 'predictions' have probability 1 under
            # the model.
            deterministic_mask = prev_mention_lengths > 1
            deterministic_mask = deterministic_mask * mask[:, timestep].byte()
            if deterministic_mask.sum() > 1:
                entity_types[deterministic_mask,
                             timestep] = entity_types[deterministic_mask, timestep - 1]
                entity_ids[deterministic_mask,
                           timestep] = entity_ids[deterministic_mask, timestep - 1]
                mention_lengths[deterministic_mask,
                                timestep] = mention_lengths[deterministic_mask, timestep - 1] - 1

            # Update mention lengths for next timestep
            prev_mention_lengths = mention_lengths[:, timestep]

        return {
            'logp': logp,
            'sample': {
                'entity_types': entity_types,
                'entity_ids': entity_ids,
                'mention_lengths': mention_lengths
            }
        }

    def _forward_loop(self,
                      tokens: Dict[str, torch.Tensor],
                      entity_types: torch.Tensor,
                      entity_ids: torch.Tensor,
                      mention_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Performs the forward pass to calculate the loss on a chunk of training data.

        Parameters
        ----------
        tokens : ``Dict[str, torch.Tensor]``
            A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of
            tokens.
        entity_types : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the
            corresponding token belongs to a mention.
        entity_ids : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the
            entities the corresponding token is mentioning.
        mention_lengths : ``torch.Tensor``
            A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining
            tokens (including the current one) there are in the mention.

        Returns
        -------
        An output dictionary consisting of:
        entity_type_loss : ``torch.Tensor``
            The loss of entity type predictions.
        entity_id_loss : ``torch.Tensor``
            The loss of entity id predictions.
        mention_length_loss : ``torch.Tensor``
            The loss of mention length predictions.
        loss : ``torch.Tensor``
            The combined loss.
        """
        batch_size, sequence_length = tokens['tokens'].shape

        # Need to track previous mention lengths in order to know when to measure loss.
        if self._state is None:
            prev_mention_lengths = mention_lengths.new_ones(batch_size)
        else:
            prev_mention_lengths = self._state['prev_mention_lengths']

        # Embed tokens and get RNN hidden state.
        mask = get_text_field_mask(tokens)
        embeddings = self._text_field_embedder(tokens)
        embeddings = self._variational_dropout(embeddings)
        hidden = self._encoder(embeddings, mask)

        # Initialize losses
        entity_type_loss = torch.tensor(
            0.0, requires_grad=True, device=hidden.device)
        entity_id_loss = torch.tensor(
            0.0, requires_grad=True, device=hidden.device)
        mention_length_loss = torch.tensor(
            0.0, requires_grad=True, device=hidden.device)

        for timestep in range(sequence_length):

            current_entity_types = entity_types[:, timestep]
            current_entity_ids = entity_ids[:, timestep]
            current_mention_lengths = mention_lengths[:, timestep]
            current_hidden = hidden[:, timestep]
            current_hidden = self._dropout(hidden[:, timestep])

            # We only predict types / ids / lengths if we are not currently in the process of
            # generating a mention (e.g. if the previous remaining mention length is 1). Indexing /
            # masking with ``predict_all`` makes it possible to do this in batch.
            predict_all = prev_mention_lengths == 1
            predict_all = predict_all * mask[:, timestep].byte()
            if predict_all.sum() > 0:

                # Equation 3 in the paper.
                entity_type_logits = self._entity_type_projection(
                    current_hidden[predict_all])
                entity_type_loss = entity_type_loss + F.cross_entropy(
                    entity_type_logits,
                    current_entity_types[predict_all].long(),
                    reduction='sum')
                self._entity_type_accuracy(predictions=entity_type_logits,
                                           gold_labels=current_entity_types[predict_all].long())

                # Only proceed to predict entity and mention length if there is in fact an entity.
                predict_em = current_entity_types * predict_all

                if predict_em.sum() > 0:
                    # Equation 4 in the paper. We want new entities to correspond to a prediction of
                    # zero, their embedding should be added after they've been predicted for the first
                    # time.
                    modified_entity_ids = current_entity_ids.clone()
                    modified_entity_ids[modified_entity_ids ==
                                        self._dynamic_embeddings.num_embeddings] = 0
                    entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden,
                                                                            target=modified_entity_ids,
                                                                            mask=predict_em)
                    entity_id_loss = entity_id_loss + \
                        entity_id_prediction_outputs['loss'].sum()
                    self._entity_id_accuracy(predictions=entity_id_prediction_outputs['logits'],
                                             gold_labels=modified_entity_ids[predict_em])

                    # Equation 5 in the paper.
                    predicted_entity_embeddings = self._dynamic_embeddings.embeddings[
                        predict_em, modified_entity_ids[predict_em]]
                    predicted_entity_embeddings = self._dropout(
                        predicted_entity_embeddings)
                    concatenated = torch.cat(
                        (current_hidden[predict_em], predicted_entity_embeddings), dim=-1)
                    mention_length_logits = self._mention_length_projection(
                        concatenated)
                    mention_length_loss = mention_length_loss + F.cross_entropy(
                        mention_length_logits,
                        current_mention_lengths[predict_em])
                    self._mention_length_accuracy(predictions=mention_length_logits,
                                                  gold_labels=current_mention_lengths[predict_em])

            # We add new entities to any sequence where the current entity id matches the number of
            # embeddings that currently exist for that sequence (this means we need a new one since
            # there is an additional dummy embedding).
            new_entities = current_entity_ids == self._dynamic_embeddings.num_embeddings
            self._dynamic_embeddings.add_embeddings(timestep, new_entities)

            # We also perform updates of the currently observed entities.
            self._dynamic_embeddings.update_embeddings(hidden=current_hidden,
                                                       update_indices=current_entity_ids,
                                                       timestep=timestep,
                                                       mask=current_entity_types)

            prev_mention_lengths = current_mention_lengths

        # Normalize the losses
        entity_type_loss = entity_type_loss / mask.sum()
        entity_id_loss = entity_id_loss / mask.sum()
        mention_length_loss = mention_length_loss / mask.sum()
        total_loss = entity_type_loss + entity_id_loss + mention_length_loss

        output_dict = {
            'entity_type_loss': entity_type_loss,
            'entity_id_loss': entity_id_loss,
            'mention_length_loss': mention_length_loss,
            'loss': total_loss
        }

        # Update state
        self._state = {
            'prev_mention_lengths': prev_mention_lengths.detach()
        }

        return output_dict

    def reset_states(self, batch_size: int) -> None:
        """Resets the model's internals. Should be called at the start of a new batch."""
        self._encoder.reset_states()
        self._dynamic_embeddings.reset_states(batch_size)
        self._state = None

    def detach_states(self):
        """Detaches the model's state to enforce truncated backpropagation."""
        self._dynamic_embeddings.detach_states()

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'et_acc': self._entity_type_accuracy.get_metric(reset),
            'eid_acc': self._entity_id_accuracy.get_metric(reset),
            'ml_acc': self._mention_length_accuracy.get_metric(reset)
        }
class TransformerClassificationTT(Model):
    """
    This class implements a classification patterned after the proposed model in
    [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al)]
    (https://api.semanticscholar.org/CorpusID:198953378).

    Parameters
    ----------
    vocab : ``Vocabulary``
    transformer_model : ``str``, optional (default=``"roberta-large"``)
        This model chooses the embedder according to this setting. You probably want to make sure this matches the
        setting in the reader.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        transformer_model: str = "roberta-large",
        num_labels: Optional[int] = None,
        label_namespace: str = "labels",
        override_weights_file: Optional[str] = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        transformer_kwargs = {
            "model_name": transformer_model,
            "weights_path": override_weights_file,
        }
        self.embeddings = TransformerEmbeddings.from_pretrained_module(
            **transformer_kwargs)
        self.transformer_stack = TransformerStack.from_pretrained_module(
            **transformer_kwargs)
        self.pooler = TransformerPooler.from_pretrained_module(
            **transformer_kwargs)
        self.pooler_dropout = Dropout(p=0.1)

        self.label_tokens = vocab.get_index_to_token_vocabulary(
            label_namespace)
        if num_labels is None:
            num_labels = len(self.label_tokens)
        self.linear_layer = torch.nn.Linear(self.pooler.get_output_dim(),
                                            num_labels)
        self.linear_layer.weight.data.normal_(mean=0.0, std=0.02)
        self.linear_layer.bias.data.zero_()

        from allennlp.training.metrics import CategoricalAccuracy, FBetaMeasure

        self.loss = torch.nn.CrossEntropyLoss()
        self.acc = CategoricalAccuracy()
        self.f1 = FBetaMeasure()

    def forward(  # type: ignore
        self,
        text: Dict[str, torch.Tensor],
        label: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``
            From a ``TensorTextField``. Contains the text to be classified.
        label : ``Optional[torch.LongTensor]``
            From a ``LabelField``, specifies the true class of the instance

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. This is only returned when `correct_alternative` is not `None`.
        logits : ``torch.FloatTensor``
            The logits for every possible answer choice
        """
        embedded_alternatives = self.embeddings(**text)
        embedded_alternatives = self.transformer_stack(embedded_alternatives,
                                                       text["attention_mask"])
        embedded_alternatives = self.pooler(
            embedded_alternatives.final_hidden_states)
        embedded_alternatives = self.pooler_dropout(embedded_alternatives)
        logits = self.linear_layer(embedded_alternatives)

        result = {"logits": logits, "answers": logits.argmax(1)}

        if label is not None:
            result["loss"] = self.loss(logits, label)
            self.acc(logits, label)
            self.f1(logits, label)

        return result

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        result = {"acc": self.acc.get_metric(reset)}
        for metric_name, metrics_per_class in self.f1.get_metric(
                reset).items():
            for class_index, value in enumerate(metrics_per_class):
                result[
                    f"{self.label_tokens[class_index]}-{metric_name}"] = value
        return result
class SpanConstituencyParser(Model):
    """
    This ``SpanConstituencyParser`` simply encodes a sequence of text
    with a stacked ``Seq2SeqEncoder``, extracts span representations using a
    ``SpanExtractor``, and then predicts a label for each span in the sequence.
    These labels are non-terminal nodes in a constituency parse tree, which we then
    greedily reconstruct.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    span_extractor : ``SpanExtractor``, required.
        The method used to extract the spans from the encoded sequence.
    encoder : ``Seq2SeqEncoder``, required.
        The encoder that we will use in between embedding tokens and
        generating span representations.
    feedforward_layer : ``FeedForward``, required.
        The FeedForward layer that we will use in between the encoder and the linear
        projection to a distribution over span labels.
    pos_tag_embedding : ``Embedding``, optional.
        Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 span_extractor: SpanExtractor,
                 encoder: Seq2SeqEncoder,
                 feedforward_layer: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 evalb_directory_path: str = None) -> None:
        super(SpanConstituencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.span_extractor = span_extractor
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.encoder = encoder
        self.feedforward_layer = TimeDistributed(feedforward_layer) if feedforward_layer else None
        self.pos_tag_embedding = pos_tag_embedding or None
        if feedforward_layer is not None:
            output_dim = feedforward_layer.get_output_dim()
        else:
            output_dim = span_extractor.get_output_dim()

        self.tag_projection_layer = TimeDistributed(Linear(output_dim, self.num_classes))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        check_dimensions_match(representation_dim,
                               encoder.get_input_dim(),
                               "representation dim (tokens + optional POS tags)",
                               "encoder input dim")
        check_dimensions_match(encoder.get_output_dim(),
                               span_extractor.get_input_dim(),
                               "encoder input dim",
                               "span extractor input dim")
        if feedforward_layer is not None:
            check_dimensions_match(span_extractor.get_output_dim(),
                                   feedforward_layer.get_input_dim(),
                                   "span extractor output dim",
                                   "feedforward input dim")

        self.tag_accuracy = CategoricalAccuracy()

        if evalb_directory_path is not None:
            self._evalb_score = EvalbBracketingScorer(evalb_directory_path)
        else:
            self._evalb_score = None
        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                pos_tags: Dict[str, torch.LongTensor] = None,
                span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
                "class_probabilities": class_probabilities,
                "spans": spans,
                "tokens": [meta["tokens"] for meta in metadata],
                "pos_tags": [meta.get("pos_tags") for meta in metadata],
                "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
                                              for tree in batch_gold_trees]
            predicted_trees = self.construct_trees(class_probabilities.cpu().data,
                                                   spans.cpu().data,
                                                   num_spans.data,
                                                   output_dict["tokens"],
                                                   gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Constructs an NLTK ``Tree`` given the scored spans. We also switch to exclusive
        span ends when constructing the tree representation, because it makes indexing
        into lists cleaner for ranges of text, rather than individual indices.

        Finally, for batch prediction, we will have padded spans and class probabilities.
        In order to make this less confusing, we remove all the padded spans and
        distributions from ``spans`` and ``class_probabilities`` respectively.
        """
        all_predictions = output_dict['class_probabilities'].cpu().data
        all_spans = output_dict["spans"].cpu().data

        all_sentences = output_dict["tokens"]
        all_pos_tags = output_dict["pos_tags"] if all(output_dict["pos_tags"]) else None
        num_spans = output_dict["num_spans"].data
        trees = self.construct_trees(all_predictions, all_spans, num_spans, all_sentences, all_pos_tags)

        batch_size = all_predictions.size(0)
        output_dict["spans"] = [all_spans[i, :num_spans[i]] for i in range(batch_size)]
        output_dict["class_probabilities"] = [all_predictions[i, :num_spans[i], :] for i in range(batch_size)]

        output_dict["trees"] = trees
        return output_dict

    def construct_trees(self,
                        predictions: torch.FloatTensor,
                        all_spans: torch.LongTensor,
                        num_spans: torch.LongTensor,
                        sentences: List[List[str]],
                        pos_tags: List[List[str]] = None) -> List[Tree]:
        """
        Construct ``nltk.Tree``'s for each batch element by greedily nesting spans.
        The trees use exclusive end indices, which contrasts with how spans are
        represented in the rest of the model.

        Parameters
        ----------
        predictions : ``torch.FloatTensor``, required.
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        all_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the span
            indices we scored.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        sentences : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, optional (default = None).
            A list of POS tags for each word in the sentence for each element
            in the batch.

        Returns
        -------
        A ``List[Tree]`` containing the decoded trees for each element in the batch.
        """
        # Switch to using exclusive end spans.
        exclusive_end_spans = all_spans.clone()
        exclusive_end_spans[:, :, -1] += 1
        no_label_id = self.vocab.get_token_index("NO-LABEL", "labels")

        trees: List[Tree] = []
        for batch_index, (scored_spans, spans, sentence) in enumerate(zip(predictions,
                                                                          exclusive_end_spans,
                                                                          sentences)):
            selected_spans = []
            for prediction, span in zip(scored_spans[:num_spans[batch_index]],
                                        spans[:num_spans[batch_index]]):
                start, end = span
                no_label_prob = prediction[no_label_id]
                label_prob, label_index = torch.max(prediction, -1)

                # Does the span have a label != NO-LABEL or is it the root node?
                # If so, include it in the spans that we consider.
                if int(label_index) != no_label_id or (start == 0 and end == len(sentence)):
                    # TODO(Mark): Remove this once pylint sorts out named tuples.
                    # https://github.com/PyCQA/pylint/issues/1418
                    selected_spans.append(SpanInformation(start=int(start), # pylint: disable=no-value-for-parameter
                                                          end=int(end),
                                                          label_prob=float(label_prob),
                                                          no_label_prob=float(no_label_prob),
                                                          label_index=int(label_index)))

            # The spans we've selected might overlap, which causes problems when we try
            # to construct the tree as they won't nest properly.
            consistent_spans = self.resolve_overlap_conflicts_greedily(selected_spans)

            spans_to_labels = {(span.start, span.end):
                               self.vocab.get_token_from_index(span.label_index, "labels")
                               for span in consistent_spans}
            sentence_pos = pos_tags[batch_index] if pos_tags is not None else None
            trees.append(self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos))

        return trees

    @staticmethod
    def resolve_overlap_conflicts_greedily(spans: List[SpanInformation]) -> List[SpanInformation]:
        """
        Given a set of spans, removes spans which overlap by evaluating the difference
        in probability between one being labeled and the other explicitly having no label
        and vice-versa. The worst case time complexity of this method is ``O(k * n^4)`` where ``n``
        is the length of the sentence that the spans were enumerated from (and therefore
        ``k * m^2`` complexity with respect to the number of spans ``m``) and ``k`` is the
        number of conflicts. However, in practice, there are very few conflicts. Hopefully.

        This function modifies ``spans`` to remove overlapping spans.

        Parameters
        ----------
        spans: ``List[SpanInformation]``, required.
            A list of spans, where each span is a ``namedtuple`` containing the
            following attributes:

        start : ``int``
            The start index of the span.
        end : ``int``
            The exclusive end index of the span.
        no_label_prob : ``float``
            The probability of this span being assigned the ``NO-LABEL`` label.
        label_prob : ``float``
            The probability of the most likely label.

        Returns
        -------
        A modified list of ``spans``, with the conflicts resolved by considering local
        differences between pairs of spans and removing one of the two spans.
        """
        conflicts_exist = True
        while conflicts_exist:
            conflicts_exist = False
            for span1_index, span1 in enumerate(spans):
                for span2_index, span2 in list(enumerate(spans))[span1_index + 1:]:
                    if (span1.start < span2.start < span1.end < span2.end or
                                span2.start < span1.start < span2.end < span1.end):
                        # The spans overlap.
                        conflicts_exist = True
                        # What's the more likely situation: that span2 was labeled
                        # and span1 was unlabled, or that span1 was labeled and span2
                        # was unlabled? In the first case, we delete span2 from the
                        # set of spans to form the tree - in the second case, we delete
                        # span1.
                        if (span1.no_label_prob + span2.label_prob <
                                    span2.no_label_prob + span1.label_prob):
                            spans.pop(span2_index)
                        else:
                            spans.pop(span1_index)
                        break
        return spans

    @staticmethod
    def construct_tree_from_spans(spans_to_labels: Dict[Tuple[int, int], str],
                                  sentence: List[str],
                                  pos_tags: List[str] = None) -> Tree:
        """
        Parameters
        ----------
        spans_to_labels : ``Dict[Tuple[int, int], str]``, required.
            A mapping from spans to constituency labels.
        sentence : ``List[str]``, required.
            A list of tokens forming the sentence to be parsed.
        pos_tags : ``List[str]``, optional (default = None)
            A list of the pos tags for the words in the sentence, if they
            were either predicted or taken as input to the model.

        Returns
        -------
        An ``nltk.Tree`` constructed from the labelled spans.
        """
        def assemble_subtree(start: int, end: int):
            if (start, end) in spans_to_labels:
                # Some labels contain nested spans, e.g S-VP.
                # We actually want to create (S (VP ...)) nodes
                # for these labels, so we split them up here.
                labels: List[str] = spans_to_labels[(start, end)].split("-")
            else:
                labels = None

            # This node is a leaf.
            if end - start == 1:
                word = sentence[start]
                pos_tag = pos_tags[start] if pos_tags is not None else "XX"
                tree = Tree(pos_tag, [word])
                if labels is not None and pos_tags is not None:
                    # If POS tags were passed explicitly,
                    # they are added as pre-terminal nodes.
                    while labels:
                        tree = Tree(labels.pop(), [tree])
                elif labels is not None:
                    # Otherwise, we didn't want POS tags
                    # at all.
                    tree = Tree(labels.pop(), [word])
                    while labels:
                        tree = Tree(labels.pop(), [tree])
                return [tree]

            argmax_split = start + 1
            # Find the next largest subspan such that
            # the left hand side is a constituent.
            for split in range(end - 1, start, -1):
                if (start, split) in spans_to_labels:
                    argmax_split = split
                    break

            left_trees = assemble_subtree(start, argmax_split)
            right_trees = assemble_subtree(argmax_split, end)
            children = left_trees + right_trees
            if labels is not None:
                while labels:
                    children = [Tree(labels.pop(), children)]
            return children

        tree = assemble_subtree(0, len(sentence))
        return tree[0]

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics = {}
        all_metrics["tag_accuracy"] = self.tag_accuracy.get_metric(reset=reset)
        if self._evalb_score is not None:
            evalb_metrics = self._evalb_score.get_metric(reset=reset)
            all_metrics.update(evalb_metrics)
        return all_metrics

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'SpanConstituencyParser':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params)
        span_extractor = SpanExtractor.from_params(params.pop("span_extractor"))
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))

        feed_forward_params = params.pop("feedforward", None)
        if feed_forward_params is not None:
            feedforward_layer = FeedForward.from_params(feed_forward_params)
        else:
            feedforward_layer = None
        pos_tag_embedding_params = params.pop("pos_tag_embedding", None)
        if pos_tag_embedding_params is not None:
            pos_tag_embedding = Embedding.from_params(vocab, pos_tag_embedding_params)
        else:
            pos_tag_embedding = None
        initializer = InitializerApplicator.from_params(params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(params.pop('regularizer', []))
        evalb_directory_path = params.pop("evalb_directory_path", None)
        params.assert_empty(cls.__name__)

        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   span_extractor=span_extractor,
                   encoder=encoder,
                   feedforward_layer=feedforward_layer,
                   pos_tag_embedding=pos_tag_embedding,
                   initializer=initializer,
                   regularizer=regularizer,
                   evalb_directory_path=evalb_directory_path)
Exemple #35
0
class BasicClassifier(Model):
    """
    This `Model` implements a basic text classifier. After embedding the text into
    a text field, we will optionally encode the embeddings with a `Seq2SeqEncoder`. The
    resulting sequence is pooled using a `Seq2VecEncoder` and then passed to
    a linear classification layer, which projects into the label space. If a
    `Seq2SeqEncoder` is not provided, we will pass the embedded text directly to the
    `Seq2VecEncoder`.

    Registered as a `Model` with name "basic_classifier".

    # Parameters

    vocab : `Vocabulary`
    text_field_embedder : `TextFieldEmbedder`
        Used to embed the input text into a `TextField`
    seq2seq_encoder : `Seq2SeqEncoder`, optional (default=`None`)
        Optional Seq2Seq encoder layer for the input text.
    seq2vec_encoder : `Seq2VecEncoder`
        Required Seq2Vec encoder layer. If `seq2seq_encoder` is provided, this encoder
        will pool its output. Otherwise, this encoder will operate directly on the output
        of the `text_field_embedder`.
    feedforward : `FeedForward`, optional, (default = None).
        An optional feedforward layer to apply after the seq2vec_encoder.
    dropout : `float`, optional (default = `None`)
        Dropout percentage to use.
    num_labels : `int`, optional (default = `None`)
        Number of labels to project to in classification layer. By default, the classification layer will
        project to the size of the vocabulary namespace corresponding to labels.
    label_namespace : `str`, optional (default = "labels")
        Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        If provided, will be used to initialize the model parameters.
    """

    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        seq2vec_encoder: Seq2VecEncoder,
        seq2seq_encoder: Seq2SeqEncoder = None,
        feedforward: Optional[FeedForward] = None,
        dropout: float = None,
        num_labels: int = None,
        label_namespace: str = "labels",
        namespace: str = "tokens",
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:

        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder

        if seq2seq_encoder:
            self._seq2seq_encoder = seq2seq_encoder
        else:
            self._seq2seq_encoder = None

        self._seq2vec_encoder = seq2vec_encoder
        self._feedforward = feedforward
        if feedforward is not None:
            self._classifier_input_dim = self._feedforward.get_output_dim()
        else:
            self._classifier_input_dim = self._seq2vec_encoder.get_output_dim()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = None
        self._label_namespace = label_namespace
        self._namespace = namespace

        if num_labels:
            self._num_labels = num_labels
        else:
            self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace)
        self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels)
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)

    def forward(  # type: ignore
        self, tokens: TextFieldTensors, label: torch.IntTensor = None
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters

        tokens : TextFieldTensors
            From a `TextField`
        label : torch.IntTensor, optional (default = None)
            From a `LabelField`

        # Returns

        An output dictionary consisting of:

            - `logits` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                unnormalized log probabilities of the label.
            - `probs` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                probabilities of the label.
            - `loss` : (`torch.FloatTensor`, optional) :
                A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(tokens)
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict

    @overrides
    def make_output_human_readable(
        self, output_dict: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        """
        Does a simple argmax over the probabilities, converts index to string label, and
        add `"label"` key to the dictionary with the result.
        """
        predictions = output_dict["probs"]
        if predictions.dim() == 2:
            predictions_list = [predictions[i] for i in range(predictions.shape[0])]
        else:
            predictions_list = [predictions]
        classes = []
        for prediction in predictions_list:
            label_idx = prediction.argmax(dim=-1).item()
            label_str = self.vocab.get_index_to_token_vocabulary(self._label_namespace).get(
                label_idx, str(label_idx)
            )
            classes.append(label_str)
        output_dict["label"] = classes
        tokens = []
        for instance_tokens in output_dict["token_ids"]:
            tokens.append(
                [
                    self.vocab.get_token_from_index(token_id.item(), namespace=self._namespace)
                    for token_id in instance_tokens
                ]
            )
        output_dict["tokens"] = tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {"accuracy": self._accuracy.get_metric(reset)}
        return metrics
class DecomposableAttention(Model):
    """
    This ``Model`` implements the Decomposable Attention model described in `"A Decomposable
    Attention Model for Natural Language Inference"
    <https://www.semanticscholar.org/paper/A-Decomposable-Attention-Model-for-Natural-Languag-Parikh-T%C3%A4ckstr%C3%B6m/07a9478e87a8304fc3267fa16e83e9f3bbd98b27>`_
    by Parikh et al., 2016, with some optional enhancements before the decomposable attention
    actually happens.  Parikh's original model allowed for computing an "intra-sentence" attention
    before doing the decomposable entailment step.  We generalize this to any
    :class:`Seq2SeqEncoder` that can be applied to the premise and/or the hypothesis before
    computing entailment.

    The basic outline of this model is to get an embedded representation of each word in the
    premise and hypothesis, align words between the two, compare the aligned phrases, and make a
    final entailment decision based on this aggregated comparison.  Each step in this process uses
    a feedforward network to modify the representation.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
        model.
    attend_feedforward : ``FeedForward``
        This feedforward network is applied to the encoded sentence representations before the
        similarity matrix is computed between words in the premise and words in the hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between words in
        the premise and words in the hypothesis.
    compare_feedforward : ``FeedForward``
        This feedforward network is applied to the aligned premise and hypothesis representations,
        individually.
    aggregate_feedforward : ``FeedForward``
        This final feedforward network is applied to the concatenated, summed result of the
        ``compare_feedforward`` network, and its output is used as the entailment class logits.
    premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the premise, we can optionally apply an encoder.  If this is ``None``, we
        will do nothing.
    hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the hypothesis, we can optionally apply an encoder.  If this is ``None``,
        we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder``
        is also ``None``).
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 attend_feedforward: FeedForward,
                 similarity_function: SimilarityFunction,
                 compare_feedforward: FeedForward,
                 aggregate_feedforward: FeedForward,
                 premise_encoder: Optional[Seq2SeqEncoder] = None,
                 hypothesis_encoder: Optional[Seq2SeqEncoder] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DecomposableAttention, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._attend_feedforward = TimeDistributed(attend_feedforward)
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._compare_feedforward = TimeDistributed(compare_feedforward)
        self._aggregate_feedforward = aggregate_feedforward
        self._premise_encoder = premise_encoder
        self._hypothesis_encoder = hypothesis_encoder or premise_encoder

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(text_field_embedder.get_output_dim(), attend_feedforward.get_input_dim(),
                               "text field embedding dim", "attend feedforward input dim")
        check_dimensions_match(aggregate_feedforward.get_output_dim(), self._num_labels,
                               "final output dimension", "number of labels")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits,
                       "label_probs": label_probs,
                       "h2p_attention": h2p_attention,
                       "p2h_attention": p2h_attention}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata]
            output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata]

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'accuracy': self._accuracy.get_metric(reset),
                }
class QuarelSemanticParser(Model):
    """
    A ``QuarelSemanticParser`` is a variant of ``WikiTablesSemanticParser`` with various
    tweaks and changes.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    use_entities : ``bool``, optional (default=False)
        Whether dynamic entities are part of the action space
    num_entity_bits : ``int``, optional (default=0)
        Whether any bits are added to encoder input/output to represent tagged entities
    entity_bits_output : ``bool``, optional (default=False)
        Whether entity bits are added to the encoder output or input
    denotation_only : ``bool``, optional (default=False)
        Whether to only predict target denotation, skipping the the whole logical form decoder
    entity_similarity_mode : ``str``, optional (default="dot_product")
        How to compute vector similarity between question and entity tokens, can take values
        "dot_product" or "weighted_dot_product" (learned weights on each dimension)
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 dropout: float = 0.0,
                 num_linking_features: int = 0,
                 num_entity_bits: int = 0,
                 entity_bits_output: bool = True,
                 use_entities: bool = False,
                 denotation_only: bool = False,
                 # Deprecated parameter to load older models
                 entity_encoder: Seq2VecEncoder = None,  # pylint: disable=unused-argument
                 entity_similarity_mode: str = "dot_product",
                 rule_namespace: str = 'rule_labels') -> None:
        super(QuarelSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = Average()
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._embedding_dim = question_embedder.get_output_dim()
        self._use_entities = use_entities

        # Note: there's only one non-trivial entity type in QuaRel for now, so most of the
        # entity_type stuff is irrelevant
        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model
        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._entity_similarity_layer = None
        self._entity_similarity_mode = entity_similarity_mode
        if self._entity_similarity_mode == "weighted_dot_product":
            self._entity_similarity_layer = \
                TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False))
            # Center initial values around unweighted dot product
            self._entity_similarity_layer._module.weight.data += 1  # pylint: disable=protected-access
        elif self._entity_similarity_mode == "dot_product":
            pass
        else:
            raise ValueError("Invalid entity_similarity_mode: {}".format(self._entity_similarity_mode))

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._encoder_output_dim = self._encoder.get_output_dim()
        if entity_bits_output:
            self._encoder_output_dim += num_entity_bits

        self._entity_bits_output = entity_bits_output

        self._debug_count = 10

        self._num_denotation_cats = 2  # Hardcoded for simplicity
        self._denotation_only = denotation_only
        if self._denotation_only:
            self._denotation_accuracy_cat = CategoricalAccuracy()
            self._denotation_classifier = torch.nn.Linear(self._encoder_output_dim,
                                                          self._num_denotation_cats)
            # Rest of init not needed for denotation only where no decoding to actions needed
            return

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._num_actions = num_actions
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        # We are tying the action embeddings used for input and output
        # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = self._action_embedder  # tied weights
        self._add_action_bias = add_action_bias
        if self._add_action_bias:
            self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(self._encoder_output_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim,
                                                       action_embedding_dim=action_embedding_dim,
                                                       input_attention=attention,
                                                       num_start_types=self._num_start_types,
                                                       predict_start_type_separately=False,
                                                       add_action_bias=self._add_action_bias,
                                                       mixture_feedforward=mixture_feedforward,
                                                       dropout=dropout)

    @overrides
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[QuarelWorld],
                actions: List[List[ProductionRule]],
                entity_bits: torch.Tensor = None,
                denotation_target: torch.Tensor = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                           num_entities * num_entity_tokens,
                                                                           self._embedding_dim),
                                                       torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(batch_size,
                                       num_question_tokens*num_entities*num_entity_tokens,
                                       self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(product)
                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                    question_mask, entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(final_encoder_output)
            loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))

        initial_grammar_state = [self._create_grammar_state(world[i], actions[i],
                                                            linking_scores[i], entity_types[i])
                                 for i in range(batch_size)]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=None,
                                          debug_info=None)

        if self.training:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get('question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get('world_extractions', {}))
                outputs['entities'].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(best_final_states[i][0].score[0].data.cpu().item())
                    outputs['answer_index'].append(predicted_answer_index)
                    outputs['score'].append(score)
                    outputs['parse_acc'].append(sequence_in_targets)
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['denotation_acc'].append(denotation_accuracy)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs['parse_acc'].append(0)
                    outputs['logical_form'].append('')
                    outputs['denotation_acc'].append(0)
                    outputs['score'].append(0)
                    outputs['answer_index'].append(-1)
                    outputs['best_action_sequence'].append([])
                    outputs['debug_info'].append([])
                    self._has_logical_form(0.0)
            return outputs

    @staticmethod
    def _get_type_vector(worlds: List[QuarelWorld],
                         num_entities: int,
                         tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's
        type. In addition, a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)
        return tensor.new_tensor(batch_types, dtype=torch.long), entity_types

    def _get_linking_probabilities(self,
                                   worlds: List[QuarelWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[QuarelWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(num_question_tokens,
                                                 num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()

    def _denotation_match(self, predicted_answer_index: int, target_answer_index: int) -> float:
        if predicted_answer_index < 0:
            # Logical form doesn't properly resolve, we do random guess with appropriate credit
            return 1.0/self._num_denotation_cats
        elif predicted_answer_index == target_answer_index:
            return 1.0
        return 0.0

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track three metrics here:

            1. parse_acc, which is the percentage of the time that our best output action sequence
            corresponds to a correct logical form

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation, including spurious correct answers using the wrong logical form

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        if self._denotation_only:
            metrics = {'denotation_acc': self._denotation_accuracy_cat.get_metric(reset)}
        else:
            metrics = {
                    'parse_acc': self._action_sequence_accuracy.get_metric(reset),
                    'denotation_acc': self._denotation_accuracy.get_metric(reset),
                    'lf_percent': self._has_logical_form.get_metric(reset),
            }
        return metrics

    def _create_grammar_state(self,
                              world: QuarelWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``QuarelWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(global_action_tensor)
                global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_output_embeddings,
                                                       list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor)
                translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                           entity_type_embeddings,
                                                           list(linked_action_ids))

        return GrammarStatelet([START_SYMBOL],
                               translated_valid_actions,
                               type_declaration.is_nonterminal)

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``FrictionQDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict
Exemple #38
0
class AclClassifier(Model):
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 node_embedder: TokenEmbedder,
                 verbose_metrics: False,
                 classifier_feedforward: FeedForward,
                 use_node_vector: bool = True,
                 use_abstract: bool = True,
                 dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(AclClassifier, self).__init__(vocab, regularizer)

        self.node_embedder = node_embedder
        self.text_field_embedder = text_field_embedder
        self.use_node_vector = use_node_vector
        self.use_abstract = use_abstract
        self.dropout = torch.nn.Dropout(dropout)
        self.num_classes = self.vocab.get_vocab_size("labels")

        self.classifier_feedforward = classifier_feedforward

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}

        self.verbose_metrics = verbose_metrics

        for i in range(self.num_classes):
            label_name = vocab.get_token_from_index(index=i, namespace="labels")
            self.label_f1_metrics[label_name] = F1Measure(positive_label=i)

        self.loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    @overrides
    def forward(self,
                abstract: Dict[str, torch.LongTensor],
                paper_id: torch.LongTensor,
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        if self.use_abstract and self.use_node_vector:
            embedding = torch.cat([
                self.text_field_embedder(abstract)[:, 0, :],
                self.node_embedder(paper_id)], dim=-1)
        elif self.use_abstract:
            embedding = self.text_field_embedder(abstract)[:, 0, :]
        elif self.use_node_vector:
            embedding = self.node_embedder(paper_id)
        else:
            embedding = self.node_embedder(paper_id)

        logits = self.classifier_feedforward(self.dropout(embedding))
        class_probs = F.softmax(logits, dim=1)
        output_dict = {"logits": logits}

        if label is not None:
            loss = self.loss(logits, label)
            output_dict["label"] = label
            output_dict["loss"] = loss
        for i in range(self.num_classes):
            label_name = self.vocab.get_token_from_index(index=i, namespace="labels")
            metric = self.label_f1_metrics[label_name]
            metric(class_probs, label)
        self.label_accuracy(logits, label)

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        class_probs = F.softmax(output_dict['logits'], dim=-1)
        output_dict['pred_label'] = [
            self.vocab.get_token_from_index(index=int(np.argmax(probs)), namespace="labels")
            for probs in class_probs.cpu()
        ]
        output_dict['label'] = [
            self.vocab.get_token_from_index(index=int(label), namespace="labels")
            for label in output_dict['label'].cpu()
        ]
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metric_dict = {}

        sum_f1 = 0.0
        for name, metric in self.label_f1_metrics.items():
            metric_val = metric.get_metric(reset)
            if self.verbose_metrics:
                metric_dict[name + '_P'] = metric_val[0]
                metric_dict[name + '_R'] = metric_val[1]
                metric_dict[name + '_F1'] = metric_val[2]
            sum_f1 += metric_val[2]

        names = list(self.label_f1_metrics.keys())
        total_len = len(names)
        if total_len > 0:
            average_f1 = sum_f1 / total_len
        else:
            average_f1 = 0.0
        metric_dict['average_F1'] = average_f1
        metric_dict['accuracy'] = self.label_accuracy.get_metric(reset)
        return metric_dict
class DialogQA(Model):
    """
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf].
    In this set-up, a single instance is a dialog, list of question answer pairs.
    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 25) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()
        max_turn_length = 15

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim,
                                                       self._encoding_dim,
                                                       'x,y,x*y')
        self._merge_atten = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(
                max_turn_length, marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding(
                (num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim,
                                                     self._encoding_dim,
                                                     'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(
            phrase_layer.get_input_dim(),
            text_field_embedder.get_output_dim() +
            marker_embedding_dim * num_context_answers,
            "phrase layer input dim",
            "embedding dim + marker dim * num context answers")

        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            p1_answer_marker: torch.IntTensor = None,
            p2_answer_marker: torch.IntTensor = None,
            p3_answer_marker: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            followup_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.
        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question[
            'token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question,
                                                      num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
            self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(
            self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question,
                                                 num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(
                max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(
                question_num_ind)
            embedded_question = torch.cat(
                [embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count,
                                                     passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat(
                [repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(
                    total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat(
                    [repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(
                        total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(
                        p3_answer_marker)
                    repeated_embedded_passage = torch.cat(
                        [repeated_embedded_passage, p3_answer_marker_emb],
                        dim=-1)

            repeated_encoded_passage = self._variational_dropout(
                self._phrase_layer(repeated_embedded_passage,
                                   repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(
                self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(
                total_qa_count, passage_length, self._encoding_dim)

        encoded_question = self._variational_dropout(
            self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(
            repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(total_qa_count, passage_length, self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            repeated_encoded_passage, passage_question_vectors,
            repeated_encoded_passage * passage_question_vectors,
            repeated_encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(
            self._residual_encoder(final_merged_passage,
                                   repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer,
                                                     residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs,
                                           residual_layer)
        self_attention_vecs = torch.cat([
            self_attention_vecs, residual_layer,
            residual_layer * self_attention_vecs
        ],
                                        dim=-1)
        residual_layer = F.relu(
            self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage,
                                             repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(
            torch.cat([final_merged_passage, start_rep], dim=-1),
            repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(
            -1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits,
                                                       span_end_logits,
                                                       span_yesno_logits,
                                                       span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
                                                    repeated_passage_mask),
                            span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask),
                             span_end.view(-1),
                             ignore_index=-1)
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().cpu().detach(
            ).numpy().reshape(total_qa_count)
            # print("span_end = {}, type(span_end)={} total_qa_count = {}".format(span_end, type(span_end), total_qa_count))
            print("span_end.shape = {}".format(span_end.shape))
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
                # print("i = {}, gold_span_end_loc = {}".format(i, gold_span_end_loc))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1),
                             followup_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)
            self._span_followup_accuracy(_followup,
                                         followup_list.view(-1),
                                         mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad_eval.metric_max_over_ground_truths(
                                    squad_eval.f1_score, best_span_string,
                                    refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \
                      for yn_list in output_dict.pop("yesno")]
        followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \
                         for followup_list in output_dict.pop("followup")]
        output_dict['yesno'] = yesno_tags
        output_dict['followup'] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'yesno': self._span_yesno_accuracy.get_metric(reset),
            'followup': self._span_followup_accuracy.get_metric(reset),
            'f1': self._official_f1.get_metric(reset),
        }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      span_followup_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4),
                                                     dtype=torch.long)

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()
        span_followup_logits = span_followup_logits.data.cpu().numpy()
        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
class Baseline(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.word_embeddings = word_embeddings

        self.text_seq_encoder = PytorchSeq2VecWrapper(LSTM(word_embeddings.get_output_dim(),
                                                      int(word_embeddings.get_output_dim()/2),
                                                      batch_first=True,
                                                      bidirectional=True))

        self.out = torch.nn.Linear(
            in_features=self.text_seq_encoder.get_output_dim()*4,
            out_features=vocab.get_vocab_size('labels')
        )
        self.accuracy = CategoricalAccuracy()
        self.f_score_0 = F1Measure(positive_label=0)
        self.f_score_1 = F1Measure(positive_label=1)
        self.f_score_2 = F1Measure(positive_label=2)
        self.loss = CrossEntropyLoss()

    def forward(self,
                article: Dict[str, torch.Tensor],
                outcome: Dict[str, torch.Tensor],
                intervention: Dict[str, torch.Tensor],
                comparator: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        a_mask = get_text_field_mask(article)
        a_embeddings = self.word_embeddings(article)
        a_vec = self.text_seq_encoder(a_embeddings, a_mask)

        o_mask = get_text_field_mask(outcome)
        o_embeddings = self.word_embeddings(outcome)
        o_vec = self.text_seq_encoder(o_embeddings, o_mask)

        i_mask = get_text_field_mask(intervention)
        i_embeddings = self.word_embeddings(intervention)
        i_vec = self.text_seq_encoder(i_embeddings, i_mask)

        c_mask = get_text_field_mask(comparator)
        c_embeddings = self.word_embeddings(comparator)
        c_vec = self.text_seq_encoder(c_embeddings, c_mask)

        logits = self.out(torch.cat((a_vec, o_vec, i_vec, c_vec), dim=1))

        output = {'logits': logits}

        if labels is not None:
            self.accuracy(logits, labels)
            self.f_score_0(logits, labels)
            self.f_score_1(logits, labels)
            self.f_score_2(logits, labels)
            output['loss'] = self.loss(logits, labels)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        _, _, f_score0 = self.f_score_0.get_metric(reset)
        _, _, f_score1 = self.f_score_1.get_metric(reset)
        _, _, f_score2 = self.f_score_2.get_metric(reset)
        return {'accuracy': self.accuracy.get_metric(reset),
                'f-score': np.mean([f_score0, f_score1, f_score2])}
Exemple #41
0
class ESIM(Model):
    """
    This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference"
    <https://www.semanticscholar.org/paper/Enhanced-LSTM-for-Natural-Language-Inference-Chen-Zhu/83e7654d545fbbaaf2328df365a781fb67b841b4>`_
    by Chen et al., 2017.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
        model.
    encoder : ``Seq2SeqEncoder``
        Used to encode the premise and hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between encoded
        words in the premise and words in the hypothesis.
    projection_feedforward : ``FeedForward``
        The feedforward network used to project down the encoded and enhanced premise and hypothesis.
    inference_encoder : ``Seq2SeqEncoder``
        Used to encode the projected premise and hypothesis for prediction.
    output_feedforward : ``FeedForward``
        Used to prepare the concatenated premise and hypothesis for prediction.
    output_logit : ``FeedForward``
        This feedforward network computes the output logits.
    dropout : ``float``, optional (default=0.5)
        Dropout percentage to use.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 output_logit: FeedForward,
                 dropout: float = 0.5,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._encoder = encoder

        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._projection_feedforward = projection_feedforward

        self._inference_encoder = inference_encoder

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.rnn_input_dropout = InputVariationalDropout(dropout)
        else:
            self.dropout = None
            self.rnn_input_dropout = None

        self._output_feedforward = output_feedforward
        self._output_logit = output_logit

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        check_dimensions_match(encoder.get_output_dim() * 4, projection_feedforward.get_input_dim(),
                               "encoder output dim", "projection feedforward input")
        check_dimensions_match(projection_feedforward.get_output_dim(), inference_encoder.get_input_dim(),
                               "proj feedforward output dim", "inference lstm input dim")

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],
                dim=-1
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True
        )

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self._accuracy.get_metric(reset)}
Exemple #42
0
class MultiGranularityHierarchicalAttentionFusionNetworks(Model):
    def __init__(
            self,
            vocab: Vocabulary,
            elmo_embedder: TextFieldEmbedder,
            tokens_embedder: TextFieldEmbedder,
            features_embedder: TextFieldEmbedder,
            phrase_layer: Seq2SeqEncoder,
            projected_layer: Seq2SeqEncoder,
            contextual_passage: Seq2SeqEncoder,
            contextual_question: Seq2SeqEncoder,
            dropout: float = 0.2,
            regularizer: Optional[RegularizerApplicator] = None,
            initializer: InitializerApplicator = InitializerApplicator(),
    ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks,
              self).__init__(vocab, regularizer)
        self.elmo_embedder = elmo_embedder
        self.tokens_embedder = tokens_embedder
        self.features_embedder = features_embedder
        self._phrase_layer = phrase_layer
        self._encoding_dim = self._phrase_layer.get_output_dim()
        self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024,
                                               self._encoding_dim)
        self.fuse_p = FusionLayer(self._encoding_dim)
        self.fuse_q = FusionLayer(self._encoding_dim)
        self.fuse_s = FusionLayer(self._encoding_dim)
        self.projected_lstm = projected_layer
        self.contextual_layer_p = contextual_passage
        self.contextual_layer_q = contextual_question
        self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1)
        # self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._self_attention = BilinearMatrixAttention(self._encoding_dim,
                                                       self._encoding_dim)
        self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim,
                                               self._encoding_dim)
        self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim,
                                               self._encoding_dim)
        self.yesno_predictor = FeedForward(self._encoding_dim,
                                           self._encoding_dim, 3)
        self.relu = torch.nn.ReLU()

        self._max_span_length = 30

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)

    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, max_qa_count, max_q_len, _ = question[
            'token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(yesno_list, 0).view(total_qa_count)

        # GloVe and simple cnn char embedding, embedding dim = 100 + 100 = 200
        word_emb_ques = self.tokens_embedder(
            question,
            num_wrapping_dims=1).reshape(total_qa_count, max_q_len,
                                         self.tokens_embedder.get_output_dim())
        word_emb_pass = self.tokens_embedder(passage)

        # Elmo embedding, embedding dim = 1024
        elmo_ques = self.elmo_embedder(question, num_wrapping_dims=1).reshape(
            total_qa_count, max_q_len, self.elmo_embedder.get_output_dim())
        elmo_pass = self.elmo_embedder(passage)

        # Passage features embedding, embedding dim = 20 + 20 = 40
        pass_feat = self.features_embedder(passage)

        # GloVe + cnn + Elmo
        embedded_question = self._variational_dropout(
            torch.cat([word_emb_ques, elmo_ques], dim=2))
        embedded_passage = self._variational_dropout(
            torch.cat([word_emb_pass, elmo_pass], dim=2))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question,
                                                 num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        # Concatenate Elmo after encoded passage
        encode_passage = self._phrase_layer(embedded_passage, passage_mask)
        projected_passage = self.relu(
            self.projected_layer(torch.cat([encode_passage, elmo_pass],
                                           dim=2)))

        # Concatenate Elmo after encoded question
        encode_question = self._phrase_layer(embedded_question, question_mask)
        projected_question = self.relu(
            self.projected_layer(torch.cat([encode_question, elmo_ques],
                                           dim=2)))

        encoded_passage = self._variational_dropout(projected_passage)
        repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
            1, max_qa_count, 1, 1)
        repeated_encoded_passage = repeated_encoded_passage.view(
            total_qa_count, passage_length, self._encoding_dim)
        repeated_pass_feat = (pass_feat.unsqueeze(1).repeat(
            1, max_qa_count, 1, 1)).view(total_qa_count, passage_length, 40)
        encoded_question = self._variational_dropout(projected_question)

        # total_qa_count * max_q_len * passage_length
        # cnt * m * n
        s = torch.bmm(encoded_question,
                      repeated_encoded_passage.transpose(2, 1))
        alpha = util.masked_softmax(s,
                                    question_mask.unsqueeze(2).expand(
                                        s.size()),
                                    dim=1)
        # cnt * n * h
        aligned_p = torch.bmm(alpha.transpose(2, 1), encoded_question)

        # cnt * m * n
        beta = util.masked_softmax(s,
                                   repeated_passage_mask.unsqueeze(1).expand(
                                       s.size()),
                                   dim=2)
        # cnt * m * h
        aligned_q = torch.bmm(beta, repeated_encoded_passage)

        fused_p = self.fuse_p(repeated_encoded_passage, aligned_p)
        fused_q = self.fuse_q(encoded_question, aligned_q)

        # add manual features here
        q_aware_p = self._variational_dropout(
            self.projected_lstm(
                torch.cat([fused_p, repeated_pass_feat], dim=2),
                repeated_passage_mask))

        # cnt * n * n
        # self_p = torch.bmm(q_aware_p, q_aware_p.transpose(2, 1))
        # self_p = self.bilinear_self_align(q_aware_p)
        self_p = self._self_attention(q_aware_p, q_aware_p)
        mask = repeated_passage_mask.reshape(
            total_qa_count, passage_length, 1) * repeated_passage_mask.reshape(
                total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              device=self_p.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        lamb = util.masked_softmax(self_p, mask, dim=2)
        # lamb = util.masked_softmax(self_p, repeated_passage_mask, dim=2)
        # cnt * n * h
        self_aligned_p = torch.bmm(lamb, q_aware_p)

        # cnt * n * h
        fused_self_p = self.fuse_s(q_aware_p, self_aligned_p)
        contextual_p = self._variational_dropout(
            self.contextual_layer_p(fused_self_p, repeated_passage_mask))
        # contextual_p = self.contextual_layer_p(fused_self_p, repeated_passage_mask)

        contextual_q = self._variational_dropout(
            self.contextual_layer_q(fused_q, question_mask))
        # contextual_q = self.contextual_layer_q(fused_q, question_mask)
        # cnt * m
        gamma = util.masked_softmax(
            self.linear_self_align(contextual_q).squeeze(2),
            question_mask,
            dim=1)
        # cnt * h
        weighted_q = torch.bmm(gamma.unsqueeze(1), contextual_q).squeeze(1)

        span_start_logits = self.bilinear_layer_s(weighted_q, contextual_p)
        span_end_logits = self.bilinear_layer_e(weighted_q, contextual_p)

        # cnt * n * 1  cnt * 1 * h
        span_yesno_logits = self.yesno_predictor(
            torch.bmm(span_end_logits.unsqueeze(2), weighted_q.unsqueeze(1)))
        # span_yesno_logits = self.yesno_predictor(contextual_p)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits,
                                                       span_end_logits,
                                                       span_yesno_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss for training

        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
                                                    repeated_passage_mask),
                            span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask),
                             span_end.view(-1),
                             ignore_index=-1)
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(torch.nn.functional.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                yesno_pred = predicted_span[2]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_query_id_list.append(iid)
                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad_eval.metric_max_over_ground_truths(
                                    squad_eval.f1_score, best_span_string,
                                    refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
        return output_dict

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[
            self.vocab.get_token_from_index(x, namespace="yesno_labels")
            for x in yn_list
        ] for yn_list in output_dict.pop("yesno")]
        output_dict['yesno'] = yesno_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'yesno': self._span_yesno_accuracy.get_metric(reset),
            'f1': self._official_f1.get_metric(reset),
        }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 3),
                                                     dtype=torch.long)
        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()

        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
        return best_word_span
Exemple #43
0
class IntentParamClassifier(Model):
    """
    This ``Model`` performs intent classification for user_utterance.  We assume we're given a
    user_utterance, a prev_user_utterance and a prev_sys_utterance, and we predict some output label for intent.
    The basic model structure: we'll embed the user_utterance, the prev_user_utterance and the prev_sys_utterance,
    and encode each of them with separate Seq2VecEncoders, getting a single vector representing the content of each.
    We'll then concatenate those three vectors, and pass the result through a feedforward network, the output of
    which we'll use as our scores for each label.
    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    user_utterance_encoder : ``Seq2VecEncoder``
        The encoder that we will use to convert the user_utterance to a vector.
    prev_user_utterance_encoder : ``Seq2VecEncoder``
        The encoder that we will use to convert the prev_user_utterance to a vector.
    prev_sys_utterance_encoder : ``Seq2VecEncoder``
        The encoder that we will use to convert the prev_sys_utterance to a vector.
    classifier_feedforward : ``FeedForward``
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 user_utterance_encoder: Seq2VecEncoder,
                 prev_user_utterance_encoder: Seq2VecEncoder,
                 prev_sys_utterance_encoder: Seq2VecEncoder,
                 classifier_feedforward: FeedForward,
                 encoder: Seq2SeqEncoder,
                 calculate_span_f1: bool = None,
                 tag_encoding: Optional[str] = None,
                 tag_namespace: str = "tags",
                 verbose_metrics: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(IntentParamClassifier, self).__init__(vocab, regularizer)

        # Intent task
        self.text_field_embedder = text_field_embedder
        self.label_num_classes = self.vocab.get_vocab_size("labels")
        self.user_utterance_encoder = user_utterance_encoder
        self.prev_user_utterance_encoder = prev_user_utterance_encoder
        self.prev_sys_utterance_encoder = prev_sys_utterance_encoder
        self.classifier_feedforward = classifier_feedforward

        if text_field_embedder.get_output_dim() != user_utterance_encoder.get_input_dim():
            raise ConfigurationError("The output dimension of the text_field_embedder must match the "
                                     "input dimension of the user_utterance_encoder. Found {} and {}, "
                                     "respectively.".format(text_field_embedder.get_output_dim(),
                                                            user_utterance_encoder.get_input_dim()))
        if text_field_embedder.get_output_dim() != prev_user_utterance_encoder.get_input_dim():
            raise ConfigurationError("The output dimension of the text_field_embedder must match the "
                                     "input dimension of the prev_user_utterance_encoder. Found {} and {}, "
                                     "respectively.".format(text_field_embedder.get_output_dim(),
                                                            prev_user_utterance_encoder.get_input_dim()))
        if text_field_embedder.get_output_dim() != prev_sys_utterance_encoder.get_input_dim():
            raise ConfigurationError("The output dimension of the text_field_embedder must match the "
                                     "input dimension of the prev_sys_utterance_encoder. Found {} and {}, "
                                     "respectively.".format(text_field_embedder.get_output_dim(),
                                                            prev_sys_utterance_encoder.get_input_dim()))

        self.label_accuracy = CategoricalAccuracy()

        self.label_f1_metrics = {}
        for i in range(self.label_num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] = F1Measure(positive_label=i)

        self.loss = torch.nn.CrossEntropyLoss()


        # Param task
        self.tag_namespace = tag_namespace
        self.tag_num_classes = self.vocab.get_vocab_size(tag_namespace)
        self.encoder = encoder
        self._verbose_metrics = verbose_metrics
        self.tag_projection_layer = TimeDistributed(Linear(self.encoder.get_output_dim(),
                                                           self.tag_num_classes))

        check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        # We keep calculate_span_f1 as a constructor argument for API consistency with
        # the CrfTagger, even it is redundant in this class
        # (tag_encoding serves the same purpose).
        if calculate_span_f1 and not tag_encoding:
            raise ConfigurationError("calculate_span_f1 is True, but "
                                     "no tag_encoding was specified.")

        self.tag_accuracy = CategoricalAccuracy()

        if calculate_span_f1 or tag_encoding:
            self._f1_metric = SpanBasedF1Measure(vocab,
                                                 tag_namespace=tag_namespace,
                                                 tag_encoding=tag_encoding)
        else:
            self._f1_metric = None

        self.f1 = SpanBasedF1Measure(vocab, tag_namespace=tag_namespace)

        self.tag_f1_metrics = {}
        for k in range(self.tag_num_classes):
            self.tag_f1_metrics[vocab.get_token_from_index(index=k, namespace=tag_namespace)] = F1Measure(
                positive_label=k)

        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                user_utterance: Dict[str, torch.LongTensor],
                prev_user_utterance: Dict[str, torch.LongTensor],
                prev_sys_utterance: Dict[str, torch.LongTensor],
                tokens: Dict[str, torch.LongTensor],
                label: torch.LongTensor = None,
                tags: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        user_utterance : Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        prev_user_utterance : Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        prev_sys_utterance : Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        label : Variable, optional (default = None)
            A variable representing the intent label for each instance in the batch.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the
            label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # Intent task
        embedded_user_utterance = self.text_field_embedder(user_utterance)
        user_utterance_mask = util.get_text_field_mask(user_utterance)
        encoded_user_utterance = self.user_utterance_encoder(embedded_user_utterance, user_utterance_mask)

        embedded_prev_user_utterance = self.text_field_embedder(prev_user_utterance)
        prev_user_utterance_mask = util.get_text_field_mask(prev_user_utterance)
        encoded_prev_user_utterance = self.prev_user_utterance_encoder(embedded_prev_user_utterance, prev_user_utterance_mask)

        embedded_prev_sys_utterance = self.text_field_embedder(prev_sys_utterance)
        prev_sys_utterance_mask = util.get_text_field_mask(prev_sys_utterance)
        encoded_prev_sys_utterance = self.prev_sys_utterance_encoder(embedded_prev_sys_utterance, prev_sys_utterance_mask)

        # Param task
        embedded_text_input = self.text_field_embedder(tokens)
        batch_size, sequence_length, _ = embedded_text_input.size()
        mask = get_text_field_mask(tokens)
        encoded_text = self.encoder(embedded_text_input, mask)

        label_logits = self.classifier_feedforward(torch.cat([encoded_user_utterance, encoded_prev_user_utterance, encoded_prev_sys_utterance], dim=-1))
        label_class_probs = F.softmax(label_logits, dim=1)
        output_dict = {"label_logits": label_logits, "label_class_probs": label_class_probs}

        tag_logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = tag_logits.view(-1, self.tag_num_classes)
        tag_class_probs = F.softmax(reshaped_log_probs, dim=-1).view([batch_size,
                                                                          sequence_length,
                                                                          self.tag_num_classes])

        output_dict["tag_logits"] = tag_logits
        output_dict["tag_class_probs"] = tag_class_probs

        if label is not None:
            if tags is not None:
                loss = self.loss(label_logits, label) + sequence_cross_entropy_with_logits(tag_logits, tags, mask)
                output_dict["loss"] = loss

                # compute intent F1 per label
                for i in range(self.label_num_classes):
                    metric = self.label_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="labels")]
                    metric(label_class_probs, label)
                self.label_accuracy(label_logits, label)

                # compute param F1 per tag
                for i in range(self.tag_num_classes):
                    metric = self.tag_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="tags")]
                    metric(tag_class_probs, tags, mask.float())
                self.tag_accuracy(tag_logits, tags, mask.float())

        if metadata is not None:
            output_dict["words"] = [x["words"] for x in metadata]
        return output_dict


    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does a simple argmax over the class probabilities, converts indices to string labels, and
        adds a ``"label"`` key to the dictionary with the result.
        """
        # Intent task
#        label_class_probs = F.softmax(output_dict["label_logits"], dim=-1)
#        output_dict["label_class_probs"] = label_class_probs
        label_class_probs = output_dict["label_class_probs"]

        label_predictions = label_class_probs.cpu().data.numpy()
        label_argmax_indices = numpy.argmax(label_predictions, axis=-1)
        labels = [self.vocab.get_token_from_index(x, namespace="labels")
                  for x in label_argmax_indices]
        output_dict["label"] = labels

        # Param task
        tag_all_predictions = output_dict["tag_class_probs"]
        tag_all_predictions = tag_all_predictions.cpu().data.numpy()
        if tag_all_predictions.ndim == 3:
            tag_predictions_list = [tag_all_predictions[i] for i in range(tag_all_predictions.shape[0])]
        else:
            tag_predictions_list = [tag_all_predictions]
        all_tags = []
        for tag_predictions in tag_predictions_list:
            tag_argmax_indices = numpy.argmax(tag_predictions, axis=-1)
            tags = [self.vocab.get_token_from_index(y, namespace="tags")
                    for y in tag_argmax_indices]
            all_tags.append(tags)
        output_dict["tags"] = all_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metric_dict = {}

        # intent task
        label_sum_f1 = 0.0
        label_count = 0
        for label_name, label_metric in self.label_f1_metrics.items():
            label_metric_val = label_metric.get_metric(reset)

            metric_dict[label_name + "_P"] = label_metric_val[0]
            metric_dict[label_name + "_R"] = label_metric_val[1]
            metric_dict[label_name + "_F1"] = label_metric_val[2]
            if label_metric_val[2]:
                label_sum_f1 += label_metric_val[2]
                label_count += 1

        if label_count:
            label_average_f1 = label_sum_f1 / label_count
        else:
            label_average_f1 = label_sum_f1
        metric_dict["intent_average_F1"] = label_average_f1
        metric_dict["label_accuracy"] = self.label_accuracy.get_metric(reset)

        # param task
        tag_sum_f1 = 0.0
        tag_count = 0
        for tag_name, tag_metric in self.tag_f1_metrics.items():
            tag_metric_val = tag_metric.get_metric(reset)
            # if self.verbose_metrics:
            metric_dict[tag_name + "_P"] = tag_metric_val[0]
            metric_dict[tag_name + "_R"] = tag_metric_val[1]
            metric_dict[tag_name + "_F1"] = tag_metric_val[2]
            if tag_metric_val[2]:
                tag_sum_f1 += tag_metric_val[2]
                tag_count += 1

        if tag_count:
            tag_average_f1 = tag_sum_f1 / tag_count
        else:
            tag_average_f1 = tag_sum_f1
        metric_dict["param_average_F1"] = tag_average_f1
        metric_dict["tag_accuracy"] = self.tag_accuracy.get_metric(reset)

        return metric_dict
Exemple #44
0
class LstmSwag(Model):
    """
    This model performs semantic role labeling using BIO tags using Propbank semantic roles.
    Specifically, it is an implmentation of `Deep Semantic Role Labeling - What works
    and what's next <https://homes.cs.washington.edu/~luheng/files/acl2017_hllz.pdf>`_ .

    This implementation is effectively a series of stacked interleaved LSTMs with highway
    connections, applied to embedded sequences of words concatenated with a binary indicator
    containing whether or not a word is the verbal predicate to generate predictions for in
    the sentence. Additionally, during inference, Viterbi decoding is applied to constrain
    the predictions to contain valid BIO sequences.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    encoder : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and predicting output tags.
    binary_feature_dim : int, required.
        The dimensionality of the embedding of the binary verb predicate features.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    label_smoothing : ``float``, optional (default = 0.0)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    """
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            encoder: Seq2SeqEncoder,
            # binary_feature_dim: int,
            embedding_dropout: float = 0.0,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(LstmSwag, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder

        # For the span based evaluation, we don't want to consider labels
        # for verb, because the verb index is provided to the model.
        self.encoder = encoder
        self.embedding_dropout = Dropout(p=embedding_dropout)
        self.output_prediction = Linear(self.encoder.get_output_dim(),
                                        1,
                                        bias=False)

        check_dimensions_match(text_field_embedder.get_output_dim(),
                               encoder.get_input_dim(), "text embedding dim",
                               "eq encoder input dim")
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    def forward(
        self,  # type: ignore
        hypothesis0: Dict[str, torch.LongTensor],
        hypothesis1: Dict[str, torch.LongTensor],
        hypothesis2: Dict[str, torch.LongTensor],
        hypothesis3: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        logits = []
        for tokens in [hypothesis0, hypothesis1, hypothesis2, hypothesis3]:
            if isinstance(self.text_field_embedder, ElmoTokenEmbedder):
                self.text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states(
                )

            embedded_text_input = self.embedding_dropout(
                self.text_field_embedder(tokens))
            mask = get_text_field_mask(tokens)

            batch_size, sequence_length, _ = embedded_text_input.size()

            encoded_text = self.encoder(embedded_text_input, mask)

            logits.append(self.output_prediction(encoded_text.max(1)[0]))

        logits = torch.cat(logits, -1)
        class_probabilities = F.softmax(logits, dim=-1).view([batch_size, 4])
        output_dict = {
            "label_logits": logits,
            "label_probs": class_probabilities
        }

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            self._accuracy(logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'accuracy': self._accuracy.get_metric(reset),
        }

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'LstmSwag':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, embedder_params)
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))

        initializer = InitializerApplicator.from_params(
            params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(
            params.pop('regularizer', []))
        params.assert_empty(cls.__name__)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   encoder=encoder,
                   initializer=initializer,
                   regularizer=regularizer)
Exemple #45
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).
    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.
    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: RegularizerApplicator = RegularizerApplicator()):
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = CosineMatrixAttention()
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(),
                               4 * encoding_dim, "modeling layer input dim",
                               "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(),
                               phrase_layer.get_input_dim(),
                               "text field embedder output dim",
                               "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(),
                               4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim",
                               "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.
        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = torch.triu(
            torch.ones((passage_length, passage_length),
                       device=device)).log().unsqueeze(0)
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)
Exemple #46
0
class ModelSQUAD(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 feed_forward: FeedForward,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(ModelSQUAD, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._residual_encoder = residual_encoder
        self._span_end_encoder = span_end_encoder
        self._span_start_encoder = span_start_encoder
        self._feed_forward = feed_forward

        encoding_dim = phrase_layer.get_output_dim()
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))
        self._no_answer_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))

        self._self_matrix_attention = MatrixAttention(
            attention_similarity_function)
        self._linear_layer = TimeDistributed(
            torch.nn.Linear(4 * encoding_dim, encoding_dim))
        self._residual_linear_layer = TimeDistributed(
            torch.nn.Linear(3 * encoding_dim, encoding_dim))

        self._w_x = torch.nn.Parameter(torch.Tensor(encoding_dim))
        self._w_y = torch.nn.Parameter(torch.Tensor(encoding_dim))
        self._w_xy = torch.nn.Parameter(torch.Tensor(encoding_dim))
        std = math.sqrt(6 / (encoding_dim * 3 + 1))
        self._w_x.data.uniform_(-std, std)
        self._w_y.data.uniform_(-std, std)
        self._w_xy.data.uniform_(-std, std)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.LongTensor = None,
            span_end: torch.LongTensor = None,
            spans=None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        # Shape: (batch_size, 4, passage_length, embedding_dim)
        embedded_passage = self._text_field_embedder(passage)

        (batch_size, q_length, embedding_dim) = embedded_question.size()
        passage_length = embedded_passage.size(2)

        # reshape: (batch_size*4, -1, embedding_dim)
        embedded_passage = embedded_passage.view(-1, passage_length,
                                                 embedding_dim)
        embedded_passage = self._highway_layer(embedded_passage)

        embedded_question = embedded_question.unsqueeze(0).expand(
            4, -1, -1, -1).contiguous().view(-1, q_length, embedding_dim)
        question_mask = util.get_text_field_mask(question).float()
        question_mask = question_mask.unsqueeze(0).expand(
            4, -1, -1).contiguous().view(-1, q_length)

        passage_mask = util.get_text_field_mask(passage, 1).float()
        passage_mask = passage_mask.view(-1, passage_length)

        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        cuda_device = encoded_question.get_device()

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        question_attended_passage = relu(
            self._linear_layer(final_merged_passage))

        # TODO: attach residual self-attention layer
        # Shape: (batch_size, passage_length, encoding_dim)
        residual_passage = self._dropout(
            self._residual_encoder(self._dropout(question_attended_passage),
                                   passage_lstm_mask))
        mask = passage_mask.resize(batch_size, passage_length,
                                   1) * passage_mask.resize(
                                       batch_size, 1, passage_length)
        self_mask = Variable(
            torch.eye(passage_length,
                      passage_length).cuda(cuda_device)).resize(
                          1, passage_length, passage_length)
        mask = mask * (1 - self_mask)
        # Shape: (batch_size, passage_length, passage_length)
        x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2)
        y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1)
        dot_similarity = torch.bmm(residual_passage * self._w_xy,
                                   residual_passage.transpose(1, 2))
        passage_self_similarity = dot_similarity + x_similarity + y_similarity
        #for i in range(passage_length):
        #    passage_self_similarity[:, i, i] = float('-Inf')
        # Shape: (batch_size, passage_length, passage_length)
        passage_self_attention = util.last_dim_softmax(passage_self_similarity,
                                                       mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_vectors = util.weighted_sum(residual_passage,
                                            passage_self_attention)
        # Shape: (batch_size, passage_length, encoding_dim * 3)
        merged_passage = torch.cat([
            residual_passage, passage_vectors,
            residual_passage * passage_vectors
        ],
                                   dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        self_attended_passage = relu(
            self._residual_linear_layer(merged_passage))

        # Shape: (batch_size, passage_length, encoding_dim)
        mixed_passage = question_attended_passage + self_attended_passage

        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_start = self._dropout(
            self._span_start_encoder(mixed_passage, passage_lstm_mask))
        span_start_logits = self._span_start_predictor(
            encoded_span_start).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, passage_length, encoding_dim * 2)
        concatenated_passage = torch.cat([mixed_passage, encoded_span_start],
                                         dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(concatenated_passage, passage_lstm_mask))
        span_end_logits = self._span_end_predictor(encoded_span_end).squeeze(
            -1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        # Shape: (batch_size, encoding_dim)
        v_1 = util.weighted_sum(encoded_span_start, span_start_probs)
        v_2 = util.weighted_sum(encoded_span_end, span_end_probs)

        no_span_logits = self._no_answer_predictor(
            self_attended_passage).squeeze(-1)
        no_span_probs = util.masked_softmax(no_span_logits, passage_mask)
        v_3 = util.weighted_sum(self_attended_passage, no_span_probs)
        # Shape: (batch_size, 1)
        z_score = self._feed_forward(torch.cat([v_1, v_2, v_3], dim=-1))
        # compute no-answer score

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }
        # create target tensor including no-answer label
        span_target = Variable(torch.ones(batch_size).long()).cuda(cuda_device)
        for b in range(batch_size):
            span_target[b].data[0] = span_start[
                b, 0].data[0] * passage_length + span_end[b, 0].data[0]
        span_target[span_target < 0] = passage_length**2

        # Shape: (batch_size, passage_length, passage_length)
        span_start_logits_tiled = span_start_logits.unsqueeze(1).expand(
            batch_size, passage_length, passage_length)
        span_end_logits_tiled = span_end_logits.unsqueeze(-1).expand(
            batch_size, passage_length, passage_length)
        span_logits = (span_start_logits_tiled + span_end_logits_tiled).view(
            batch_size, -1)
        answer_mask = torch.bmm(passage_mask.unsqueeze(-1),
                                passage_mask.unsqueeze(1)).view(
                                    batch_size, -1)
        no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device)
        combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1)
        all_logits = torch.cat([span_logits, z_score], dim=-1)
        loss = nll_loss(util.masked_log_softmax(all_logits, combined_mask),
                        span_target)
        output_dict["loss"] = loss

        # Shape(batch_size, max_answers, num_span)
        #    max_answers = spans.size(1)
        #    span_logits = torch.bmm(span_start_logits.unsqueeze(-1), span_end_logits.unsqueeze(1)).view(batch_size, -1)
        #    answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view(batch_size, -1)
        #    no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device)
        #    combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1)
        #    # Shape: (batch_size, passage_length**2 + 1)
        #    all_logits = torch.cat([span_logits, z_score], dim=-1)
        #    # Shape: (batch_size, max_answers)
        #    spans_combined = spans[:, :, 0] * passage_length + spans[:, :, 1]
        #    spans_combined[spans_combined < 0] = passage_length*passage_length
        #
        #    all_modified_logits = []
        #    for b in range(batch_size):
        #        idxs = Variable(torch.LongTensor(range(passage_length**2 + 1))).cuda(cuda_device)
        #        for i in range(max_answers):
        #            idxs[spans_combined[b, i].data[0]].data = idxs[spans_combined[b, 0].data[0]].data
        #        idxs[passage_length**2].data[0] = passage_length**2

        #        modified_logits = Variable(torch.zeros(all_logits.size(-1))).cuda(cuda_device)
        #        modified_logits.index_add_(0, idxs, all_logits[b])
        #        all_modified_logits.append(modified_logits)

        #    all_modified_logits = torch.stack(all_modified_logits, dim=0)
        #    loss = nll_loss(util.masked_log_softmax(all_modified_logits, combined_mask), spans_combined[:, 0])
        #    output_dict["loss"] = loss

        if span_start is not None:
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: Variable,
                      span_end_logits: Variable) -> Variable:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = Variable(span_start_logits.data.new().resize_(
            batch_size, 2).fill_(0)).long()

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'ModelSQUAD':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, embedder_params)
        num_highway_layers = params.pop_int("num_highway_layers")
        phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer"))
        similarity_function = SimilarityFunction.from_params(
            params.pop("similarity_function"))
        residual_encoder = Seq2SeqEncoder.from_params(
            params.pop("residual_encoder"))
        span_start_encoder = Seq2SeqEncoder.from_params(
            params.pop("span_start_encoder"))
        span_end_encoder = Seq2SeqEncoder.from_params(
            params.pop("span_end_encoder"))
        feed_forward = FeedForward.from_params(params.pop("feed_forward"))
        dropout = params.pop_float('dropout', 0.2)

        initializer = InitializerApplicator.from_params(
            params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(
            params.pop('regularizer', []))

        mask_lstms = params.pop_bool('mask_lstms', True)
        params.assert_empty(cls.__name__)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   num_highway_layers=num_highway_layers,
                   phrase_layer=phrase_layer,
                   attention_similarity_function=similarity_function,
                   residual_encoder=residual_encoder,
                   span_start_encoder=span_start_encoder,
                   span_end_encoder=span_end_encoder,
                   feed_forward=feed_forward,
                   dropout=dropout,
                   mask_lstms=mask_lstms,
                   initializer=initializer,
                   regularizer=regularizer)
Exemple #47
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
                                                      num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
                }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
 def test_does_not_divide_by_zero_with_no_count(self):
     accuracy = CategoricalAccuracy()
     self.assertAlmostEqual(accuracy.get_metric(), 0.0)