def _get_predicted_embedding_addition( self, checklist_state: ChecklistStatelet, action_ids: List[int], action_embeddings: torch.Tensor, ) -> torch.Tensor: """ Gets the embeddings of desired terminal actions yet to be produced by the decoder, and returns their sum for the decoder to add it to the predicted embedding to bias the prediction towards missing actions. """ # Our basic approach here will be to figure out which actions we want to bias, by doing # some fancy indexing work, then multiply the action embeddings by a mask for those # actions, and return the sum of the result. # Shape: (num_terminal_actions, 1). This is 1 if we still want to predict something on the # checklist, and 0 otherwise. checklist_balance = checklist_state.get_balance().clamp(min=0) # (num_terminal_actions, 1) actions_in_agenda = checklist_state.terminal_actions # (1, num_current_actions) action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze( 0) # Shape: (num_terminal_actions, num_current_actions). Will have a value of 1 if the # terminal action i is our current action j, and a value of 0 otherwise. Because both sets # of actions are free of duplicates, there will be at most one non-zero value per current # action, and per terminal action. current_agenda_actions = ( actions_in_agenda == action_id_tensor).float() # Shape: (num_current_actions,). With the inner multiplication, we remove any current # agenda actions that are not in our checklist balance, then we sum over the terminal # action dimension, which will have a sum of at most one. So this will be a 0/1 tensor, # where a 1 means to encourage the current action in that position. actions_to_encourage = torch.sum(current_agenda_actions * checklist_balance, dim=0) # Shape: (action_embedding_dim,). This is the sum of the action embeddings that we want # the model to prefer. embedding_addition = torch.sum(action_embeddings * actions_to_encourage.unsqueeze(1), dim=0, keepdim=False) if self._add_action_bias: # If we're adding an action bias, the last dimension of the action embedding is a bias # weight. We don't want this addition to affect the bias (TODO(mattg): or do we?), so # we zero out that dimension here. embedding_addition[-1] = 0 return embedding_addition
def _get_predicted_embedding_addition(self, checklist_state: ChecklistStatelet, action_ids: List[int], action_embeddings: torch.Tensor) -> torch.Tensor: """ Gets the embeddings of desired terminal actions yet to be produced by the decoder, and returns their sum for the decoder to add it to the predicted embedding to bias the prediction towards missing actions. """ # Our basic approach here will be to figure out which actions we want to bias, by doing # some fancy indexing work, then multiply the action embeddings by a mask for those # actions, and return the sum of the result. # Shape: (num_terminal_actions, 1). This is 1 if we still want to predict something on the # checklist, and 0 otherwise. checklist_balance = checklist_state.get_balance().clamp(min=0) # (num_terminal_actions, 1) actions_in_agenda = checklist_state.terminal_actions # (1, num_current_actions) action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze(0) # Shape: (num_terminal_actions, num_current_actions). Will have a value of 1 if the # terminal action i is our current action j, and a value of 0 otherwise. Because both sets # of actions are free of duplicates, there will be at most one non-zero value per current # action, and per terminal action. current_agenda_actions = (actions_in_agenda == action_id_tensor).float() # Shape: (num_current_actions,). With the inner multiplication, we remove any current # agenda actions that are not in our checklist balance, then we sum over the terminal # action dimension, which will have a sum of at most one. So this will be a 0/1 tensor, # where a 1 means to encourage the current action in that position. actions_to_encourage = torch.sum(current_agenda_actions * checklist_balance, dim=0) # Shape: (action_embedding_dim,). This is the sum of the action embeddings that we want # the model to prefer. embedding_addition = torch.sum(action_embeddings * actions_to_encourage.unsqueeze(1), dim=0, keepdim=False) if self._add_action_bias: # If we're adding an action bias, the last dimension of the action embedding is a bias # weight. We don't want this addition to affect the bias (TODO(mattg): or do we?), so # we zero out that dimension here. embedding_addition[-1] = 0 return embedding_addition
def _get_linked_logits_addition( checklist_state: ChecklistStatelet, action_ids: List[int], action_logits: torch.Tensor) -> torch.Tensor: """ Gets the logits of desired terminal actions yet to be produced by the decoder, and returns them for the decoder to add to the prior action logits, biasing the model towards predicting missing linked actions. """ # Our basic approach here will be to figure out which actions we want to bias, by doing # some fancy indexing work, then multiply the action embeddings by a mask for those # actions, and return the sum of the result. # Shape: (num_terminal_actions, 1). This is 1 if we still want to predict something on the # checklist, and 0 otherwise. checklist_balance = checklist_state.get_balance().clamp(min=0) # (num_terminal_actions, 1) actions_in_agenda = checklist_state.terminal_actions # (1, num_current_actions) action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze( 0) # Shape: (num_terminal_actions, num_current_actions). Will have a value of 1 if the # terminal action i is our current action j, and a value of 0 otherwise. Because both sets # of actions are free of duplicates, there will be at most one non-zero value per current # action, and per terminal action. current_agenda_actions = ( actions_in_agenda == action_id_tensor).float() # Shape: (num_current_actions,). With the inner multiplication, we remove any current # agenda actions that are not in our checklist balance, then we sum over the terminal # action dimension, which will have a sum of at most one. So this will be a 0/1 tensor, # where a 1 means to encourage the current action in that position. actions_to_encourage = torch.sum(current_agenda_actions * checklist_balance, dim=0) # Shape: (num_current_actions,). This is the sum of the action embeddings that we want # the model to prefer. logit_addition = action_logits * actions_to_encourage return logit_addition
def _get_linked_logits_addition(checklist_state: ChecklistStatelet, action_ids: List[int], action_logits: torch.Tensor) -> torch.Tensor: """ Gets the logits of desired terminal actions yet to be produced by the decoder, and returns them for the decoder to add to the prior action logits, biasing the model towards predicting missing linked actions. """ # Our basic approach here will be to figure out which actions we want to bias, by doing # some fancy indexing work, then multiply the action embeddings by a mask for those # actions, and return the sum of the result. # Shape: (num_terminal_actions, 1). This is 1 if we still want to predict something on the # checklist, and 0 otherwise. checklist_balance = checklist_state.get_balance().clamp(min=0) # (num_terminal_actions, 1) actions_in_agenda = checklist_state.terminal_actions # (1, num_current_actions) action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze(0) # Shape: (num_terminal_actions, num_current_actions). Will have a value of 1 if the # terminal action i is our current action j, and a value of 0 otherwise. Because both sets # of actions are free of duplicates, there will be at most one non-zero value per current # action, and per terminal action. current_agenda_actions = (actions_in_agenda == action_id_tensor).float() # Shape: (num_current_actions,). With the inner multiplication, we remove any current # agenda actions that are not in our checklist balance, then we sum over the terminal # action dimension, which will have a sum of at most one. So this will be a 0/1 tensor, # where a 1 means to encourage the current action in that position. actions_to_encourage = torch.sum(current_agenda_actions * checklist_balance, dim=0) # Shape: (num_current_actions,). This is the sum of the action embeddings that we want # the model to prefer. logit_addition = action_logits * actions_to_encourage return logit_addition
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesLanguage], actions: List[List[ProductionRule]], agenda: torch.LongTensor, target_values: List[List[str]] = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesLanguage]`` We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``world`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. agenda : ``torch.LongTensor`` Agenda vectors that the checklist vectors will be compared against to compute the checklist cost. target_values : ``List[List[str]]``, optional (default = None) For each instance, a list of target values taken from the example lisp string. We pass this list to the evaluator along with logical forms to compute denotation accuracy. metadata : ``List[Dict[str, Any]]``, optional (default = None) Metadata containing the original tokenized question within a 'question_tokens' field. """ batch_size = list(question.values())[0].size(0) # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] checklist_states = [] all_terminal_productions = [ set(instance_world.terminal_productions.values()) for instance_world in world ] max_num_terminals = max( [len(terminals) for terminals in all_terminal_productions]) for instance_actions, instance_agenda, terminal_productions in zip( actions, agenda_list, all_terminal_productions): checklist_info = self._get_checklist_info(instance_agenda, instance_actions, terminal_productions, max_num_terminals) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) checklist_states.append( ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state( question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = CoverageState( batch_indices=list(range(batch_size)), # type: ignore action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, checklist_state=checklist_states, possible_actions=actions, extras=target_values, debug_info=None) if target_values is not None: logger.warning(f"TARGET VALUES: {target_values}") trainer_outputs = self._decoder_trainer.decode( initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, world)) outputs.update(trainer_outputs) else: initial_state.debug_info = [[] for _ in range(batch_size)] batch_size = len(actions) agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] best_final_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) for i in range(batch_size): in_agenda_ratio = 0.0 # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: action_sequence = best_final_states[i][0].action_history[0] action_strings = [ action_mapping[(i, action_index)] for action_index in action_sequence ] instance_possible_actions = actions[i] agenda_actions = [] for rule_id in agenda_indices[i]: rule_id = int(rule_id) if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] agenda_actions.append(action_string) actions_in_agenda = [ action in action_strings for action in agenda_actions ] if actions_in_agenda: # Note: This means that when there are no actions on agenda, agenda coverage # will be 0, not 1. in_agenda_ratio = sum(actions_in_agenda) / len( actions_in_agenda) self._agenda_coverage(in_agenda_ratio) self._compute_validation_outputs(actions, best_final_states, world, target_values, metadata, outputs) return outputs
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRule]], agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, epoch_num: List[int] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ # We look at the epoch number and adjust the checklist cost weight if needed here. instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: if self.training and instance_epoch_num is None: raise RuntimeError( "If you want a dynamic cost weight, use the " "EpochTrackingBucketIterator!") if instance_epoch_num != self._last_epoch_in_forward: if instance_epoch_num >= self._dynamic_cost_wait_epochs: decrement = self._checklist_cost_weight * self._dynamic_cost_rate self._checklist_cost_weight -= decrement logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) self._last_epoch_in_forward = instance_epoch_num batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size) ] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) # TODO(mattg): It looks like the agenda is only ever used on the CPU. In that case, it's a # waste to copy it to the GPU and then back, and this should probably be a MetadataField. agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) initial_checklist_states.append( ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_state = CoverageState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings, checklist_state=initial_checklist_states) agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode( initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, worlds)) if identifier is not None: outputs['identifier'] = identifier best_final_states = outputs['best_final_states'] best_action_sequences = {} for batch_index, states in best_final_states.items(): best_action_sequences[batch_index] = [ state.action_history[0] for state in states ] batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data) else: # We're testing. outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs