Exemple #1
0
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              mixture_feedforward: FeedForward,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              attention_function: SimilarityFunction,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/') -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super(WikiTablesMmlSemanticParser, self).__init__(vocab=vocab,
                                                       question_embedder=question_embedder,
                                                       action_embedding_dim=action_embedding_dim,
                                                       encoder=encoder,
                                                       entity_encoder=entity_encoder,
                                                       mixture_feedforward=mixture_feedforward,
                                                       max_decoding_steps=max_decoding_steps,
                                                       attention_function=attention_function,
                                                       use_neighbor_similarity_for_linking=use_similarity,
                                                       dropout=dropout,
                                                       num_linking_features=num_linking_features,
                                                       rule_namespace=rule_namespace,
                                                       tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood()
class TestMaximumMarginalLikelihood(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleDecoderState(
            [0, 1], [[], []],
            [Variable(torch.Tensor([0.0])),
             Variable(torch.Tensor([0.0]))], [0, 1])
        self.decoder_step = SimpleDecoderStep()
        self.targets = torch.autograd.Variable(
            torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                          [[3, 4, 0], [2, 3, 4], [0, 0, 0]]]))
        self.target_mask = torch.autograd.Variable(
            torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                          [[1, 1, 0], [1, 1, 1], [0, 0, 0]]]))

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()

    def test_decode(self):
        decoded_info = self.trainer.decode(self.initial_state,
                                           self.decoder_step, self.supervision)

        # Our loss is the negative log sum of the scores from each target sequence.  The score for
        # each sequence in our simple transition system is just `-sequence_length`.
        instance0_loss = math.log(math.exp(-3) *
                                  3)  # all three sequences have length 3
        instance1_loss = math.log(
            math.exp(-2) + math.exp(-3))  # one has length 2, one has length 3
        expected_loss = -(instance0_loss + instance1_loss) / 2
        assert_almost_equal(decoded_info['loss'].data.numpy(), expected_loss)

    def test_create_allowed_transitions(self):
        result = self.trainer._create_allowed_transitions(
            self.targets, self.target_mask)
        # There were two instances in this batch.
        assert len(result) == 2

        # The first instance had six valid action sequence prefixes.
        assert len(result[0]) == 6
        assert result[0][()] == {1, 2}
        assert result[0][(1, )] == {2, 3}
        assert result[0][(1, 2)] == {4}
        assert result[0][(1, 3)] == {4}
        assert result[0][(2, )] == {3}
        assert result[0][(2, 3)] == {4}

        # The second instance had four valid action sequence prefixes.
        assert len(result[1]) == 4
        assert result[1][()] == {2, 3}
        assert result[1][(2, )] == {3}
        assert result[1][(2, 3)] == {4}
        assert result[1][(3, )] == {4}

    def test_get_allowed_actions(self):
        state = DecoderState([0, 1, 0], [[1], [0], []], [])
        allowed_transitions = [{(1, ): {2}, (): {3}}, {(0, ): {4, 5}}]
        allowed_actions = self.trainer._get_allowed_actions(
            state, allowed_transitions)
        assert allowed_actions == [{2}, {4, 5}, {3}]
Exemple #3
0
 def __init__(self,
              vocab: Vocabulary,
              sentence_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              attention: Attention,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              dropout: float = 0.0) -> None:
     super(NlvrDirectSemanticParser,
           self).__init__(vocab=vocab,
                          sentence_embedder=sentence_embedder,
                          action_embedding_dim=action_embedding_dim,
                          encoder=encoder,
                          dropout=dropout)
     self._decoder_trainer = MaximumMarginalLikelihood()
     self._decoder_step = BasicTransitionFunction(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         num_start_types=1,
         activation=Activation.by_name('tanh')(),
         predict_start_type_separately=False,
         add_action_bias=False,
         dropout=dropout)
     self._decoder_beam_search = decoder_beam_search
     self._max_decoding_steps = max_decoding_steps
     self._action_padding_index = -1
    def setUp(self):
        super(TestMaximumMarginalLikelihood, self).setUp()
        self.initial_state = SimpleDecoderState(
            [0, 1], [[], []],
            [torch.Tensor([0.0]), torch.Tensor([0.0])], [0, 1])
        self.decoder_step = SimpleDecoderStep()
        self.targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                                     [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])
        self.target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                                         [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()
class TestMaximumMarginalLikelihood(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleDecoderState([0, 1],
                                                [[], []],
                                                [torch.Tensor([0.0]), torch.Tensor([0.0])],
                                                [0, 1])
        self.decoder_step = SimpleDecoderStep()
        self.targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                                     [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])
        self.target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                                         [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()

    def test_decode(self):
        decoded_info = self.trainer.decode(self.initial_state, self.decoder_step, self.supervision)

        # Our loss is the negative log sum of the scores from each target sequence.  The score for
        # each sequence in our simple transition system is just `-sequence_length`.
        instance0_loss = math.log(math.exp(-3) * 3)  # all three sequences have length 3
        instance1_loss = math.log(math.exp(-2) + math.exp(-3))  # one has length 2, one has length 3
        expected_loss = -(instance0_loss + instance1_loss) / 2
        assert_almost_equal(decoded_info['loss'].data.numpy(), expected_loss)
Exemple #6
0
class TestMaximumMarginalLikelihood(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleDecoderState(
            [0, 1], [[], []],
            [Variable(torch.Tensor([0.0])),
             Variable(torch.Tensor([0.0]))], [0, 1])
        self.decoder_step = SimpleDecoderStep()
        self.targets = torch.autograd.Variable(
            torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                          [[3, 4, 0], [2, 3, 4], [0, 0, 0]]]))
        self.target_mask = torch.autograd.Variable(
            torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                          [[1, 1, 0], [1, 1, 1], [0, 0, 0]]]))

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()

    def test_decode(self):
        decoded_info = self.trainer.decode(self.initial_state,
                                           self.decoder_step, self.supervision)

        # Our loss is the negative log sum of the scores from each target sequence.  The score for
        # each sequence in our simple transition system is just `-sequence_length`.
        instance0_loss = math.log(math.exp(-3) *
                                  3)  # all three sequences have length 3
        instance1_loss = math.log(
            math.exp(-2) + math.exp(-3))  # one has length 2, one has length 3
        expected_loss = -(instance0_loss + instance1_loss) / 2
        assert_almost_equal(decoded_info['loss'].data.numpy(), expected_loss)
 def __init__(self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder,
              action_embedding_dim: int, encoder: Seq2SeqEncoder,
              attention_function: SimilarityFunction,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int) -> None:
     super(NlvrDirectSemanticParser,
           self).__init__(vocab=vocab,
                          sentence_embedder=sentence_embedder,
                          action_embedding_dim=action_embedding_dim,
                          encoder=encoder)
     self._decoder_trainer = MaximumMarginalLikelihood()
     self._decoder_step = NlvrDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         attention_function=attention_function)
     self._decoder_beam_search = decoder_beam_search
     self._max_decoding_steps = max_decoding_steps
     self._action_padding_index = -1
