def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]): embedded_input = self._sentence_embedder(sentence) # (batch_size, sentence_length) sentence_mask = util.get_text_field_mask(sentence).float() batch_size = embedded_input.size(0) # (batch_size, sentence_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(embedded_input, sentence_mask)) final_encoder_output = util.get_final_encoder_states( encoder_outputs, sentence_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) attended_sentence, _ = self._decoder_step.attend_on_question( final_encoder_output, encoder_outputs, sentence_mask) encoder_outputs_list = [encoder_outputs[i] for i in range(batch_size)] sentence_mask_list = [sentence_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, attended_sentence[i], encoder_outputs_list, sentence_mask_list)) return initial_rnn_state
def make_state(group_index: int, action: int, new_score: torch.Tensor, action_embedding: torch.Tensor) -> GrammarBasedState: new_rnn_state = RnnStatelet( hidden_state[group_index], memory_cell[group_index], action_embedding, attended_question[group_index], state.rnn_state[group_index].encoder_outputs, state.rnn_state[group_index].encoder_output_mask, ) batch_index = state.batch_indices[group_index] for i, _, current_log_probs, _, actions in batch_action_probs[ batch_index]: if i == group_index: considered_actions = actions probabilities = current_log_probs.exp().cpu() break return state.new_state_from_group_index( group_index, action, new_score, new_rnn_state, considered_actions, probabilities, updated_rnn_state["attention_weights"], )
def _get_initial_state( self, encoder_outputs: torch.Tensor, mask: torch.Tensor, actions: List[List[ProductionRule]]) -> GrammarBasedState: batch_size = encoder_outputs.size(0) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = encoder_outputs.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet( final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list, )) initial_grammar_state = [ self._create_grammar_state(actions[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None, ) return initial_state
def _get_initial_rnn_state( self, question_encoded: torch.FloatTensor, question_mask: torch.Tensor, question_encoded_finalstate: torch.FloatTensor, question_encoded_aslist: List[torch.Tensor], question_mask_aslist: List[torch.Tensor], ): """ Get the initial RnnStatelet for the decoder based on the question encoding Parameters: ----------- ques_repr: (B, question_length, D) ques_mask: (B, question_length) question_final_repr: (B, D) ques_encoded_list: [(question_length, D)] - List of length B ques_mask_list: [(question_length)] - List of length B """ batch_size = question_encoded_finalstate.size(0) ques_encoded_dim = question_encoded_finalstate.size()[-1] # Shape: (B, D) memory_cell = question_encoded_finalstate.new_zeros( batch_size, ques_encoded_dim) # TODO(nitish): Why does WikiTablesParser use '_first_attended_question' embedding and not this attended_sentence, _ = self._decoder_step.attend_on_question( question_encoded_finalstate, question_encoded, question_mask) initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet( question_encoded_finalstate[i], memory_cell[i], self._first_action_embedding, attended_sentence[i], question_encoded_aslist, question_mask_aslist, )) return initial_rnn_state
def _get_initial_rnn_and_grammar_state( self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesLanguage], actions: List[List[ProductionRuleArray]], outputs: Dict[str, Any], ) -> Tuple[List[RnnStatelet], List[GrammarStatelet]]: """ Encodes the question and table, computes a linking between the two, and constructs an initial RnnStatelet and GrammarStatelet for each batch instance to pass to the decoder. We take ``outputs`` as a parameter here and `modify` it, adding things that we want to visualize in a demo. """ table_text = table["text"] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, encoded_table) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) # (batch_size, num_entities, num_neighbors) or None neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) if neighbor_indices is not None: # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { "ignored": { "ignored": neighbor_indices + 1 } }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) else: # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2), ) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table["linking"] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select( question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max( question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze( -1) linking_scores = (projected_question_entity_similarity + projected_question_neighbor_similarity) feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet( final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list, )) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] if not self.training: # We add a few things to the outputs that will be returned from `forward` at evaluation # time, for visualization in a demo. outputs["linking_scores"] = linking_scores if feature_scores is not None: outputs["feature_scores"] = feature_scores outputs["similarity_scores"] = question_entity_similarity_max_score return initial_rnn_state, initial_grammar_state
def _get_initial_state( self, utterance: Dict[str, torch.LongTensor], worlds: List[AtisWorld], actions: List[List[ProductionRule]], linking_scores: torch.Tensor, ) -> GrammarBasedState: embedded_utterance = self._utterance_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance) batch_size = embedded_utterance.size(0) num_entities = max([len(world.entities) for world in worlds]) # entity_types: tensor with shape (batch_size, num_entities) entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance) # (batch_size, num_utterance_tokens, embedding_dim) encoder_input = embedded_utterance # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, utterance_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): if self._decoder_num_layers > 1: initial_rnn_state.append( RnnStatelet( final_encoder_output[i].repeat( self._decoder_num_layers, 1), memory_cell[i].repeat(self._decoder_num_layers, 1), self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list, )) else: initial_rnn_state.append( RnnStatelet( final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list, )) initial_grammar_state = [ self._create_grammar_state(worlds[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None, ) return initial_state
def setUp(self): super().setUp() self.decoder_step = BasicTransitionFunction( encoder_output_dim=2, action_embedding_dim=2, input_attention=Attention.by_name("dot_product")(), add_action_bias=False, ) batch_indices = [0, 1, 0] action_history = [[1], [3, 4], []] score = [torch.FloatTensor([x]) for x in [0.1, 1.1, 2.2]] hidden_state = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) memory_cell = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) previous_action_embedding = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) attended_question = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) # This maps non-terminals to valid actions, where the valid actions are grouped by _type_. # We have "global" actions, which are from the global grammar, and "linked" actions, which # are instance-specific and are generated based on question attention. Each action type # has a tuple which is (input representation, output representation, action ids). valid_actions = { "e": { "global": ( torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]), torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]), [0, 1, 2], ), "linked": ( torch.FloatTensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), torch.FloatTensor([[3, 3], [4, 4]]), [3, 4], ), }, "d": { "global": (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1, -1]]), [0]), "linked": ( torch.FloatTensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6], [-0.7, -0.8, -0.9]]), torch.FloatTensor([[5, 5], [6, 6], [7, 7]]), [1, 2, 3], ), }, } grammar_state = [ GrammarStatelet([nonterminal], valid_actions, is_nonterminal) for _, nonterminal in zip(batch_indices, ["e", "d", "e"]) ] self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[10, 11], [12, 13], [14, 15]]]) self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]) self.possible_actions = [ [ ("e -> f", False, None), ("e -> g", True, None), ("e -> h", True, None), ("e -> i", True, None), ("e -> j", True, None), ], [ ("d -> q", True, None), ("d -> g", True, None), ("d -> h", True, None), ("d -> i", True, None), ], ] rnn_state = [] for i in range(len(batch_indices)): rnn_state.append( RnnStatelet( hidden_state[i], memory_cell[i], previous_action_embedding[i], attended_question[i], self.encoder_outputs, self.encoder_output_mask, )) self.state = GrammarBasedState( batch_indices=batch_indices, action_history=action_history, score=score, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=self.possible_actions, )
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[QuarelWorld], actions: List[List[ProductionRule]], entity_bits: torch.Tensor = None, denotation_target: torch.Tensor = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[QuarelWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. entity_bits : ``torch.Tensor``, optional (default=None) Tensor encoding bits for the world entities. denotation_target : ``torch.Tensor``, optional (default=None) If model's field ``denotation_only`` is True, this is the tensor target denotation. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. metadata : List[Dict[str, Any]], optional (default=None). A dictionary of metadata for each batch element which has keys: question_tokens : ``List[str]``, optional. The original string tokens in the question. world_extractions : ``nltk.Tree``, optional. Extracted worlds from the question. answer_index : ``List[str]``, optional. Index of the correct answer. """ table_text = table['text'] self._debug_count -= 1 # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table) if self._use_entities: if self._entity_similarity_mode == "dot_product": # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score elif self._entity_similarity_mode == "weighted_dot_product": embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1) ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim) ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1) product = torch.mul(eqe, ete) product = product.view(batch_size, num_question_tokens*num_entities*num_entity_tokens, self._embedding_dim) question_entity_similarity = self._entity_similarity_layer(product) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) encoder_input = embedded_question else: if entity_bits is not None and not self._entity_bits_output: encoder_input = torch.cat([embedded_question, entity_bits], 2) else: encoder_input = embedded_question # Fake linking_scores added for downstream code to not object linking_scores = question_mask.clone().fill_(0).unsqueeze(1) linking_probabilities = None # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) if self._entity_bits_output and entity_bits is not None: encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) # For predicting a categorical denotation directly if self._denotation_only: denotation_logits = self._denotation_classifier(final_encoder_output) loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1)) self._denotation_accuracy_cat(denotation_logits, denotation_target) return {"loss": loss} memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim) _, num_entities, num_question_tokens = linking_scores.size() if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=None, debug_info=None) if self.training: outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) return outputs else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] if self._linking_params is not None: outputs['linking_scores'] = linking_scores outputs['feature_scores'] = feature_scores outputs['linking_features'] = linking_features if self._use_entities: outputs['linking_probabilities'] = linking_probabilities if entity_bits is not None: outputs['entity_bits'] = entity_bits # outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] outputs['denotation_acc'] = [] outputs['score'] = [] outputs['parse_acc'] = [] outputs['answer_index'] = [] if metadata is not None: outputs['question_tokens'] = [] outputs['world_extractions'] = [] for i in range(batch_size): if metadata is not None: outputs['question_tokens'].append(metadata[i].get('question_tokens', [])) if metadata is not None: outputs['world_extractions'].append(metadata[i].get('world_extractions', {})) outputs['entities'].append(world[i].table_graph.entities) # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] sequence_in_targets = 0 if target_action_sequences is not None: targets = target_action_sequences[i].data sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form(action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' denotation_accuracy = 0.0 predicted_answer_index = world[i].execute(logical_form) if metadata is not None and 'answer_index' in metadata[i]: answer_index = metadata[i]['answer_index'] denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index) self._denotation_accuracy(denotation_accuracy) score = math.exp(best_final_states[i][0].score[0].data.cpu().item()) outputs['answer_index'].append(predicted_answer_index) outputs['score'].append(score) outputs['parse_acc'].append(sequence_in_targets) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['denotation_acc'].append(denotation_accuracy) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore else: outputs['parse_acc'].append(0) outputs['logical_form'].append('') outputs['denotation_acc'].append(0) outputs['score'].append(0) outputs['answer_index'].append(-1) outputs['best_action_sequence'].append([]) outputs['debug_info'].append([]) self._has_logical_form(0.0) return outputs