def _get_predicted_embedding_addition( self, checklist_state: ChecklistState, action_ids: List[int], action_embeddings: torch.Tensor) -> torch.Tensor: """ Gets the embeddings of desired terminal actions yet to be produced by the decoder, and returns their sum for the decoder to add it to the predicted embedding to bias the prediction towards missing actions. """ # Our basic approach here will be to figure out which actions we want to bias, by doing # some fancy indexing work, then multiply the action embeddings by a mask for those # actions, and return the sum of the result. # Shape: (num_terminal_actions, 1). This is 1 if we still want to predict something on the # checklist, and 0 otherwise. checklist_balance = checklist_state.get_balance().clamp(min=0) # (num_terminal_actions, 1) actions_in_agenda = checklist_state.terminal_actions # (1, num_current_actions) action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze( 0) # Shape: (num_terminal_actions, num_current_actions). Will have a value of 1 if the # terminal action i is our current action j, and a value of 0 otherwise. Because both sets # of actions are free of duplicates, there will be at most one non-zero value per current # action, and per terminal action. current_agenda_actions = ( actions_in_agenda == action_id_tensor).float() # Shape: (num_current_actions,). With the inner multiplication, we remove any current # agenda actions that are not in our checklist balance, then we sum over the terminal # action dimension, which will have a sum of at most one. So this will be a 0/1 tensor, # where a 1 means to encourage the current action in that position. actions_to_encourage = torch.sum(current_agenda_actions * checklist_balance, dim=0) # Shape: (action_embedding_dim,). This is the sum of the action embeddings that we want # the model to prefer. embedding_addition = torch.sum(action_embeddings * actions_to_encourage.unsqueeze(1), dim=0, keepdim=False) if self._add_action_bias: # If we're adding an action bias, the last dimension of the action embedding is a bias # weight. We don't want this addition to affect the bias (TODO(mattg): or do we?), so # we zero out that dimension here. embedding_addition[-1] = 0 return embedding_addition
def _get_linked_logits_addition( checklist_state: ChecklistState, action_ids: List[int], action_logits: torch.Tensor) -> torch.Tensor: """ Gets the logits of desired terminal actions yet to be produced by the decoder, and returns them for the decoder to add to the prior action logits, biasing the model towards predicting missing linked actions. """ # Our basic approach here will be to figure out which actions we want to bias, by doing # some fancy indexing work, then multiply the action embeddings by a mask for those # actions, and return the sum of the result. # Shape: (num_terminal_actions, 1). This is 1 if we still want to predict something on the # checklist, and 0 otherwise. checklist_balance = checklist_state.get_balance().clamp(min=0) # (num_terminal_actions, 1) actions_in_agenda = checklist_state.terminal_actions # (1, num_current_actions) action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze( 0) # Shape: (num_terminal_actions, num_current_actions). Will have a value of 1 if the # terminal action i is our current action j, and a value of 0 otherwise. Because both sets # of actions are free of duplicates, there will be at most one non-zero value per current # action, and per terminal action. current_agenda_actions = ( actions_in_agenda == action_id_tensor).float() # Shape: (num_current_actions,). With the inner multiplication, we remove any current # agenda actions that are not in our checklist balance, then we sum over the terminal # action dimension, which will have a sum of at most one. So this will be a 0/1 tensor, # where a 1 means to encourage the current action in that position. actions_to_encourage = torch.sum(current_agenda_actions * checklist_balance, dim=0) # Shape: (num_current_actions,). This is the sum of the action embeddings that we want # the model to prefer. logit_addition = action_logits * actions_to_encourage return logit_addition
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], agenda: torch.LongTensor, example_lisp_string: List[str]) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]`` The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. """ batch_size = list(question.values())[0].size(0) # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] checklist_states = [] all_terminal_productions = [ set(instance_world.terminal_productions.values()) for instance_world in world ] max_num_terminals = max( [len(terminals) for terminals in all_terminal_productions]) for instance_actions, instance_agenda, terminal_productions in zip( actions, agenda_list, all_terminal_productions): checklist_info = self._get_checklist_info(instance_agenda, instance_actions, terminal_productions, max_num_terminals) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) checklist_states.append( ChecklistState(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_info = self._get_initial_state_and_scores( question=question, table=table, world=world, actions=actions, example_lisp_string=example_lisp_string, add_world_to_initial_state=True, checklist_states=checklist_states) initial_state = initial_info["initial_state"] # TODO(pradeep): Keep track of debug info. It's not straightforward currently because the # ERM's decode does not return the best states. outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, self._get_state_cost) if not self.training: # TODO(pradeep): Can move most of this block to super class. linking_scores = initial_info["linking_scores"] feature_scores = initial_info["feature_scores"] similarity_scores = initial_info["similarity_scores"] batch_size = list(question.values())[0].size(0) action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs['action_mapping'] = action_mapping outputs['entities'] = [] outputs['linking_scores'] = linking_scores if feature_scores is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = similarity_scores outputs['logical_form'] = [] best_action_sequences = outputs['best_action_sequences'] outputs["best_action_sequence"] = [] outputs['debug_info'] = [] agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda] for i in range(batch_size): in_agenda_ratio = 0.0 # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). outputs['logical_form'].append([]) if i in best_action_sequences: for j, action_sequence in enumerate( best_action_sequences[i]): action_strings = [ action_mapping[(i, action_index)] for action_index in action_sequence ] try: logical_form = world[i].get_logical_form( action_strings, add_var_function=False) outputs['logical_form'][-1].append(logical_form) except ParsingError: logical_form = "Error producing logical form" if j == 0: # Updating denotation accuracy and has_logical_form only based on the # first logical form. if logical_form.startswith("Error"): self._has_logical_form(0.0) else: self._has_logical_form(1.0) if example_lisp_string: self._denotation_accuracy( logical_form, example_lisp_string[i]) outputs['best_action_sequence'].append( action_strings) outputs['entities'].append(world[i].table_graph.entities) instance_possible_actions = actions[i] agenda_actions = [] for rule_id in agenda_indices[i]: rule_id = int(rule_id) if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] agenda_actions.append(action_string) actions_in_agenda = [ action in action_strings for action in agenda_actions ] if actions_in_agenda: # Note: This means that when there are no actions on agenda, agenda coverage # will be 0, not 1. in_agenda_ratio = sum(actions_in_agenda) / len( actions_in_agenda) else: outputs['best_action_sequence'].append([]) outputs['logical_form'][-1].append('') self._has_logical_form(0.0) if example_lisp_string: self._denotation_accuracy(None, example_lisp_string[i]) self._agenda_coverage(in_agenda_ratio) return outputs
def forward(self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRuleArray]], agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, epoch_num: List[int] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ # We look at the epoch number and adjust the checklist cost weight if needed here. instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: if self.training and instance_epoch_num is None: raise RuntimeError("If you want a dynamic cost weight, use the " "EpochTrackingBucketIterator!") if instance_epoch_num != self._last_epoch_in_forward: if instance_epoch_num >= self._dynamic_cost_wait_epochs: decrement = self._checklist_cost_weight * self._dynamic_cost_rate self._checklist_cost_weight -= decrement logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) self._last_epoch_in_forward = instance_epoch_num batch_size = len(worlds) action_embeddings, action_indices = self._embed_actions(actions) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size)] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)] label_strings = self._get_label_strings(labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros(checklist_target.size()) initial_checklist_states.append(ChecklistState(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, worlds=worlds, label_strings=label_strings, checklist_state=initial_checklist_states) agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, self._get_state_cost) if identifier is not None: outputs['identifier'] = identifier best_action_sequences = outputs['best_action_sequences'] batch_action_strings = self._get_action_strings(actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data) else: # We're testing. outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], agenda: torch.LongTensor, example_lisp_string: List[str], metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]`` The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenized question within a 'question_tokens' key. """ batch_size = list(question.values())[0].size(0) # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] checklist_states = [] all_terminal_productions = [ set(instance_world.terminal_productions.values()) for instance_world in world ] max_num_terminals = max( [len(terminals) for terminals in all_terminal_productions]) for instance_actions, instance_agenda, terminal_productions in zip( actions, agenda_list, all_terminal_productions): checklist_info = self._get_checklist_info(instance_agenda, instance_actions, terminal_productions, max_num_terminals) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) checklist_states.append( ChecklistState(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state( question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = CoverageDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, checklist_state=checklist_states, possible_actions=actions, extras=example_lisp_string, debug_info=None) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] outputs = self._decoder_trainer.decode( initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, world)) best_final_states = outputs['best_final_states'] if not self.training: batch_size = len(actions) agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] for i in range(batch_size): in_agenda_ratio = 0.0 # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: action_sequence = best_final_states[i][0].action_history[0] action_strings = [ action_mapping[(i, action_index)] for action_index in action_sequence ] instance_possible_actions = actions[i] agenda_actions = [] for rule_id in agenda_indices[i]: rule_id = int(rule_id) if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] agenda_actions.append(action_string) actions_in_agenda = [ action in action_strings for action in agenda_actions ] if actions_in_agenda: # Note: This means that when there are no actions on agenda, agenda coverage # will be 0, not 1. in_agenda_ratio = sum(actions_in_agenda) / len( actions_in_agenda) self._agenda_coverage(in_agenda_ratio) self._compute_validation_outputs(actions, best_final_states, world, example_lisp_string, metadata, outputs) return outputs