Exemple #8
0
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              attention: Attention,
              mixture_feedforward: FeedForward = None,
              add_action_bias: bool = True,
              training_beam_size: int = None,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/') -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      add_action_bias=add_action_bias,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace,
                      tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
     self._decoder_step = LinkingTransitionFunction(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         num_start_types=self._num_start_types,
         predict_start_type_separately=True,
         add_action_bias=self._add_action_bias,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
 def __init__(self,
              vocab,
              question_embedder,
              action_embedding_dim,
              encoder,
              entity_encoder,
              decoder_beam_search,
              max_decoding_steps,
              attention,
              mixture_feedforward=None,
              training_beam_size=None,
              use_neighbor_similarity_for_linking=False,
              dropout=0.0,
              num_linking_features=10,
              rule_namespace=u'rule_labels',
              tables_directory=u'/wikitables/'):
     use_similarity = use_neighbor_similarity_for_linking
     super(WikiTablesMmlSemanticParser, self).__init__(
         vocab=vocab,
         question_embedder=question_embedder,
         action_embedding_dim=action_embedding_dim,
         encoder=encoder,
         entity_encoder=entity_encoder,
         max_decoding_steps=max_decoding_steps,
         use_neighbor_similarity_for_linking=use_similarity,
         dropout=dropout,
         num_linking_features=num_linking_features,
         rule_namespace=rule_namespace,
         tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
     self._decoder_step = WikiTablesDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         num_start_types=self._num_start_types,
         num_entity_types=self._num_entity_types,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
Exemple #10
0
 def __init__(self,
              vocab,
              sentence_embedder,
              action_embedding_dim,
              encoder,
              attention,
              decoder_beam_search,
              max_decoding_steps,
              dropout=0.0):
     super(NlvrDirectSemanticParser,
           self).__init__(vocab=vocab,
                          sentence_embedder=sentence_embedder,
                          action_embedding_dim=action_embedding_dim,
                          encoder=encoder,
                          dropout=dropout)
     self._decoder_trainer = MaximumMarginalLikelihood()
     self._decoder_step = NlvrDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         dropout=dropout)
     self._decoder_beam_search = decoder_beam_search
     self._max_decoding_steps = max_decoding_steps
     self._action_padding_index = -1
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleDecoderState([0, 1],
                                                [[], []],
                                                [torch.Tensor([0.0]), torch.Tensor([0.0])],
                                                [0, 1])
        self.decoder_step = SimpleDecoderStep()
        self.targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                                     [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])
        self.target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                                         [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])

        self.supervision = (self.targets, self.target_mask)
        # High beam size ensures exhaustive search.
        self.trainer = MaximumMarginalLikelihood()
 def __init__(self,
              vocab: Vocabulary,
              sentence_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              attention: Attention,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              dropout: float = 0.0) -> None:
     super(NlvrDirectSemanticParser, self).__init__(vocab=vocab,
                                                    sentence_embedder=sentence_embedder,
                                                    action_embedding_dim=action_embedding_dim,
                                                    encoder=encoder,
                                                    dropout=dropout)
     self._decoder_trainer = MaximumMarginalLikelihood()
     self._decoder_step = NlvrDecoderStep(encoder_output_dim=self._encoder.get_output_dim(),
                                          action_embedding_dim=action_embedding_dim,
                                          input_attention=attention,
                                          dropout=dropout)
     self._decoder_beam_search = decoder_beam_search
     self._max_decoding_steps = max_decoding_steps
     self._action_padding_index = -1
