class TestMaximumMarginalLikelihood(AllenNlpTestCase): def setUp(self): super().setUp() self.initial_state = SimpleDecoderState( [0, 1], [[], []], [Variable(torch.Tensor([0.0])), Variable(torch.Tensor([0.0]))], [0, 1]) self.decoder_step = SimpleDecoderStep() self.targets = torch.autograd.Variable( torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]], [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])) self.target_mask = torch.autograd.Variable( torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])) self.supervision = (self.targets, self.target_mask) # High beam size ensures exhaustive search. self.trainer = MaximumMarginalLikelihood() def test_decode(self): decoded_info = self.trainer.decode(self.initial_state, self.decoder_step, self.supervision) # Our loss is the negative log sum of the scores from each target sequence. The score for # each sequence in our simple transition system is just `-sequence_length`. instance0_loss = math.log(math.exp(-3) * 3) # all three sequences have length 3 instance1_loss = math.log( math.exp(-2) + math.exp(-3)) # one has length 2, one has length 3 expected_loss = -(instance0_loss + instance1_loss) / 2 assert_almost_equal(decoded_info['loss'].data.numpy(), expected_loss)
class TestMaximumMarginalLikelihood(AllenNlpTestCase): def setUp(self): super().setUp() self.initial_state = SimpleDecoderState( [0, 1], [[], []], [Variable(torch.Tensor([0.0])), Variable(torch.Tensor([0.0]))], [0, 1]) self.decoder_step = SimpleDecoderStep() self.targets = torch.autograd.Variable( torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]], [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])) self.target_mask = torch.autograd.Variable( torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])) self.supervision = (self.targets, self.target_mask) # High beam size ensures exhaustive search. self.trainer = MaximumMarginalLikelihood() def test_decode(self): decoded_info = self.trainer.decode(self.initial_state, self.decoder_step, self.supervision) # Our loss is the negative log sum of the scores from each target sequence. The score for # each sequence in our simple transition system is just `-sequence_length`. instance0_loss = math.log(math.exp(-3) * 3) # all three sequences have length 3 instance1_loss = math.log( math.exp(-2) + math.exp(-3)) # one has length 2, one has length 3 expected_loss = -(instance0_loss + instance1_loss) / 2 assert_almost_equal(decoded_info['loss'].data.numpy(), expected_loss) def test_create_allowed_transitions(self): result = self.trainer._create_allowed_transitions( self.targets, self.target_mask) # There were two instances in this batch. assert len(result) == 2 # The first instance had six valid action sequence prefixes. assert len(result[0]) == 6 assert result[0][()] == {1, 2} assert result[0][(1, )] == {2, 3} assert result[0][(1, 2)] == {4} assert result[0][(1, 3)] == {4} assert result[0][(2, )] == {3} assert result[0][(2, 3)] == {4} # The second instance had four valid action sequence prefixes. assert len(result[1]) == 4 assert result[1][()] == {2, 3} assert result[1][(2, )] == {3} assert result[1][(2, 3)] == {4} assert result[1][(3, )] == {4} def test_get_allowed_actions(self): state = DecoderState([0, 1, 0], [[1], [0], []], []) allowed_transitions = [{(1, ): {2}, (): {3}}, {(0, ): {4, 5}}] allowed_actions = self.trainer._get_allowed_actions( state, allowed_transitions) assert allowed_actions == [{2}, {4, 5}, {3}]
class TestMaximumMarginalLikelihood(AllenNlpTestCase): def setUp(self): super().setUp() self.initial_state = SimpleDecoderState([0, 1], [[], []], [torch.Tensor([0.0]), torch.Tensor([0.0])], [0, 1]) self.decoder_step = SimpleDecoderStep() self.targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]], [[3, 4, 0], [2, 3, 4], [0, 0, 0]]]) self.target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1], [0, 0, 0]]]) self.supervision = (self.targets, self.target_mask) # High beam size ensures exhaustive search. self.trainer = MaximumMarginalLikelihood() def test_decode(self): decoded_info = self.trainer.decode(self.initial_state, self.decoder_step, self.supervision) # Our loss is the negative log sum of the scores from each target sequence. The score for # each sequence in our simple transition system is just `-sequence_length`. instance0_loss = math.log(math.exp(-3) * 3) # all three sequences have length 3 instance1_loss = math.log(math.exp(-2) + math.exp(-3)) # one has length 2, one has length 3 expected_loss = -(instance0_loss + instance1_loss) / 2 assert_almost_equal(decoded_info['loss'].data.numpy(), expected_loss)
class NlvrDirectSemanticParser(NlvrSemanticParser): """ ``NlvrDirectSemanticParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack of logical form annotations by maximizing the marginal likelihood of an approximate set of target sequences that yield the correct denotation. The main difference between this parser and ``NlvrCoverageSemanticParser`` is that while this parser takes the output of an offline search process as the set of target sequences for training, the latter performs search during training. Parameters ---------- vocab : ``Vocabulary`` Passed to super-class. sentence_embedder : ``TextFieldEmbedder`` Passed to super-class. action_embedding_dim : ``int`` Passed to super-class. encoder : ``Seq2SeqEncoder`` Passed to super-class. attention_function : ``SimilarityFunction`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. This is the similarity function we use for that attention. decoder_beam_search : ``BeamSearch`` Beam search used to retrieve best sequences after training. max_decoding_steps : ``int`` Maximum number of steps for beam search after training. """ def __init__(self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, attention_function: SimilarityFunction, decoder_beam_search: BeamSearch, max_decoding_steps: int) -> None: super(NlvrDirectSemanticParser, self).__init__(vocab=vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder) self._decoder_trainer = MaximumMarginalLikelihood() self._decoder_step = NlvrDecoderStep( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, attention_function=attention_function) self._decoder_beam_search = decoder_beam_search self._max_decoding_steps = max_decoding_steps self._action_padding_index = -1 @overrides def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRuleArray]], target_action_sequences: torch.LongTensor = None, labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihod over a set of approximate logical forms. """ batch_size = len(worlds) action_embeddings, action_indices = self._embed_actions(actions) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ util.new_variable_with_data( list(sentence.values())[0], torch.Tensor([0.0])) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] worlds_list = [worlds[i] for i in range(batch_size)] initial_state = NlvrDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, worlds=worlds_list, label_strings=label_strings) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} if target_action_sequences is not None: outputs = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) best_final_states = self._decoder_beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [ best_final_states[i][0].action_history[0] ] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if target_action_sequences is not None: self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings) else: outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs def _update_metrics(self, action_strings: List[List[List[str]]], worlds: List[List[NlvrWorld]], label_strings: List[List[str]]) -> None: # TODO(pradeep): Move this to the base class. # TODO(pradeep): Using only the best decoded sequence. Define metrics for top-k sequences? batch_size = len(worlds) for i in range(batch_size): instance_action_strings = action_strings[i] sequence_is_correct = [False] if instance_action_strings: instance_label_strings = label_strings[i] instance_worlds = worlds[i] # Taking only the best sequence. sequence_is_correct = self._check_denotation( instance_action_strings[0], instance_label_strings, instance_worlds) for correct_in_world in sequence_is_correct: self._denotation_accuracy(1 if correct_in_world else 0) self._consistency(1 if all(sequence_is_correct) else 0) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'denotation_accuracy': self._denotation_accuracy.get_metric(reset), 'consistency': self._consistency.get_metric(reset) } @classmethod def from_params(cls, vocab, params: Params) -> 'NlvrDirectSemanticParser': sentence_embedder_params = params.pop("sentence_embedder") sentence_embedder = TextFieldEmbedder.from_params( vocab, sentence_embedder_params) action_embedding_dim = params.pop_int('action_embedding_dim') encoder = Seq2SeqEncoder.from_params(params.pop("encoder")) attention_function_type = params.pop("attention_function", None) if attention_function_type is not None: attention_function = SimilarityFunction.from_params( attention_function_type) else: attention_function = None decoder_beam_search = BeamSearch.from_params( params.pop("decoder_beam_search")) max_decoding_steps = params.pop_int("max_decoding_steps") params.assert_empty(cls.__name__) return cls(vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, attention_function=attention_function, decoder_beam_search=decoder_beam_search, max_decoding_steps=max_decoding_steps)
class WikiTablesMmlSemanticParser(WikiTablesSemanticParser): """ A ``WikiTablesMmlSemanticParser`` is a :class:`WikiTablesSemanticParser` which is trained to maximize the marginal likelihood of an approximate set of logical forms which give the correct denotation. This is a re-implementation of the model used for the paper `Neural Semantic Parsing with Type Constraints for Semi-Structured Tables <https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_, by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017). WORK STILL IN PROGRESS. We'll iteratively improve it until we've reproduced the performance of the original parser. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. Passed to super class. action_embedding_dim : ``int`` Dimension to use for action embeddings. Passed to super class. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. Passed to super class. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. Passed to super class. decoder_beam_search : ``BeamSearch`` When we're not training, this is how we will do decoding. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. Passed to super class. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. mixture_feedforward : ``FeedForward``, optional (default=None) If given, we'll use this to compute a mixture probability between global actions and linked actions given the hidden state at every timestep of decoding, instead of concatenating the logits for both (where the logits may not be compatible with each other). Passed to the transition function. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. Passed to super class. training_beam_size : ``int``, optional (default=None) If given, we will use a constrained beam search of this size during training, so that we use only the top ``training_beam_size`` action sequences according to the model in the MML computation. If this is ``None``, we will use all of the provided action sequences in the MML computation. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. Passed to super class. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). Passed to super class. num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 10 here matches the default in the ``KnowledgeGraphField``, which is to use all ten defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. Passed to super class. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. Passed to super class. tables_directory : ``str``, optional (default=/wikitables/) The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells SEMPRE where to find the tables. Passed to super class. """ def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, attention: Attention, mixture_feedforward: FeedForward = None, add_action_bias: bool = True, training_beam_size: int = None, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/') -> None: use_similarity = use_neighbor_similarity_for_linking super().__init__(vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, add_action_bias=add_action_bias, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, tables_directory=tables_directory) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) self._decoder_step = LinkingTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, num_start_types=self._num_start_types, predict_start_type_separately=True, add_action_bias=self._add_action_bias, mixture_feedforward=mixture_feedforward, dropout=dropout) @overrides def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]``, optional (default = None) The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. target_action_sequences : torch.Tensor, optional (default = None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenized question within a 'question_tokens' key. """ outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state( question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=actions, extras=example_lisp_string, debug_info=None) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None if self.training: return self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) else: if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][ 0].action_history[0] if target_action_sequences is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequences[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match( best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) self._compute_validation_outputs(actions, best_final_states, world, example_lisp_string, metadata, outputs) return outputs
class WikiTablesMmlSemanticParser(WikiTablesSemanticParser): """ A ``WikiTablesMmlSemanticParser`` is a :class:`WikiTablesSemanticParser` which is trained to maximize the marginal likelihood of an approximate set of logical forms which give the correct denotation. This is a re-implementation of the model used for the paper `Neural Semantic Parsing with Type Constraints for Semi-Structured Tables <https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_, by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017). WORK STILL IN PROGRESS. We'll iteratively improve it until we've reproduced the performance of the original parser. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. Passed to super class. action_embedding_dim : ``int`` Dimension to use for action embeddings. Passed to super class. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. Passed to super class. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. Passed to super class. decoder_beam_search : ``BeamSearch`` When we're not training, this is how we will do decoding. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. Passed to super class. attention_function : ``SimilarityFunction`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. This is the similarity function we use for that attention. Passed to super class. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. Passed to super class. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). Passed to super class. num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 10 here matches the default in the ``KnowledgeGraphField``, which is to use all ten defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. Passed to super class. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. Passed to super class. tables_directory : ``str``, optional (default=/wikitables/) The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells SEMPRE where to find the tables. Passed to super class. """ def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, mixture_feedforward: FeedForward, decoder_beam_search: BeamSearch, max_decoding_steps: int, attention_function: SimilarityFunction, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/') -> None: use_similarity = use_neighbor_similarity_for_linking super(WikiTablesMmlSemanticParser, self).__init__(vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, mixture_feedforward=mixture_feedforward, max_decoding_steps=max_decoding_steps, attention_function=attention_function, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, tables_directory=tables_directory) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood() @overrides def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]``, optional (default=None) The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ initial_info = self._get_initial_state_and_scores(question, table, world, actions) initial_state = initial_info["initial_state"] linking_scores = initial_info["linking_scores"] feature_scores = initial_info["feature_scores"] similarity_scores = initial_info["similarity_scores"] batch_size = list(question.values())[0].size(0) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None if self.training: return self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) else: # TODO(pradeep): Most of the functionality in this black can be moved to the super # class. action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs: Dict[str, Any] = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] outputs['linking_scores'] = linking_scores if feature_scores is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = similarity_scores outputs['logical_form'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] if target_action_sequences is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequences[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form(action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' if example_lisp_string: self._denotation_accuracy(logical_form, example_lisp_string[i]) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore outputs['entities'].append(world[i].table_graph.entities) else: outputs['logical_form'].append('') self._has_logical_form(0.0) if example_lisp_string: self._denotation_accuracy(None, example_lisp_string[i]) return outputs @classmethod def from_params(cls, vocab, params: Params) -> 'WikiTablesMmlSemanticParser': question_embedder = TextFieldEmbedder.from_params(vocab, params.pop("question_embedder")) action_embedding_dim = params.pop_int("action_embedding_dim") encoder = Seq2SeqEncoder.from_params(params.pop("encoder")) entity_encoder = Seq2VecEncoder.from_params(params.pop('entity_encoder')) max_decoding_steps = params.pop_int("max_decoding_steps") mixture_feedforward_type = params.pop('mixture_feedforward', None) if mixture_feedforward_type is not None: mixture_feedforward = FeedForward.from_params(mixture_feedforward_type) else: mixture_feedforward = None decoder_beam_search = BeamSearch.from_params(params.pop("decoder_beam_search")) # If no attention function is specified, we should not use attention, not attention with # default similarity function. attention_function_type = params.pop("attention_function", None) if attention_function_type is not None: attention_function = SimilarityFunction.from_params(attention_function_type) else: attention_function = None use_neighbor_similarity_for_linking = params.pop_bool('use_neighbor_similarity_for_linking', False) dropout = params.pop_float('dropout', 0.0) num_linking_features = params.pop_int('num_linking_features', 10) tables_directory = params.pop('tables_directory', '/wikitables/') rule_namespace = params.pop('rule_namespace', 'rule_labels') params.assert_empty(cls.__name__) return cls(vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, mixture_feedforward=mixture_feedforward, decoder_beam_search=decoder_beam_search, max_decoding_steps=max_decoding_steps, attention_function=attention_function, use_neighbor_similarity_for_linking=use_neighbor_similarity_for_linking, dropout=dropout, num_linking_features=num_linking_features, tables_directory=tables_directory, rule_namespace=rule_namespace)
class WikiTablesSemanticParser(Model): """ A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question, and produces a logical form that answers the question when executed over the table. The logical form is generated by a `type-constrained`, `transition-based` parser. This is a re-implementation of the model used for the paper `Neural Semantic Parsing with Type Constraints for Semi-Structured Tables <https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_, by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017). WORK STILL IN PROGRESS. We'll iteratively improve it until we've reproduced the performance of the original parser. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. decoder_beam_search : ``BeamSearch`` When we're not training, this is how we will do decoding. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. attention_function : ``SimilarityFunction`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. This is the similarity function we use for that attention. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_linking_features : ``int``, optional (default=8) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 8 here matches the default in the ``KnowledgeGraphField``, which is to use all eight defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. table_directory : ``str``, optional (default=/wikitables/) The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells SEMPRE where to find the tables. """ def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, mixture_feedforward: FeedForward, decoder_beam_search: BeamSearch, max_decoding_steps: int, attention_function: SimilarityFunction, dropout: float = 0.0, num_linking_features: int = 8, rule_namespace: str = 'rule_labels', table_directory: str = '/wikitables/') -> None: super(WikiTablesSemanticParser, self).__init__(vocab) self._question_embedder = question_embedder self._encoder = encoder self._entity_encoder = TimeDistributed(entity_encoder) self._beam_search = decoder_beam_search self._max_decoding_steps = max_decoding_steps if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._denotation_accuracy = WikiTablesAccuracy(table_directory) self._action_sequence_accuracy = Average() self._has_logical_form = Average() self._action_padding_index = -1 # the padding value used by IndexField self._action_embedder = Embedding(num_embeddings=vocab.get_vocab_size(self._rule_namespace), embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous question attention. self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal(self._first_action_embedding) torch.nn.init.normal(self._first_attended_question) check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(), "entity word average embedding dim", "question embedding dim") self._num_entity_types = 4 # TODO(mattg): get this in a more principled way somehow? self._num_start_types = 5 # TODO(mattg): get this in a more principled way somehow? self._embedding_dim = question_embedder.get_output_dim() self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) if num_linking_features > 0: self._linking_params = torch.nn.Linear(num_linking_features, 1) else: self._linking_params = None self._question_entity_params = torch.nn.Linear(1, 1) self._question_neighbor_params = torch.nn.Linear(1, 1) self._decoder_trainer = MaximumMarginalLikelihood() self._decoder_step = WikiTablesDecoderStep(encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, attention_function=attention_function, num_start_types=self._num_start_types, num_entity_types=self._num_entity_types, mixture_feedforward=mixture_feedforward, dropout=dropout) @overrides def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]``, optional (default=None) The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = question_entity_similarity_max_score + feature_scores else: # The linking score is computed as a linear projection of two terms. The first is the maximum # similarity score over the entity's words and the question token. The second is the maximum # similarity over the words in the entity's neighbors and the question token. # The second term, projected_question_neighbor_similarity, is useful when # a column needs to be selected. For example, the question token might have no similarity # with the column name, but is similar with the cells in the column. # Note that projected_question_neighbor_similarity is intended to capture the same information # as the related_column feature. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = Variable(encoder_outputs.data.new(batch_size, self._encoder.get_output_dim()).fill_(0)) initial_score = Variable(embedded_question.data.new(batch_size).fill_(0)) action_embeddings, action_indices = self._embed_actions(actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores, world, actions) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i]) for i in range(batch_size)] initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, debug_info=None) if self.training: return self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs: Dict[str, Any] = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] outputs['linking_scores'] = linking_scores if self._linking_params is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] if target_action_sequences is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequences[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form(action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' if example_lisp_string: self._denotation_accuracy(logical_form, example_lisp_string[i]) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore outputs['entities'].append(world[i].table_graph.entities) else: outputs['logical_form'].append('') self._has_logical_form(0.0) if example_lisp_string: self._denotation_accuracy(None, example_lisp_string[i]) return outputs @staticmethod def _get_neighbor_indices(worlds: List[WikiTablesWorld], num_entities: int, tensor: Variable) -> torch.LongTensor: """ This method returns the indices of each entity's neighbors. A tensor is accepted as a parameter for copying purposes. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``Variable`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded with -1 instead of 0, since 0 is a valid neighbor index. """ num_neighbors = 0 for world in worlds: for entity in world.table_graph.entities: if len(world.table_graph.neighbors[entity]) > num_neighbors: num_neighbors = len(world.table_graph.neighbors[entity]) batch_neighbors = [] for world in worlds: # Each batch instance has its own world, which has a corresponding table. entities = world.table_graph.entities entity2index = {entity: i for i, entity in enumerate(entities)} entity2neighbors = world.table_graph.neighbors neighbor_indexes = [] for entity in entities: entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]] # Pad with -1 instead of 0, since 0 represents a neighbor index. padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1) neighbor_indexes.append(padded) neighbor_indexes = pad_sequence_to_length(neighbor_indexes, num_entities, lambda: [-1] * num_neighbors) batch_neighbors.append(neighbor_indexes) return Variable(tensor.data.new(batch_neighbors)).long() @staticmethod def _get_type_vector(worlds: List[WikiTablesWorld], num_entities: int, tensor: Variable) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces the one hot encoding for each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate(world.table_graph.entities): one_hot_vectors = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] # We need numbers to be first, then cells, then parts, then row, because our # entities are going to be sorted. We do a split by type and then a merge later, # and it relies on this sorting. if entity.startswith('fb:cell'): entity_type = 1 elif entity.startswith('fb:part'): entity_type = 2 elif entity.startswith('fb:row'): entity_type = 3 else: entity_type = 0 types.append(one_hot_vectors[entity_type]) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: [0, 0, 0, 0]) batch_types.append(padded) return Variable(tensor.data.new(batch_types)), entity_types def _get_linking_probabilities(self, worlds: List[WikiTablesWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "fb:cell", and "fb:cell" comes before "fb:row". This is not a great # assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = Variable(linking_scores.data.new(entity_indices)).long() entity_scores = linking_scores[batch_index].index_select(1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = Variable(linking_scores.data.new(num_question_tokens, num_entities - num_entities_in_instance).fill_(0)) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track three metrics here: 1. dpd_acc, which is the percentage of the time that our best output action sequence is in the set of action sequences provided by DPD. This is an easy-to-compute lower bound on denotation accuracy for the set of examples where we actually have DPD output. We only score dpd_acc on that subset. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that has DPD output (make sure you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data, but not training data). 3. lf_percent, which is the percentage of time that decoding actually produces a finished logical form. We might not produce a valid logical form if the decoder gets into a repetitive loop, or we're trying to produce a super long logical form and run out of time steps, or something. """ return { 'dpd_acc': self._action_sequence_accuracy.get_metric(reset), 'denotation_acc': self._denotation_accuracy.get_metric(reset), 'lf_percent': self._has_logical_form.get_metric(reset), } @staticmethod def _create_grammar_state(world: WikiTablesWorld, possible_actions: List[ProductionRuleArray]) -> GrammarState: valid_actions = world.get_valid_actions() action_mapping = {} for i, action in enumerate(possible_actions): action_string = action[0] action_mapping[action_string] = i translated_valid_actions = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = [action_mapping[action_string] for action_string in action_strings] return GrammarState([START_SYMBOL], {}, translated_valid_actions, action_mapping, type_declaration.is_nonterminal) def _embed_actions(self, actions: List[List[ProductionRuleArray]]) -> Tuple[torch.Tensor, Dict[Tuple[int, int], int]]: """ Given all of the possible actions for all batch instances, produce an embedding for them. There will be significant overlap in this list, as the production rules from the grammar are shared across all batch instances. Our returned tensor has an embedding for each `unique` action, so we also need to return a mapping from the original ``(batch_index, action_index)`` to our new ``global_action_index``, so that we can get the right action embedding during decoding. Returns ------- action_embeddings : ``torch.Tensor`` Has shape ``(num_unique_actions, action_embedding_dim)``. action_map : ``Dict[Tuple[int, int], int]`` Maps ``(batch_index, action_index)`` in the input action list to ``action_index`` in the ``action_embeddings`` tensor. All non-embeddable actions get mapped to `-1` here. """ # TODO(mattg): This whole action pipeline might be a whole lot more complicated than it # needs to be. We used to embed actions differently (using some crazy ideas about # embedding the LHS and RHS separately); we could probably get away with simplifying things # further now that we're just doing a simple embedding for global actions. But I'm leaving # it like this for now to have a minimal change to go from the LHS/RHS embedding to a # single action embedding. embedded_actions = self._action_embedder.weight # Now we just need to make a map from `(batch_index, action_index)` to # `global_action_index`. global_action_ids has the list of all unique actions; here we're # going over all of the actions for each batch instance so we can map them to the global # action ids. action_vocab = self.vocab.get_token_to_index_vocabulary(self._rule_namespace) action_map: Dict[Tuple[int, int], int] = {} for batch_index, instance_actions in enumerate(actions): for action_index, action in enumerate(instance_actions): if not action[0]: # This rule is padding. continue global_action_id = action_vocab.get(action[0], -1) action_map[(batch_index, action_index)] = global_action_id return embedded_actions, action_map @staticmethod def _map_entity_productions(linking_scores: torch.FloatTensor, worlds: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]]) -> Tuple[torch.Tensor, Dict[Tuple[int, int], int]]: """ Constructs a map from ``(batch_index, action_index)`` to ``(batch_index * entity_index)``. That is, some actions correspond to terminal productions of entities from our table. We need to find those actions and map them to their corresponding entity indices, where the entity index is its position in the list of entities returned by the ``world``. This list is what defines the second dimension of the ``linking_scores`` tensor, so we can use this index to look up linking scores for each action in that tensor. For easier processing later, the mapping that we return is `flattened` - we really want to map ``(batch_index, action_index)`` to ``(batch_index, entity_index)``, but we are going to have to use the result of this mapping to do ``index_selects`` on the ``linking_scores`` tensor. You can't do ``index_select`` with tuples, so we flatten ``linking_scores`` to have shape ``(batch_size * num_entities, num_question_tokens)``, and return shifted indices into this flattened tensor. Parameters ---------- linking_scores : ``torch.Tensor`` A tensor representing linking scores between each table entity and each question token. Has shape ``(batch_size, num_entities, num_question_tokens)``. worlds : ``List[WikiTablesWorld]`` The ``World`` for each batch instance. The ``World`` contains a reference to the ``TableKnowledgeGraph`` that defines the set of entities in the linking. actions : ``List[List[ProductionRuleArray]]`` The list of possible actions for each batch instance. Our action indices are defined in terms of this list, so we'll find entity productions in this list and map them to entity indices from the entity list we get from the ``World``. Returns ------- flattened_linking_scores : ``torch.Tensor`` A flattened version of ``linking_scores``, with shape ``(batch_size * num_entities, num_question_tokens)``. actions_to_entities : ``Dict[Tuple[int, int], int]`` A mapping from ``(batch_index, action_index)`` to ``(batch_size * num_entities)``, representing which action indices correspond to which entity indices in the returned ``flattened_linking_scores`` tensor. """ batch_size, num_entities, num_question_tokens = linking_scores.size() entity_map: Dict[Tuple[int, str], int] = {} for batch_index, world in enumerate(worlds): for entity_index, entity in enumerate(world.table_graph.entities): entity_map[(batch_index, entity)] = batch_index * num_entities + entity_index actions_to_entities: Dict[Tuple[int, int], int] = {} for batch_index, action_list in enumerate(actions): for action_index, action in enumerate(action_list): if not action[0]: # This action is padding. continue _, production = action[0].split(' -> ') entity_index = entity_map.get((batch_index, production), None) if entity_index is not None: actions_to_entities[(batch_index, action_index)] = entity_index flattened_linking_scores = linking_scores.view(batch_size * num_entities, num_question_tokens) return flattened_linking_scores, actions_to_entities @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``WikiTablesDecoderStep``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip(predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['question_attention'] = action_debug_info['question_attention'] instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict @classmethod def from_params(cls, vocab, params: Params) -> 'WikiTablesSemanticParser': question_embedder = TextFieldEmbedder.from_params(vocab, params.pop("question_embedder")) action_embedding_dim = params.pop_int("action_embedding_dim") encoder = Seq2SeqEncoder.from_params(params.pop("encoder")) entity_encoder = Seq2VecEncoder.from_params(params.pop('entity_encoder')) max_decoding_steps = params.pop_int("max_decoding_steps") mixture_feedforward_type = params.pop('mixture_feedforward', None) if mixture_feedforward_type is not None: mixture_feedforward = FeedForward.from_params(mixture_feedforward_type) else: mixture_feedforward = None decoder_beam_search = BeamSearch.from_params(params.pop("decoder_beam_search")) # If no attention function is specified, we should not use attention, not attention with # default similarity function. attention_function_type = params.pop("attention_function", None) if attention_function_type is not None: attention_function = SimilarityFunction.from_params(attention_function_type) else: attention_function = None dropout = params.pop_float('dropout', 0.0) num_linking_features = params.pop_int('num_linking_features', 8) rule_namespace = params.pop('rule_namespace', 'rule_labels') params.assert_empty(cls.__name__) return cls(vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, mixture_feedforward=mixture_feedforward, decoder_beam_search=decoder_beam_search, max_decoding_steps=max_decoding_steps, attention_function=attention_function, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace)
class NlvrDirectSemanticParser(NlvrSemanticParser): """ ``NlvrDirectSemanticParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack of logical form annotations by maximizing the marginal likelihood of an approximate set of target sequences that yield the correct denotation. The main difference between this parser and ``NlvrCoverageSemanticParser`` is that while this parser takes the output of an offline search process as the set of target sequences for training, the latter performs search during training. Parameters ---------- vocab : ``Vocabulary`` Passed to super-class. sentence_embedder : ``TextFieldEmbedder`` Passed to super-class. action_embedding_dim : ``int`` Passed to super-class. encoder : ``Seq2SeqEncoder`` Passed to super-class. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the DecoderStep. decoder_beam_search : ``BeamSearch`` Beam search used to retrieve best sequences after training. max_decoding_steps : ``int`` Maximum number of steps for beam search after training. dropout : ``float``, optional (default=0.0) Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions. """ def __init__(self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, attention: Attention, decoder_beam_search: BeamSearch, max_decoding_steps: int, dropout: float = 0.0) -> None: super(NlvrDirectSemanticParser, self).__init__(vocab=vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, dropout=dropout) self._decoder_trainer = MaximumMarginalLikelihood() self._decoder_step = NlvrDecoderStep(encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, dropout=dropout) self._decoder_beam_search = decoder_beam_search self._max_decoding_steps = max_decoding_steps self._action_padding_index = -1 @overrides def forward(self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRuleArray]], identifier: List[str] = None, target_action_sequences: torch.LongTensor = None, labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihod over a set of approximate logical forms. """ batch_size = len(worlds) action_embeddings, action_indices = self._embed_actions(actions) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size)] label_strings = self._get_label_strings(labels) if labels is not None else None # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)] worlds_list = [worlds[i] for i in range(batch_size)] initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, worlds=worlds_list, label_strings=label_strings) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} if identifier is not None: outputs["identifier"] = identifier if target_action_sequences is not None: outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) best_final_states = self._decoder_beam_search.search(self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [best_final_states[i][0].action_history[0]] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings(actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if target_action_sequences is not None: self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings) else: outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs def _update_metrics(self, action_strings: List[List[List[str]]], worlds: List[List[NlvrWorld]], label_strings: List[List[str]]) -> None: # TODO(pradeep): Move this to the base class. # TODO(pradeep): Using only the best decoded sequence. Define metrics for top-k sequences? batch_size = len(worlds) for i in range(batch_size): instance_action_strings = action_strings[i] sequence_is_correct = [False] if instance_action_strings: instance_label_strings = label_strings[i] instance_worlds = worlds[i] # Taking only the best sequence. sequence_is_correct = self._check_denotation(instance_action_strings[0], instance_label_strings, instance_worlds) for correct_in_world in sequence_is_correct: self._denotation_accuracy(1 if correct_in_world else 0) self._consistency(1 if all(sequence_is_correct) else 0) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'denotation_accuracy': self._denotation_accuracy.get_metric(reset), 'consistency': self._consistency.get_metric(reset) }
class NlvrDirectSemanticParser(NlvrSemanticParser): u""" ``NlvrDirectSemanticParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack of logical form annotations by maximizing the marginal likelihood of an approximate set of target sequences that yield the correct denotation. The main difference between this parser and ``NlvrCoverageSemanticParser`` is that while this parser takes the output of an offline search process as the set of target sequences for training, the latter performs search during training. Parameters ---------- vocab : ``Vocabulary`` Passed to super-class. sentence_embedder : ``TextFieldEmbedder`` Passed to super-class. action_embedding_dim : ``int`` Passed to super-class. encoder : ``Seq2SeqEncoder`` Passed to super-class. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the DecoderStep. decoder_beam_search : ``BeamSearch`` Beam search used to retrieve best sequences after training. max_decoding_steps : ``int`` Maximum number of steps for beam search after training. dropout : ``float``, optional (default=0.0) Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions. """ def __init__(self, vocab, sentence_embedder, action_embedding_dim, encoder, attention, decoder_beam_search, max_decoding_steps, dropout=0.0): super(NlvrDirectSemanticParser, self).__init__(vocab=vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, dropout=dropout) self._decoder_trainer = MaximumMarginalLikelihood() self._decoder_step = NlvrDecoderStep( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, dropout=dropout) self._decoder_beam_search = decoder_beam_search self._max_decoding_steps = max_decoding_steps self._action_padding_index = -1 #overrides def forward( self, # type: ignore sentence, worlds, actions, identifier=None, target_action_sequences=None, labels=None): # pylint: disable=arguments-differ u""" Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihod over a set of approximate logical forms. """ batch_size = len(worlds) action_embeddings, action_indices = self._embed_actions(actions) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ iter(list(sentence.values())).next().new_zeros(1, dtype=torch.float) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] worlds_list = [worlds[i] for i in range(batch_size)] initial_state = NlvrDecoderState( batch_indices=range(batch_size), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, worlds=worlds_list, label_strings=label_strings) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None outputs = {} if identifier is not None: outputs[u"identifier"] = identifier if target_action_sequences is not None: outputs = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) best_final_states = self._decoder_beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [ best_final_states[i][0].action_history[0] ] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if target_action_sequences is not None: self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings) else: outputs[u"best_action_strings"] = batch_action_strings outputs[u"denotations"] = batch_denotations return outputs def _update_metrics(self, action_strings, worlds, label_strings): # TODO(pradeep): Move this to the base class. # TODO(pradeep): Using only the best decoded sequence. Define metrics for top-k sequences? batch_size = len(worlds) for i in range(batch_size): instance_action_strings = action_strings[i] sequence_is_correct = [False] if instance_action_strings: instance_label_strings = label_strings[i] instance_worlds = worlds[i] # Taking only the best sequence. sequence_is_correct = self._check_denotation( instance_action_strings[0], instance_label_strings, instance_worlds) for correct_in_world in sequence_is_correct: self._denotation_accuracy(1 if correct_in_world else 0) self._consistency(1 if all(sequence_is_correct) else 0) #overrides def get_metrics(self, reset=False): return { u'denotation_accuracy': self._denotation_accuracy.get_metric(reset), u'consistency': self._consistency.get_metric(reset) }