def _get_initial_state( self, encoder_outputs: torch.Tensor, utterance_mask: torch.Tensor, actions: List[ProductionRule], ) -> GrammarBasedState: batch_size = encoder_outputs.size(0) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) # Use CLS states as final encoder outputs # final_encoder_output = encoder_outputs[:,0,:] memory_cell = encoder_outputs.new_zeros(batch_size, encoder_outputs.shape[-1]) initial_score = encoder_outputs.data.new_zeros(batch_size) attended_sentence, _ = self._transition_function.attend_on_question( final_encoder_output, encoder_outputs, utterance_mask) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): if self._decoder_num_layers > 1: encoder_output = final_encoder_output[i].repeat( self._decoder_num_layers, 1) cell = memory_cell[i].repeat(self._decoder_num_layers, 1) else: encoder_output = final_encoder_output[i] cell = memory_cell[i] initial_rnn_state.append( RnnStatelet( encoder_output, cell, self._first_action_embedding, attended_sentence[i], encoder_output_list, utterance_mask_list, )) initial_grammar_state = [ self._create_grammar_state(actions[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=[[] for _ in range(batch_size)], ) return initial_state
def _get_initial_state(self, utterance: Dict[str, torch.LongTensor], worlds: List[AtisWorld], actions: List[List[ProductionRuleArray]], linking_scores: torch.Tensor) -> GrammarBasedState: embedded_utterance = self._utterance_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size = embedded_utterance.size(0) num_entities = max([len(world.entities) for world in worlds]) # entity_types: tensor with shape (batch_size, num_entities) entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance) # (batch_size, num_utterance_tokens, embedding_dim) encoder_input = embedded_utterance # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, utterance_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state(worlds[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None) return initial_state
def _check_state_denotations(self, state: GrammarBasedState, worlds: List[NlvrWorld]) -> List[bool]: """ Returns whether action history in the state evaluates to the correct denotations over all worlds. Only defined when the state is finished. """ assert state.is_finished(), "Cannot compute denotations for unfinished states!" # Since this is a finished state, its group size must be 1. batch_index = state.batch_indices[0] instance_label_strings = state.extras[batch_index] history = state.action_history[0] all_actions = state.possible_actions[0] action_sequence = [all_actions[action][0] for action in history] return self._check_denotation(action_sequence, instance_label_strings, worlds)
def _check_state_denotations(self, state: GrammarBasedState, worlds: List[NlvrWorld]) -> List[bool]: """ Returns whether action history in the state evaluates to the correct denotations over all worlds. Only defined when the state is finished. """ assert state.is_finished( ), "Cannot compute denotations for unfinished states!" # Since this is a finished state, its group size must be 1. batch_index = state.batch_indices[0] instance_label_strings = state.extras[batch_index] history = state.action_history[0] all_actions = state.possible_actions[0] action_sequence = [all_actions[action][0] for action in history] return self._check_denotation(action_sequence, instance_label_strings, worlds)
def _get_initial_state( self, encoder_outputs: torch.Tensor, mask: torch.Tensor, actions: List[List[ProductionRule]]) -> GrammarBasedState: batch_size = encoder_outputs.size(0) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = encoder_outputs.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet( final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list, )) initial_grammar_state = [ self._create_grammar_state(actions[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None, ) return initial_state
def _compute_action_probabilities( self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # We take a couple of extra arguments here because subclasses might use them. # pylint: disable=unused-argument,no-self-use # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[ group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions[ 'global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _compute_action_probabilities(self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # We take a couple of extra arguments here because subclasses might use them. # pylint: disable=unused-argument,no-self-use # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions['global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append((group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _compute_action_probabilities( self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[ group_index] embedded_actions: List[int] = [] output_action_embeddings = None embedded_action_logits = None current_log_probs = None if 'global' in instance_actions: action_embeddings, output_action_embeddings, embedded_actions = instance_actions[ 'global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. embedded_action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1)).squeeze(-1) action_ids = embedded_actions if 'linked' in instance_actions: linking_scores, type_embeddings, linked_actions = instance_actions[ 'linked'] action_ids = embedded_actions + linked_actions # (num_question_tokens, 1) linked_action_logits = linking_scores.mm( attention_weights[group_index].unsqueeze(-1)).squeeze(-1) # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use # the entity type instead. if output_action_embeddings is not None: output_action_embeddings = torch.cat( [output_action_embeddings, type_embeddings], dim=0) else: output_action_embeddings = type_embeddings if self._mixture_feedforward is not None: # The linked and global logits are combined with a mixture weight to prevent the # linked_action_logits from dominating the embedded_action_logits if a softmax # was applied on both together. mixture_weight = self._mixture_feedforward( hidden_state[group_index]) mix1 = torch.log(mixture_weight) mix2 = torch.log(1 - mixture_weight) entity_action_probs = torch.nn.functional.log_softmax( linked_action_logits, dim=-1) + mix1 if embedded_action_logits is not None: embedded_action_probs = torch.nn.functional.log_softmax( embedded_action_logits, dim=-1) + mix2 current_log_probs = torch.cat( [embedded_action_probs, entity_action_probs], dim=-1) else: current_log_probs = entity_action_probs else: if embedded_action_logits is not None: action_logits = torch.cat( [embedded_action_logits, linked_action_logits], dim=-1) else: action_logits = linked_action_logits current_log_probs = torch.nn.functional.log_softmax( action_logits, dim=-1) else: action_logits = embedded_action_logits current_log_probs = torch.nn.functional.log_softmax( action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[QuarelWorld], actions: List[List[ProductionRule]], entity_bits: torch.Tensor = None, denotation_target: torch.Tensor = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[QuarelWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. entity_bits : ``torch.Tensor``, optional (default=None) Tensor encoding bits for the world entities. denotation_target : ``torch.Tensor``, optional (default=None) If model's field ``denotation_only`` is True, this is the tensor target denotation. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. metadata : List[Dict[str, Any]], optional (default=None). A dictionary of metadata for each batch element which has keys: question_tokens : ``List[str]``, optional. The original string tokens in the question. world_extractions : ``nltk.Tree``, optional. Extracted worlds from the question. answer_index : ``List[str]``, optional. Index of the correct answer. """ table_text = table['text'] self._debug_count -= 1 # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, embedded_table) if self._use_entities: if self._entity_similarity_mode == "dot_product": # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / ( embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / ( embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score elif self._entity_similarity_mode == "weighted_dot_product": embedded_table = embedded_table / ( embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / ( embedded_question.norm(dim=-1, keepdim=True) + 1e-13) eqe = embedded_question.unsqueeze(1).expand( -1, num_entities * num_entity_tokens, -1, -1) ete = embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim) ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1) product = torch.mul(eqe, ete) product = product.view( batch_size, num_question_tokens * num_entities * num_entity_tokens, self._embedding_dim) question_entity_similarity = self._entity_similarity_layer( product) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] if self._linking_params is not None: feature_scores = self._linking_params( linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) encoder_input = embedded_question else: if entity_bits is not None and not self._entity_bits_output: encoder_input = torch.cat([embedded_question, entity_bits], 2) else: encoder_input = embedded_question # Fake linking_scores added for downstream code to not object linking_scores = question_mask.clone().fill_(0).unsqueeze(1) linking_probabilities = None # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) if self._entity_bits_output and entity_bits is not None: encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) # For predicting a categorical denotation directly if self._denotation_only: denotation_logits = self._denotation_classifier( final_encoder_output) loss = torch.nn.functional.cross_entropy( denotation_logits, denotation_target.view(-1)) self._denotation_accuracy_cat(denotation_logits, denotation_target) return {"loss": loss} memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim) _, num_entities, num_question_tokens = linking_scores.size() if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=None, debug_info=None) if self.training: outputs = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) return outputs else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] if self._linking_params is not None: outputs['linking_scores'] = linking_scores outputs['feature_scores'] = feature_scores outputs['linking_features'] = linking_features if self._use_entities: outputs['linking_probabilities'] = linking_probabilities if entity_bits is not None: outputs['entity_bits'] = entity_bits # outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] outputs['denotation_acc'] = [] outputs['score'] = [] outputs['parse_acc'] = [] outputs['answer_index'] = [] if metadata is not None: outputs['question_tokens'] = [] outputs['world_extractions'] = [] for i in range(batch_size): if metadata is not None: outputs['question_tokens'].append(metadata[i].get( 'question_tokens', [])) if metadata is not None: outputs['world_extractions'].append(metadata[i].get( 'world_extractions', {})) outputs['entities'].append(world[i].table_graph.entities) # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][ 0].action_history[0] sequence_in_targets = 0 if target_action_sequences is not None: targets = target_action_sequences[i].data sequence_in_targets = self._action_history_match( best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [ action_mapping[(i, action_index)] for action_index in best_action_indices ] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form( action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' denotation_accuracy = 0.0 predicted_answer_index = world[i].execute(logical_form) if metadata is not None and 'answer_index' in metadata[i]: answer_index = metadata[i]['answer_index'] denotation_accuracy = self._denotation_match( predicted_answer_index, answer_index) self._denotation_accuracy(denotation_accuracy) score = math.exp( best_final_states[i][0].score[0].data.cpu().item()) outputs['answer_index'].append(predicted_answer_index) outputs['score'].append(score) outputs['parse_acc'].append(sequence_in_targets) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['denotation_acc'].append(denotation_accuracy) outputs['debug_info'].append( best_final_states[i][0].debug_info[0]) # type: ignore else: outputs['parse_acc'].append(0) outputs['logical_form'].append('') outputs['denotation_acc'].append(0) outputs['score'].append(0) outputs['answer_index'].append(-1) outputs['best_action_sequence'].append([]) outputs['debug_info'].append([]) self._has_logical_form(0.0) return outputs
def forward(self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRule]], identifier: List[str] = None, target_action_sequences: torch.LongTensor = None, labels: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihod over a set of approximate logical forms. """ batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size)] label_strings = self._get_label_strings(labels) if labels is not None else None # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} if identifier is not None: outputs["identifier"] = identifier if target_action_sequences is not None: outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._decoder_beam_search.search(self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [best_final_states[i][0].action_history[0]] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings(actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if target_action_sequences is not None: self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings) else: if metadata is not None: outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata] outputs['debug_info'] = [] for i in range(batch_size): outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs['action_mapping'] = action_mapping return outputs
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRule]], example_lisp_string: List[str] = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]``, optional (default = None) The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. target_action_sequences : torch.Tensor, optional (default = None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenized question within a 'question_tokens' key. """ outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), # type: ignore action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=actions, extras=example_lisp_string, debug_info=None) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None if self.training: return self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) else: if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] if target_action_sequences is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequences[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) self._compute_validation_outputs(actions, best_final_states, world, example_lisp_string, metadata, outputs) return outputs
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesVariableFreeWorld], actions: List[List[ProductionRuleArray]], target_values: List[List[str]] = None, target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. target_values : ``List[List[str]]``, optional (default = None) For each instance, a list of target values taken from the example lisp string. We pass this list to the evaluator along with logical forms to compute denotation accuracy. target_action_sequences : torch.Tensor, optional (default = None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), # type: ignore action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=actions, extras=target_values, debug_info=None) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None if self.training: return self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) else: if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] if target_action_sequences is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequences[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) metadata = None self._compute_validation_outputs(actions, best_final_states, world, target_values, metadata, outputs) return outputs
def _take_first_step( self, state: GrammarBasedState, allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]: # We'll just do a projection from the current hidden state (which was initialized with the # final encoder output) to the number of start actions that we have, normalize those # logits, and use that as our score. We end up duplicating some of the logic from # `_compute_new_states` here, but we do things slightly differently, and it's easier to # just copy the parts we need than to try to re-use that code. # (group_size, hidden_dim) hidden_state = torch.stack( [rnn_state.hidden_state for rnn_state in state.rnn_state]) # (group_size, num_start_type) start_action_logits = self._start_type_predictor(hidden_state) log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1) sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True) sorted_actions = sorted_actions.detach().cpu().numpy().tolist() if state.debug_info is not None: probs_cpu = log_probs.exp().detach().cpu().numpy().tolist() else: probs_cpu = [None] * len(state.batch_indices) # state.get_valid_actions() will return a list that is consistently sorted, so as along as # the set of valid start actions never changes, we can just match up the log prob indices # above with the position of each considered action, and we're good. valid_actions = state.get_valid_actions() considered_actions = [ actions['global'][2] for actions in valid_actions ] if len(considered_actions[0]) != self._num_start_types: raise RuntimeError( "Calculated wrong number of initial actions. Expected " f"{self._num_start_types}, found {len(considered_actions[0])}." ) best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list) for group_index, (batch_index, group_actions) in enumerate( zip(state.batch_indices, sorted_actions)): for action_index, action in enumerate(group_actions): # `action` is currently the index in `log_probs`, not the actual action ID. To get # the action ID, we need to go through `considered_actions`. action = considered_actions[group_index][action] if allowed_actions is not None and action not in allowed_actions[ group_index]: # This happens when our _decoder trainer_ wants us to only evaluate certain # actions, likely because they are the gold actions in this state. We just skip # emitting any state that isn't allowed by the trainer, because constructing the # new state can be expensive. continue best_next_states[batch_index].append( (group_index, action_index, action)) new_states = [] for batch_index, best_states in sorted(best_next_states.items()): for group_index, action_index, action in best_states: # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just # regroup them later, as that's a really easy operation. new_score = state.score[group_index] + sorted_log_probs[ group_index, action_index] # This part is different from `_compute_new_states` - we're just passing through # the previous RNN state, as predicting the start type wasn't included in the # decoder RNN in the original model. new_rnn_state = state.rnn_state[group_index] new_state = state.new_state_from_group_index( group_index, action, new_score, new_rnn_state, considered_actions[group_index], probs_cpu[group_index], None) new_states.append(new_state) return new_states
def setUp(self): super().setUp() self.decoder_step = BasicTransitionFunction( encoder_output_dim=2, action_embedding_dim=2, input_attention=Attention.by_name('dot_product')(), num_start_types=3, add_action_bias=False) batch_indices = [0, 1, 0] action_history = [[1], [3, 4], []] score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]] hidden_state = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) memory_cell = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) previous_action_embedding = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) attended_question = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) # This maps non-terminals to valid actions, where the valid actions are grouped by _type_. # We have "global" actions, which are from the global grammar, and "linked" actions, which # are instance-specific and are generated based on question attention. Each action type # has a tuple which is (input representation, output representation, action ids). valid_actions = { 'e': { 'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]), torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]), [0, 1, 2]), 'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]), torch.FloatTensor([[3, 3], [4, 4]]), [3, 4]) }, 'd': { 'global': (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1, -1]]), [0]), 'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6], [-.7, -.8, -.9]]), torch.FloatTensor([[5, 5], [6, 6], [7, 7]]), [1, 2, 3]) } } grammar_state = [ GrammarStatelet([nonterminal], {}, valid_actions, {}, is_nonterminal) for _, nonterminal in zip(batch_indices, ['e', 'd', 'e']) ] self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[10, 11], [12, 13], [14, 15]]]) self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]) self.possible_actions = [[ ('e -> f', False, None), ('e -> g', True, None), ('e -> h', True, None), ('e -> i', True, None), ('e -> j', True, None) ], [ ('d -> q', True, None), ('d -> g', True, None), ('d -> h', True, None), ('d -> i', True, None) ]] rnn_state = [] for i in range(len(batch_indices)): rnn_state.append( RnnStatelet(hidden_state[i], memory_cell[i], previous_action_embedding[i], attended_question[i], self.encoder_outputs, self.encoder_output_mask)) self.state = GrammarBasedState(batch_indices=batch_indices, action_history=action_history, score=score, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=self.possible_actions)
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrLanguage]], actions: List[List[ProductionRule]], identifier: List[str] = None, target_action_sequences: torch.LongTensor = None, labels: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihod over a set of approximate logical forms. """ batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} if identifier is not None: outputs["identifier"] = identifier if target_action_sequences is not None: outputs = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._decoder_beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [ best_final_states[i][0].action_history[0] ] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if target_action_sequences is not None: self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings) else: if metadata is not None: outputs["sentence_tokens"] = [ x["sentence_tokens"] for x in metadata ] outputs['debug_info'] = [] for i in range(batch_size): outputs['debug_info'].append( best_final_states[i][0].debug_info[0]) # type: ignore outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs['action_mapping'] = action_mapping return outputs
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[QuarelWorld], actions: List[List[ProductionRule]], entity_bits: torch.Tensor = None, denotation_target: torch.Tensor = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[QuarelWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ table_text = table['text'] self._debug_count -= 1 # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table) if self._use_entities: if self._entity_similarity_mode == "dot_product": # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score elif self._entity_similarity_mode == "weighted_dot_product": embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1) ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim) ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1) product = torch.mul(eqe, ete) product = product.view(batch_size, num_question_tokens*num_entities*num_entity_tokens, self._embedding_dim) question_entity_similarity = self._entity_similarity_layer(product) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) encoder_input = embedded_question else: if entity_bits is not None and not self._entity_bits_output: encoder_input = torch.cat([embedded_question, entity_bits], 2) else: encoder_input = embedded_question # Fake linking_scores added for downstream code to not object linking_scores = question_mask.clone().fill_(0).unsqueeze(1) linking_probabilities = None # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) if self._entity_bits_output and entity_bits is not None: encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) # For predicting a categorical denotation directly if self._denotation_only: denotation_logits = self._denotation_classifier(final_encoder_output) loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1)) self._denotation_accuracy_cat(denotation_logits, denotation_target) return {"loss": loss} memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim) _, num_entities, num_question_tokens = linking_scores.size() if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=None, debug_info=None) if self.training: outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) return outputs else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] if self._linking_params is not None: outputs['linking_scores'] = linking_scores outputs['feature_scores'] = feature_scores outputs['linking_features'] = linking_features if self._use_entities: outputs['linking_probabilities'] = linking_probabilities if entity_bits is not None: outputs['entity_bits'] = entity_bits # outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] outputs['denotation_acc'] = [] outputs['score'] = [] outputs['parse_acc'] = [] outputs['answer_index'] = [] if metadata is not None: outputs['question_tokens'] = [] outputs['world_extractions'] = [] for i in range(batch_size): if metadata is not None: outputs['question_tokens'].append(metadata[i].get('question_tokens', [])) if metadata is not None: outputs['world_extractions'].append(metadata[i].get('world_extractions', {})) outputs['entities'].append(world[i].table_graph.entities) # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] sequence_in_targets = 0 if target_action_sequences is not None: targets = target_action_sequences[i].data sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form(action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' denotation_accuracy = 0.0 predicted_answer_index = world[i].execute(logical_form) if metadata is not None and 'answer_index' in metadata[i]: answer_index = metadata[i]['answer_index'] denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index) self._denotation_accuracy(denotation_accuracy) score = math.exp(best_final_states[i][0].score[0].data.cpu().item()) outputs['answer_index'].append(predicted_answer_index) outputs['score'].append(score) outputs['parse_acc'].append(sequence_in_targets) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['denotation_acc'].append(denotation_accuracy) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore else: outputs['parse_acc'].append(0) outputs['logical_form'].append('') outputs['denotation_acc'].append(0) outputs['score'].append(0) outputs['answer_index'].append(-1) outputs['best_action_sequence'].append([]) outputs['debug_info'].append([]) self._has_logical_form(0.0) return outputs
def forward( self, question: Dict[str, torch.LongTensor], question_predicates, # labelled_results, world: List[LCQuADLanguage], actions: List[List[ProductionRule]], question_entities=None, target_action_sequences: torch.LongTensor = None, labels: torch.LongTensor = None, logical_forms=None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihood over a set of approximate logical forms. """ assert target_action_sequences is not None batch_size = question['tokens'].size()[0] # Remove the trailing dimension (from ListField[ListField[IndexField]]). # assert target_action_sequences.dim() == 3 target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index # if self._kg_embedder: # embedded_entities = self._kg_embedder(question_entities, input_type="entity") # embedded_type_entities = self._kg_embedder(question_type_entities, input_type="entity") # embedded_predicates = self._kg_embedder(question_predicates, input_type="predicate") initial_rnn_state = self._get_initial_rnn_state(question) initial_score_list = [ next(iter(question.values())).new_zeros(1, dtype=torch.float) for _ in range(batch_size) ] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_statelet = [ self._create_grammar_state(world[i], actions[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_statelet, possible_actions=actions) outputs = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._decoder_beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [ best_final_states[i][0].action_history[0] ] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings( actions, best_action_sequences) # self._update_metrics(action_strings=batch_action_strings, # worlds=world, # labelled_results=labelled_results) debug_infos = [] for i in range(batch_size): debug_infos.append(best_final_states[i][0].debug_info[0]) action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] self._update_seq_metrics(action_strings=batch_action_strings, worlds=world, gold_logical_forms=logical_forms, train=self.training) outputs["predicted queries"] = batch_action_strings best_actions = batch_action_strings batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate( zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip( predicted_actions[0], debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info[ 'considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append( (action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['question_attention'] = action_debug_info.get( 'question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) outputs["predicted_actions"] = batch_action_info return outputs
def setUp(self): super().setUp() self.decoder_step = BasicTransitionFunction( encoder_output_dim=2, action_embedding_dim=2, input_attention=Attention.by_name("dot_product")(), add_action_bias=False, ) batch_indices = [0, 1, 0] action_history = [[1], [3, 4], []] score = [torch.FloatTensor([x]) for x in [0.1, 1.1, 2.2]] hidden_state = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) memory_cell = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) previous_action_embedding = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) attended_question = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) # This maps non-terminals to valid actions, where the valid actions are grouped by _type_. # We have "global" actions, which are from the global grammar, and "linked" actions, which # are instance-specific and are generated based on question attention. Each action type # has a tuple which is (input representation, output representation, action ids). valid_actions = { "e": { "global": ( torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]), torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]), [0, 1, 2], ), "linked": ( torch.FloatTensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), torch.FloatTensor([[3, 3], [4, 4]]), [3, 4], ), }, "d": { "global": (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1, -1]]), [0]), "linked": ( torch.FloatTensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6], [-0.7, -0.8, -0.9]]), torch.FloatTensor([[5, 5], [6, 6], [7, 7]]), [1, 2, 3], ), }, } grammar_state = [ GrammarStatelet([nonterminal], valid_actions, is_nonterminal) for _, nonterminal in zip(batch_indices, ["e", "d", "e"]) ] self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[10, 11], [12, 13], [14, 15]]]) self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]) self.possible_actions = [ [ ("e -> f", False, None), ("e -> g", True, None), ("e -> h", True, None), ("e -> i", True, None), ("e -> j", True, None), ], [ ("d -> q", True, None), ("d -> g", True, None), ("d -> h", True, None), ("d -> i", True, None), ], ] rnn_state = [] for i in range(len(batch_indices)): rnn_state.append( RnnStatelet( hidden_state[i], memory_cell[i], previous_action_embedding[i], attended_question[i], self.encoder_outputs, self.encoder_output_mask, )) self.state = GrammarBasedState( batch_indices=batch_indices, action_history=action_history, score=score, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=self.possible_actions, )
def _take_first_step(self, state: GrammarBasedState, allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]: # We'll just do a projection from the current hidden state (which was initialized with the # final encoder output) to the number of start actions that we have, normalize those # logits, and use that as our score. We end up duplicating some of the logic from # `_compute_new_states` here, but we do things slightly differently, and it's easier to # just copy the parts we need than to try to re-use that code. # (group_size, hidden_dim) hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state]) # (group_size, num_start_type) start_action_logits = self._start_type_predictor(hidden_state) log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1) sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True) sorted_actions = sorted_actions.detach().cpu().numpy().tolist() if state.debug_info is not None: probs_cpu = log_probs.exp().detach().cpu().numpy().tolist() else: probs_cpu = [None] * len(state.batch_indices) # state.get_valid_actions() will return a list that is consistently sorted, so as along as # the set of valid start actions never changes, we can just match up the log prob indices # above with the position of each considered action, and we're good. valid_actions = state.get_valid_actions() considered_actions = [actions['global'][2] for actions in valid_actions] if len(considered_actions[0]) != self._num_start_types: raise RuntimeError("Calculated wrong number of initial actions. Expected " f"{self._num_start_types}, found {len(considered_actions[0])}.") best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list) for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)): for action_index, action in enumerate(group_actions): # `action` is currently the index in `log_probs`, not the actual action ID. To get # the action ID, we need to go through `considered_actions`. action = considered_actions[group_index][action] if allowed_actions is not None and action not in allowed_actions[group_index]: # This happens when our _decoder trainer_ wants us to only evaluate certain # actions, likely because they are the gold actions in this state. We just skip # emitting any state that isn't allowed by the trainer, because constructing the # new state can be expensive. continue best_next_states[batch_index].append((group_index, action_index, action)) new_states = [] for batch_index, best_states in sorted(best_next_states.items()): for group_index, action_index, action in best_states: # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just # regroup them later, as that's a really easy operation. new_score = state.score[group_index] + sorted_log_probs[group_index, action_index] # This part is different from `_compute_new_states` - we're just passing through # the previous RNN state, as predicting the start type wasn't included in the # decoder RNN in the original model. new_rnn_state = state.rnn_state[group_index] new_state = state.new_state_from_group_index(group_index, action, new_score, new_rnn_state, considered_actions[group_index], probs_cpu[group_index], None) new_states.append(new_state) return new_states