Exemple #13
0
class WikiTablesMmlSemanticParser(WikiTablesSemanticParser):
    """
    A ``WikiTablesMmlSemanticParser`` is a :class:`WikiTablesSemanticParser` which is trained to
    maximize the marginal likelihood of an approximate set of logical forms which give the correct
    denotation. This is a re-implementation of the model used for the paper `Neural Semantic Parsing with Type
    Constraints for Semi-Structured Tables
    <https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_,
    by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017).

    WORK STILL IN PROGRESS.  We'll iteratively improve it until we've reproduced the performance of
    the original parser.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions. Passed to super class.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings. Passed to super class.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question. Passed to super class.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity. Passed to super class.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training. Passed to super class.
    attention_function : ``SimilarityFunction``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  This is the similarity function we use for that
        attention. Passed to super class.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature. Passed to super class.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer). Passed to super class.
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 10 here matches the default in the ``KnowledgeGraphField``,
        which is to use all ten defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question. Passed to super class.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this. Passed to super
        class.
    tables_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables. Passed to super class.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 mixture_feedforward: FeedForward,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention_function: SimilarityFunction,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/') -> None:
        use_similarity = use_neighbor_similarity_for_linking
        super(WikiTablesMmlSemanticParser, self).__init__(vocab=vocab,
                                                          question_embedder=question_embedder,
                                                          action_embedding_dim=action_embedding_dim,
                                                          encoder=encoder,
                                                          entity_encoder=entity_encoder,
                                                          mixture_feedforward=mixture_feedforward,
                                                          max_decoding_steps=max_decoding_steps,
                                                          attention_function=attention_function,
                                                          use_neighbor_similarity_for_linking=use_similarity,
                                                          dropout=dropout,
                                                          num_linking_features=num_linking_features,
                                                          rule_namespace=rule_namespace,
                                                          tables_directory=tables_directory)
        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood()

    @overrides
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRuleArray]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default=None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """
        initial_info = self._get_initial_state_and_scores(question, table, world, actions)
        initial_state = initial_info["initial_state"]
        linking_scores = initial_info["linking_scores"]
        feature_scores = initial_info["feature_scores"]
        similarity_scores = initial_info["similarity_scores"]
        batch_size = list(question.values())[0].size(0)
        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            # TODO(pradeep): Most of the functionality in this black can be moved to the super
            # class.
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = similarity_scores
            outputs['logical_form'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    if example_lisp_string:
                        self._denotation_accuracy(logical_form, example_lisp_string[i])
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                    outputs['entities'].append(world[i].table_graph.entities)
                else:
                    outputs['logical_form'].append('')
                    self._has_logical_form(0.0)
                    if example_lisp_string:
                        self._denotation_accuracy(None, example_lisp_string[i])
            return outputs

    @classmethod
    def from_params(cls, vocab, params: Params) -> 'WikiTablesMmlSemanticParser':
        question_embedder = TextFieldEmbedder.from_params(vocab, params.pop("question_embedder"))
        action_embedding_dim = params.pop_int("action_embedding_dim")
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))
        entity_encoder = Seq2VecEncoder.from_params(params.pop('entity_encoder'))
        max_decoding_steps = params.pop_int("max_decoding_steps")
        mixture_feedforward_type = params.pop('mixture_feedforward', None)
        if mixture_feedforward_type is not None:
            mixture_feedforward = FeedForward.from_params(mixture_feedforward_type)
        else:
            mixture_feedforward = None
        decoder_beam_search = BeamSearch.from_params(params.pop("decoder_beam_search"))
        # If no attention function is specified, we should not use attention, not attention with
        # default similarity function.
        attention_function_type = params.pop("attention_function", None)
        if attention_function_type is not None:
            attention_function = SimilarityFunction.from_params(attention_function_type)
        else:
            attention_function = None
        use_neighbor_similarity_for_linking = params.pop_bool('use_neighbor_similarity_for_linking', False)
        dropout = params.pop_float('dropout', 0.0)
        num_linking_features = params.pop_int('num_linking_features', 10)
        tables_directory = params.pop('tables_directory', '/wikitables/')
        rule_namespace = params.pop('rule_namespace', 'rule_labels')
        params.assert_empty(cls.__name__)
        return cls(vocab,
                   question_embedder=question_embedder,
                   action_embedding_dim=action_embedding_dim,
                   encoder=encoder,
                   entity_encoder=entity_encoder,
                   mixture_feedforward=mixture_feedforward,
                   decoder_beam_search=decoder_beam_search,
                   max_decoding_steps=max_decoding_steps,
                   attention_function=attention_function,
                   use_neighbor_similarity_for_linking=use_neighbor_similarity_for_linking,
                   dropout=dropout,
                   num_linking_features=num_linking_features,
                   tables_directory=tables_directory,
                   rule_namespace=rule_namespace)
class NlvrDirectSemanticParser(NlvrSemanticParser):
    """
    ``NlvrDirectSemanticParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack
    of logical form annotations by maximizing the marginal likelihood of an approximate set of target
    sequences that yield the correct denotation. The main difference between this parser and
    ``NlvrCoverageSemanticParser`` is that while this parser takes the output of an offline search
    process as the set of target sequences for training, the latter performs search during training.

    Parameters
    ----------
    vocab : ``Vocabulary``
        Passed to super-class.
    sentence_embedder : ``TextFieldEmbedder``
        Passed to super-class.
    action_embedding_dim : ``int``
        Passed to super-class.
    encoder : ``Seq2SeqEncoder``
        Passed to super-class.
    attention_function : ``SimilarityFunction``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  This is the similarity function we use for that
        attention.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        Maximum number of steps for beam search after training.
    """
    def __init__(self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int, encoder: Seq2SeqEncoder,
                 attention_function: SimilarityFunction,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int) -> None:
        super(NlvrDirectSemanticParser,
              self).__init__(vocab=vocab,
                             sentence_embedder=sentence_embedder,
                             action_embedding_dim=action_embedding_dim,
                             encoder=encoder)
        self._decoder_trainer = MaximumMarginalLikelihood()
        self._decoder_step = NlvrDecoderStep(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            attention_function=attention_function)
        self._decoder_beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        self._action_padding_index = -1

    @overrides
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrWorld]],
            actions: List[List[ProductionRuleArray]],
            target_action_sequences: torch.LongTensor = None,
            labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            util.new_variable_with_data(
                list(sentence.values())[0], torch.Tensor([0.0]))
            for i in range(batch_size)
        ]
        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]
        worlds_list = [worlds[i] for i in range(batch_size)]

        initial_state = NlvrDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            action_indices=action_indices,
            possible_actions=actions,
            worlds=worlds_list,
            label_strings=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        best_final_states = self._decoder_beam_search.search(
            self._max_decoding_steps,
            initial_state,
            self._decoder_step,
            keep_final_unfinished_states=False)
        best_action_sequences: Dict[int, List[List[int]]] = {}
        for i in range(batch_size):
            # Decoding may not have terminated with any completed logical forms, if `num_steps`
            # isn't long enough (or if the model is not trained enough and gets into an
            # infinite action loop).
            if i in best_final_states:
                best_action_indices = [
                    best_final_states[i][0].action_history[0]
                ]
                best_action_sequences[i] = best_action_indices
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if target_action_sequences is not None:
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings)
        else:
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs

    def _update_metrics(self, action_strings: List[List[List[str]]],
                        worlds: List[List[NlvrWorld]],
                        label_strings: List[List[str]]) -> None:
        # TODO(pradeep): Move this to the base class.
        # TODO(pradeep): Using only the best decoded sequence. Define metrics for top-k sequences?
        batch_size = len(worlds)
        for i in range(batch_size):
            instance_action_strings = action_strings[i]
            sequence_is_correct = [False]
            if instance_action_strings:
                instance_label_strings = label_strings[i]
                instance_worlds = worlds[i]
                # Taking only the best sequence.
                sequence_is_correct = self._check_denotation(
                    instance_action_strings[0], instance_label_strings,
                    instance_worlds)
            for correct_in_world in sequence_is_correct:
                self._denotation_accuracy(1 if correct_in_world else 0)
            self._consistency(1 if all(sequence_is_correct) else 0)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'denotation_accuracy': self._denotation_accuracy.get_metric(reset),
            'consistency': self._consistency.get_metric(reset)
        }

    @classmethod
    def from_params(cls, vocab, params: Params) -> 'NlvrDirectSemanticParser':
        sentence_embedder_params = params.pop("sentence_embedder")
        sentence_embedder = TextFieldEmbedder.from_params(
            vocab, sentence_embedder_params)
        action_embedding_dim = params.pop_int('action_embedding_dim')
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))
        attention_function_type = params.pop("attention_function", None)
        if attention_function_type is not None:
            attention_function = SimilarityFunction.from_params(
                attention_function_type)
        else:
            attention_function = None
        decoder_beam_search = BeamSearch.from_params(
            params.pop("decoder_beam_search"))
        max_decoding_steps = params.pop_int("max_decoding_steps")
        params.assert_empty(cls.__name__)
        return cls(vocab,
                   sentence_embedder=sentence_embedder,
                   action_embedding_dim=action_embedding_dim,
                   encoder=encoder,
                   attention_function=attention_function,
                   decoder_beam_search=decoder_beam_search,
                   max_decoding_steps=max_decoding_steps)
class NlvrDirectSemanticParser(NlvrSemanticParser):
    """
    ``NlvrDirectSemanticParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack
    of logical form annotations by maximizing the marginal likelihood of an approximate set of target
    sequences that yield the correct denotation. The main difference between this parser and
    ``NlvrCoverageSemanticParser`` is that while this parser takes the output of an offline search
    process as the set of target sequences for training, the latter performs search during training.

    Parameters
    ----------
    vocab : ``Vocabulary``
        Passed to super-class.
    sentence_embedder : ``TextFieldEmbedder``
        Passed to super-class.
    action_embedding_dim : ``int``
        Passed to super-class.
    encoder : ``Seq2SeqEncoder``
        Passed to super-class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the DecoderStep.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        Maximum number of steps for beam search after training.
    dropout : ``float``, optional (default=0.0)
        Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 dropout: float = 0.0) -> None:
        super(NlvrDirectSemanticParser, self).__init__(vocab=vocab,
                                                       sentence_embedder=sentence_embedder,
                                                       action_embedding_dim=action_embedding_dim,
                                                       encoder=encoder,
                                                       dropout=dropout)
        self._decoder_trainer = MaximumMarginalLikelihood()
        self._decoder_step = NlvrDecoderStep(encoder_output_dim=self._encoder.get_output_dim(),
                                             action_embedding_dim=action_embedding_dim,
                                             input_attention=attention,
                                             dropout=dropout)
        self._decoder_beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        self._action_padding_index = -1

    @overrides
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.LongTensor],
                worlds: List[List[NlvrWorld]],
                actions: List[List[ProductionRuleArray]],
                identifier: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
                              for i in range(batch_size)]
        label_strings = self._get_label_strings(labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
                                 range(batch_size)]
        worlds_list = [worlds[i] for i in range(batch_size)]

        initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)),
                                         action_history=[[] for _ in range(batch_size)],
                                         score=initial_score_list,
                                         rnn_state=initial_rnn_state,
                                         grammar_state=initial_grammar_state,
                                         action_embeddings=action_embeddings,
                                         action_indices=action_indices,
                                         possible_actions=actions,
                                         worlds=worlds_list,
                                         label_strings=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if identifier is not None:
            outputs["identifier"] = identifier
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
        best_final_states = self._decoder_beam_search.search(self._max_decoding_steps,
                                                             initial_state,
                                                             self._decoder_step,
                                                             keep_final_unfinished_states=False)
        best_action_sequences: Dict[int, List[List[int]]] = {}
        for i in range(batch_size):
            # Decoding may not have terminated with any completed logical forms, if `num_steps`
            # isn't long enough (or if the model is not trained enough and gets into an
            # infinite action loop).
            if i in best_final_states:
                best_action_indices = [best_final_states[i][0].action_history[0]]
                best_action_sequences[i] = best_action_indices
        batch_action_strings = self._get_action_strings(actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if target_action_sequences is not None:
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings)
        else:
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs

    def _update_metrics(self,
                        action_strings: List[List[List[str]]],
                        worlds: List[List[NlvrWorld]],
                        label_strings: List[List[str]]) -> None:
        # TODO(pradeep): Move this to the base class.
        # TODO(pradeep): Using only the best decoded sequence. Define metrics for top-k sequences?
        batch_size = len(worlds)
        for i in range(batch_size):
            instance_action_strings = action_strings[i]
            sequence_is_correct = [False]
            if instance_action_strings:
                instance_label_strings = label_strings[i]
                instance_worlds = worlds[i]
                # Taking only the best sequence.
                sequence_is_correct = self._check_denotation(instance_action_strings[0],
                                                             instance_label_strings,
                                                             instance_worlds)
            for correct_in_world in sequence_is_correct:
                self._denotation_accuracy(1 if correct_in_world else 0)
            self._consistency(1 if all(sequence_is_correct) else 0)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'denotation_accuracy': self._denotation_accuracy.get_metric(reset),
                'consistency': self._consistency.get_metric(reset)
        }
Exemple #16
0
class WikiTablesMmlSemanticParser(WikiTablesSemanticParser):
    """
    A ``WikiTablesMmlSemanticParser`` is a :class:`WikiTablesSemanticParser` which is trained to
    maximize the marginal likelihood of an approximate set of logical forms which give the correct
    denotation. This is a re-implementation of the model used for the paper `Neural Semantic Parsing with Type
    Constraints for Semi-Structured Tables
    <https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_,
    by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017).

    WORK STILL IN PROGRESS.  We'll iteratively improve it until we've reproduced the performance of
    the original parser.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions. Passed to super class.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings. Passed to super class.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question. Passed to super class.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity. Passed to super class.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training. Passed to super class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the transition function.
    mixture_feedforward : ``FeedForward``, optional (default=None)
        If given, we'll use this to compute a mixture probability between global actions and linked
        actions given the hidden state at every timestep of decoding, instead of concatenating the
        logits for both (where the logits may not be compatible with each other).  Passed to
        the transition function.
    add_action_bias : ``bool``, optional (default=True)
        If ``True``, we will learn a bias weight for each action that gets used when predicting
        that action, in addition to its embedding.  Passed to super class.
    training_beam_size : ``int``, optional (default=None)
        If given, we will use a constrained beam search of this size during training, so that we
        use only the top ``training_beam_size`` action sequences according to the model in the MML
        computation.  If this is ``None``, we will use all of the provided action sequences in the
        MML computation.
    use_neighbor_similarity_for_linking : ``bool``, optional (default=False)
        If ``True``, we will compute a max similarity between a question token and the `neighbors`
        of an entity as a component of the linking scores.  This is meant to capture the same kind
        of information as the ``related_column`` feature. Passed to super class.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer). Passed to super class.
    num_linking_features : ``int``, optional (default=10)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 10 here matches the default in the ``KnowledgeGraphField``,
        which is to use all ten defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question. Passed to super class.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this. Passed to super
        class.
    tables_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables. Passed to super class.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention: Attention,
                 mixture_feedforward: FeedForward = None,
                 add_action_bias: bool = True,
                 training_beam_size: int = None,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/') -> None:
        use_similarity = use_neighbor_similarity_for_linking
        super().__init__(vocab=vocab,
                         question_embedder=question_embedder,
                         action_embedding_dim=action_embedding_dim,
                         encoder=encoder,
                         entity_encoder=entity_encoder,
                         max_decoding_steps=max_decoding_steps,
                         add_action_bias=add_action_bias,
                         use_neighbor_similarity_for_linking=use_similarity,
                         dropout=dropout,
                         num_linking_features=num_linking_features,
                         rule_namespace=rule_namespace,
                         tables_directory=tables_directory)
        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
        self._decoder_step = LinkingTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            num_start_types=self._num_start_types,
            predict_start_type_separately=True,
            add_action_bias=self._add_action_bias,
            mixture_feedforward=mixture_feedforward,
            dropout=dropout)

    @overrides
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[WikiTablesWorld],
            actions: List[List[ProductionRuleArray]],
            example_lisp_string: List[str] = None,
            target_action_sequences: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default = None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default = None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' key.
        """
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(
            question, table, world, actions, outputs)
        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            possible_actions=actions,
            extras=example_lisp_string,
            debug_info=None)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            return self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        else:
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(
                    initial_state, self._decoder_step,
                    (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][
                        0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(
                            best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)

            self._compute_validation_outputs(actions, best_final_states, world,
                                             example_lisp_string, metadata,
                                             outputs)

            return outputs
class WikiTablesSemanticParser(Model):
    """
    A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question,
    and produces a logical form that answers the question when executed over the table.  The
    logical form is generated by a `type-constrained`, `transition-based` parser.  This is a
    re-implementation of the model used for the paper `Neural Semantic Parsing with Type
    Constraints for Semi-Structured Tables
    <https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_,
    by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017).

    WORK STILL IN PROGRESS.  We'll iteratively improve it until we've reproduced the performance of
    the original parser.

    Parameters
    ----------
    vocab : ``Vocabulary``
    question_embedder : ``TextFieldEmbedder``
        Embedder for questions.
    action_embedding_dim : ``int``
        Dimension to use for action embeddings.
    encoder : ``Seq2SeqEncoder``
        The encoder to use for the input question.
    entity_encoder : ``Seq2VecEncoder``
        The encoder to used for averaging the words of an entity.
    decoder_beam_search : ``BeamSearch``
        When we're not training, this is how we will do decoding.
    max_decoding_steps : ``int``
        When we're decoding with a beam search, what's the maximum number of steps we should take?
        This only applies at evaluation time, not during training.
    attention_function : ``SimilarityFunction``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  This is the similarity function we use for that
        attention.
    dropout : ``float``, optional (default=0)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_linking_features : ``int``, optional (default=8)
        We need to construct a parameter vector for the linking features, so we need to know how
        many there are.  The default of 8 here matches the default in the ``KnowledgeGraphField``,
        which is to use all eight defined features. If this is 0, another term will be added to the
        linking score. This term contains the maximum similarity value from the entity's neighbors
        and the question.
    rule_namespace : ``str``, optional (default=rule_labels)
        The vocabulary namespace to use for production rules.  The default corresponds to the
        default used in the dataset reader, so you likely don't need to modify this.
    table_directory : ``str``, optional (default=/wikitables/)
        The directory to find tables when evaluating logical forms.  We rely on a call to SEMPRE to
        evaluate logical forms, and SEMPRE needs to read the table from disk itself.  This tells
        SEMPRE where to find the tables.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 mixture_feedforward: FeedForward,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention_function: SimilarityFunction,
                 dropout: float = 0.0,
                 num_linking_features: int = 8,
                 rule_namespace: str = 'rule_labels',
                 table_directory: str = '/wikitables/') -> None:
        super(WikiTablesSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = WikiTablesAccuracy(table_directory)
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        self._action_embedder = Embedding(num_embeddings=vocab.get_vocab_size(self._rule_namespace),
                                          embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal(self._first_action_embedding)
        torch.nn.init.normal(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(),
                               "entity word average embedding dim", "question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)
        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._decoder_step = WikiTablesDecoderStep(encoder_output_dim=self._encoder.get_output_dim(),
                                                   action_embedding_dim=action_embedding_dim,
                                                   attention_function=attention_function,
                                                   num_start_types=self._num_start_types,
                                                   num_entity_types=self._num_entity_types,
                                                   mixture_feedforward=mixture_feedforward,
                                                   dropout=dropout)

    @overrides
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRuleArray]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default=None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word cosine similarity. Need to add a small value to
        # to the table norm since there are padding values which cause a divide by 0.
        embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
        embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = question_entity_similarity_max_score + feature_scores
        else:
            # The linking score is computed as a linear projection of two terms. The first is the maximum
            # similarity score over the entity's words and the question token. The second is the maximum
            # similarity over the words in the entity's neighbors and the question token.
            #   The second term, projected_question_neighbor_similarity, is useful when
            # a column needs to be selected. For example, the question token might have no similarity
            # with the column name, but is similar with the cells in the column.
            #   Note that projected_question_neighbor_similarity is intended to capture the same information
            # as the related_column feature.
            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = Variable(encoder_outputs.data.new(batch_size, self._encoder.get_output_dim()).fill_(0))

        initial_score = Variable(embedded_question.data.new(batch_size).fill_(0))

        action_embeddings, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               debug_info=None)
        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['linking_scores'] = linking_scores
            if self._linking_params is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    if example_lisp_string:
                        self._denotation_accuracy(logical_form, example_lisp_string[i])
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                    outputs['entities'].append(world[i].table_graph.entities)
                else:
                    outputs['logical_form'].append('')
                    self._has_logical_form(0.0)
                    if example_lisp_string:
                        self._denotation_accuracy(None, example_lisp_string[i])
            return outputs

    @staticmethod
    def _get_neighbor_indices(worlds: List[WikiTablesWorld],
                              num_entities: int,
                              tensor: Variable) -> torch.LongTensor:
        """
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``Variable``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.table_graph.entities:
                if len(world.table_graph.neighbors[entity]) > num_neighbors:
                    num_neighbors = len(world.table_graph.neighbors[entity])

        batch_neighbors = []
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.table_graph.entities
            entity2index = {entity: i for i, entity in enumerate(entities)}
            entity2neighbors = world.table_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]]
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(neighbor_indexes,
                                                      num_entities,
                                                      lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        return Variable(tensor.data.new(batch_neighbors)).long()

    @staticmethod
    def _get_type_vector(worlds: List[WikiTablesWorld],
                         num_entities: int,
                         tensor: Variable) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the one hot encoding for each entity's type. In addition,
        a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                one_hot_vectors = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(one_hot_vectors[entity_type])

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: [0, 0, 0, 0])
            batch_types.append(padded)
        return Variable(tensor.data.new(batch_types)), entity_types

    def _get_linking_probabilities(self,
                                   worlds: List[WikiTablesWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = Variable(linking_scores.data.new(entity_indices)).long()
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = Variable(linking_scores.data.new(num_question_tokens,
                                                         num_entities - num_entities_in_instance).fill_(0))
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(1):
            return 0
        predicted_tensor = targets.new(predicted)
        targets_trimmed = targets[:, :len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0])

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        We track three metrics here:

            1. dpd_acc, which is the percentage of the time that our best output action sequence is
            in the set of action sequences provided by DPD.  This is an easy-to-compute lower bound
            on denotation accuracy for the set of examples where we actually have DPD output.  We
            only score dpd_acc on that subset.

            2. denotation_acc, which is the percentage of examples where we get the correct
            denotation.  This is the typical "accuracy" metric, and it is what you should usually
            report in an experimental result.  You need to be careful, though, that you're
            computing this on the full data, and not just the subset that has DPD output (make sure
            you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data,
            but not training data).

            3. lf_percent, which is the percentage of time that decoding actually produces a
            finished logical form.  We might not produce a valid logical form if the decoder gets
            into a repetitive loop, or we're trying to produce a super long logical form and run
            out of time steps, or something.
        """
        return {
                'dpd_acc': self._action_sequence_accuracy.get_metric(reset),
                'denotation_acc': self._denotation_accuracy.get_metric(reset),
                'lf_percent': self._has_logical_form.get_metric(reset),
                }

    @staticmethod
    def _create_grammar_state(world: WikiTablesWorld,
                              possible_actions: List[ProductionRuleArray]) -> GrammarState:
        valid_actions = world.get_valid_actions()
        action_mapping = {}
        for i, action in enumerate(possible_actions):
            action_string = action[0]
            action_mapping[action_string] = i
        translated_valid_actions = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = [action_mapping[action_string]
                                             for action_string in action_strings]
        return GrammarState([START_SYMBOL],
                            {},
                            translated_valid_actions,
                            action_mapping,
                            type_declaration.is_nonterminal)

    def _embed_actions(self, actions: List[List[ProductionRuleArray]]) -> Tuple[torch.Tensor,
                                                                                Dict[Tuple[int, int], int]]:
        """
        Given all of the possible actions for all batch instances, produce an embedding for them.
        There will be significant overlap in this list, as the production rules from the grammar
        are shared across all batch instances.  Our returned tensor has an embedding for each
        `unique` action, so we also need to return a mapping from the original ``(batch_index,
        action_index)`` to our new ``global_action_index``, so that we can get the right action
        embedding during decoding.

        Returns
        -------
        action_embeddings : ``torch.Tensor``
            Has shape ``(num_unique_actions, action_embedding_dim)``.
        action_map : ``Dict[Tuple[int, int], int]``
            Maps ``(batch_index, action_index)`` in the input action list to ``action_index`` in
            the ``action_embeddings`` tensor.  All non-embeddable actions get mapped to `-1` here.
        """
        # TODO(mattg): This whole action pipeline might be a whole lot more complicated than it
        # needs to be.  We used to embed actions differently (using some crazy ideas about
        # embedding the LHS and RHS separately); we could probably get away with simplifying things
        # further now that we're just doing a simple embedding for global actions.  But I'm leaving
        # it like this for now to have a minimal change to go from the LHS/RHS embedding to a
        # single action embedding.
        embedded_actions = self._action_embedder.weight

        # Now we just need to make a map from `(batch_index, action_index)` to
        # `global_action_index`.  global_action_ids has the list of all unique actions; here we're
        # going over all of the actions for each batch instance so we can map them to the global
        # action ids.
        action_vocab = self.vocab.get_token_to_index_vocabulary(self._rule_namespace)
        action_map: Dict[Tuple[int, int], int] = {}
        for batch_index, instance_actions in enumerate(actions):
            for action_index, action in enumerate(instance_actions):
                if not action[0]:
                    # This rule is padding.
                    continue
                global_action_id = action_vocab.get(action[0], -1)
                action_map[(batch_index, action_index)] = global_action_id
        return embedded_actions, action_map

    @staticmethod
    def _map_entity_productions(linking_scores: torch.FloatTensor,
                                worlds: List[WikiTablesWorld],
                                actions: List[List[ProductionRuleArray]]) -> Tuple[torch.Tensor,
                                                                                   Dict[Tuple[int, int], int]]:
        """
        Constructs a map from ``(batch_index, action_index)`` to ``(batch_index * entity_index)``.
        That is, some actions correspond to terminal productions of entities from our table.  We
        need to find those actions and map them to their corresponding entity indices, where the
        entity index is its position in the list of entities returned by the ``world``.  This list
        is what defines the second dimension of the ``linking_scores`` tensor, so we can use this
        index to look up linking scores for each action in that tensor.

        For easier processing later, the mapping that we return is `flattened` - we really want to
        map ``(batch_index, action_index)`` to ``(batch_index, entity_index)``, but we are going to
        have to use the result of this mapping to do ``index_selects`` on the ``linking_scores``
        tensor.  You can't do ``index_select`` with tuples, so we flatten ``linking_scores`` to
        have shape ``(batch_size * num_entities, num_question_tokens)``, and return shifted indices
        into this flattened tensor.

        Parameters
        ----------
        linking_scores : ``torch.Tensor``
            A tensor representing linking scores between each table entity and each question token.
            Has shape ``(batch_size, num_entities, num_question_tokens)``.
        worlds : ``List[WikiTablesWorld]``
            The ``World`` for each batch instance.  The ``World`` contains a reference to the
            ``TableKnowledgeGraph`` that defines the set of entities in the linking.
        actions : ``List[List[ProductionRuleArray]]``
            The list of possible actions for each batch instance.  Our action indices are defined
            in terms of this list, so we'll find entity productions in this list and map them to
            entity indices from the entity list we get from the ``World``.

        Returns
        -------
        flattened_linking_scores : ``torch.Tensor``
            A flattened version of ``linking_scores``, with shape ``(batch_size * num_entities,
            num_question_tokens)``.
        actions_to_entities : ``Dict[Tuple[int, int], int]``
            A mapping from ``(batch_index, action_index)`` to ``(batch_size * num_entities)``,
            representing which action indices correspond to which entity indices in the returned
            ``flattened_linking_scores`` tensor.
        """
        batch_size, num_entities, num_question_tokens = linking_scores.size()
        entity_map: Dict[Tuple[int, str], int] = {}
        for batch_index, world in enumerate(worlds):
            for entity_index, entity in enumerate(world.table_graph.entities):
                entity_map[(batch_index, entity)] = batch_index * num_entities + entity_index
        actions_to_entities: Dict[Tuple[int, int], int] = {}
        for batch_index, action_list in enumerate(actions):
            for action_index, action in enumerate(action_list):
                if not action[0]:
                    # This action is padding.
                    continue
                _, production = action[0].split(' -> ')
                entity_index = entity_map.get((batch_index, production), None)
                if entity_index is not None:
                    actions_to_entities[(batch_index, action_index)] = entity_index
        flattened_linking_scores = linking_scores.view(batch_size * num_entities, num_question_tokens)
        return flattened_linking_scores, actions_to_entities

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``WikiTablesDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        action_mapping = output_dict['action_mapping']
        best_actions = output_dict["best_action_sequence"]
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions, debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info['question_attention']
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        return output_dict

    @classmethod
    def from_params(cls, vocab, params: Params) -> 'WikiTablesSemanticParser':
        question_embedder = TextFieldEmbedder.from_params(vocab, params.pop("question_embedder"))
        action_embedding_dim = params.pop_int("action_embedding_dim")
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))
        entity_encoder = Seq2VecEncoder.from_params(params.pop('entity_encoder'))
        max_decoding_steps = params.pop_int("max_decoding_steps")
        mixture_feedforward_type = params.pop('mixture_feedforward', None)
        if mixture_feedforward_type is not None:
            mixture_feedforward = FeedForward.from_params(mixture_feedforward_type)
        else:
            mixture_feedforward = None
        decoder_beam_search = BeamSearch.from_params(params.pop("decoder_beam_search"))
        # If no attention function is specified, we should not use attention, not attention with
        # default similarity function.
        attention_function_type = params.pop("attention_function", None)
        if attention_function_type is not None:
            attention_function = SimilarityFunction.from_params(attention_function_type)
        else:
            attention_function = None
        dropout = params.pop_float('dropout', 0.0)
        num_linking_features = params.pop_int('num_linking_features', 8)
        rule_namespace = params.pop('rule_namespace', 'rule_labels')
        params.assert_empty(cls.__name__)
        return cls(vocab,
                   question_embedder=question_embedder,
                   action_embedding_dim=action_embedding_dim,
                   encoder=encoder,
                   entity_encoder=entity_encoder,
                   mixture_feedforward=mixture_feedforward,
                   decoder_beam_search=decoder_beam_search,
                   max_decoding_steps=max_decoding_steps,
                   attention_function=attention_function,
                   dropout=dropout,
                   num_linking_features=num_linking_features,
                   rule_namespace=rule_namespace)
Exemple #18
0
class NlvrDirectSemanticParser(NlvrSemanticParser):
    u"""
    ``NlvrDirectSemanticParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack
    of logical form annotations by maximizing the marginal likelihood of an approximate set of target
    sequences that yield the correct denotation. The main difference between this parser and
    ``NlvrCoverageSemanticParser`` is that while this parser takes the output of an offline search
    process as the set of target sequences for training, the latter performs search during training.

    Parameters
    ----------
    vocab : ``Vocabulary``
        Passed to super-class.
    sentence_embedder : ``TextFieldEmbedder``
        Passed to super-class.
    action_embedding_dim : ``int``
        Passed to super-class.
    encoder : ``Seq2SeqEncoder``
        Passed to super-class.
    attention : ``Attention``
        We compute an attention over the input question at each step of the decoder, using the
        decoder hidden state as the query.  Passed to the DecoderStep.
    decoder_beam_search : ``BeamSearch``
        Beam search used to retrieve best sequences after training.
    max_decoding_steps : ``int``
        Maximum number of steps for beam search after training.
    dropout : ``float``, optional (default=0.0)
        Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions.
    """
    def __init__(self,
                 vocab,
                 sentence_embedder,
                 action_embedding_dim,
                 encoder,
                 attention,
                 decoder_beam_search,
                 max_decoding_steps,
                 dropout=0.0):
        super(NlvrDirectSemanticParser,
              self).__init__(vocab=vocab,
                             sentence_embedder=sentence_embedder,
                             action_embedding_dim=action_embedding_dim,
                             encoder=encoder,
                             dropout=dropout)
        self._decoder_trainer = MaximumMarginalLikelihood()
        self._decoder_step = NlvrDecoderStep(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            dropout=dropout)
        self._decoder_beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        self._action_padding_index = -1

    #overrides
    def forward(
            self,  # type: ignore
            sentence,
            worlds,
            actions,
            identifier=None,
            target_action_sequences=None,
            labels=None):
        # pylint: disable=arguments-differ
        u"""
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            iter(list(sentence.values())).next().new_zeros(1,
                                                           dtype=torch.float)
            for i in range(batch_size)
        ]
        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]
        worlds_list = [worlds[i] for i in range(batch_size)]

        initial_state = NlvrDecoderState(
            batch_indices=range(batch_size),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            action_indices=action_indices,
            possible_actions=actions,
            worlds=worlds_list,
            label_strings=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs = {}
        if identifier is not None:
            outputs[u"identifier"] = identifier
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        best_final_states = self._decoder_beam_search.search(
            self._max_decoding_steps,
            initial_state,
            self._decoder_step,
            keep_final_unfinished_states=False)
        best_action_sequences = {}
        for i in range(batch_size):
            # Decoding may not have terminated with any completed logical forms, if `num_steps`
            # isn't long enough (or if the model is not trained enough and gets into an
            # infinite action loop).
            if i in best_final_states:
                best_action_indices = [
                    best_final_states[i][0].action_history[0]
                ]
                best_action_sequences[i] = best_action_indices
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if target_action_sequences is not None:
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings)
        else:
            outputs[u"best_action_strings"] = batch_action_strings
            outputs[u"denotations"] = batch_denotations
        return outputs

    def _update_metrics(self, action_strings, worlds, label_strings):
        # TODO(pradeep): Move this to the base class.
        # TODO(pradeep): Using only the best decoded sequence. Define metrics for top-k sequences?
        batch_size = len(worlds)
        for i in range(batch_size):
            instance_action_strings = action_strings[i]
            sequence_is_correct = [False]
            if instance_action_strings:
                instance_label_strings = label_strings[i]
                instance_worlds = worlds[i]
                # Taking only the best sequence.
                sequence_is_correct = self._check_denotation(
                    instance_action_strings[0], instance_label_strings,
                    instance_worlds)
            for correct_in_world in sequence_is_correct:
                self._denotation_accuracy(1 if correct_in_world else 0)
            self._consistency(1 if all(sequence_is_correct) else 0)

    #overrides
    def get_metrics(self, reset=False):
        return {
            u'denotation_accuracy':
            self._denotation_accuracy.get_metric(reset),
            u'consistency': self._consistency.get_metric(reset)
        }
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 mixture_feedforward: FeedForward,
                 decoder_beam_search: BeamSearch,
                 max_decoding_steps: int,
                 attention_function: SimilarityFunction,
                 dropout: float = 0.0,
                 num_linking_features: int = 8,
                 rule_namespace: str = 'rule_labels',
                 table_directory: str = '/wikitables/') -> None:
        super(WikiTablesSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = WikiTablesAccuracy(table_directory)
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        self._action_embedder = Embedding(num_embeddings=vocab.get_vocab_size(self._rule_namespace),
                                          embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal(self._first_action_embedding)
        torch.nn.init.normal(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(),
                               "entity word average embedding dim", "question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)
        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)

        self._decoder_trainer = MaximumMarginalLikelihood()

        self._decoder_step = WikiTablesDecoderStep(encoder_output_dim=self._encoder.get_output_dim(),
                                                   action_embedding_dim=action_embedding_dim,
                                                   attention_function=attention_function,
                                                   num_start_types=self._num_start_types,
                                                   num_entity_types=self._num_entity_types,
                                                   mixture_feedforward=mixture_feedforward,
                                                   dropout=dropout)