def _get_state_cost( self, batch_worlds: List[List[NlvrLanguage]], state: CoverageState ) -> torch.Tensor: """ Return the cost of a finished state. Since it is a finished state, the group size will be 1, and hence we'll return just one cost. The ``batch_worlds`` parameter here is because we need the world to check the denotation accuracy of the action sequence in the finished state. Instead of adding a field to the ``State`` object just for this method, we take the ``World`` as a parameter here. """ if not state.is_finished(): raise RuntimeError("_get_state_cost() is not defined for unfinished states!") instance_worlds = batch_worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. checklist_balance = state.checklist_state[0].get_balance() checklist_cost = torch.sum((checklist_balance) ** 2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. # Note: If we are penalizing the model for producing non-agenda actions, this is not the # upper limit on the checklist cost. That would be the number of terminal actions. denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on # how many worlds the logical form is correct in? # extras being None happens when we are testing. We do not care about the cost # then. TODO (pradeep): Make this cleaner. if state.extras is None or all(self._check_state_denotations(state, instance_worlds)): cost = checklist_cost else: cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost return cost
def _get_state_cost(self, worlds: List[WikiTablesWorld], state: CoverageState) -> torch.Tensor: if not state.is_finished(): raise RuntimeError( "_get_state_cost() is not defined for unfinished states!") world = worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. We clamp the lower limit of the balance at 0 to avoid # penalizing agenda actions produced multiple times. checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0) checklist_cost = torch.sum((checklist_balance)**2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. denotation_cost = torch.sum( state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost action_history = state.action_history[0] batch_index = state.batch_indices[0] action_strings = [ state.possible_actions[batch_index][i][0] for i in action_history ] logical_form = world.get_logical_form(action_strings) lisp_string = state.extras[batch_index] if self._executor.evaluate_logical_form(logical_form, lisp_string): cost = checklist_cost else: cost = checklist_cost + ( 1 - self._checklist_cost_weight) * denotation_cost return cost
def _get_state_cost(self, worlds: List[WikiTablesVariableFreeWorld], state: CoverageState) -> torch.Tensor: if not state.is_finished(): raise RuntimeError("_get_state_cost() is not defined for unfinished states!") world = worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. We clamp the lower limit of the balance at 0 to avoid # penalizing agenda actions produced multiple times. checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0) checklist_cost = torch.sum((checklist_balance) ** 2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost action_history = state.action_history[0] batch_index = state.batch_indices[0] action_strings = [state.possible_actions[batch_index][i][0] for i in action_history] logical_form = world.get_logical_form(action_strings) target_values = state.extras[batch_index] evaluation = False try: executor_logger = \ logging.getLogger('weak_supervision.semparse.executors.wikitables_variable_free_executor') executor_logger.setLevel(logging.ERROR) evaluation = world.evaluate_logical_form(logical_form, target_values) except IndexError: # TODO(pradeep): This happens due to a bug in "filter_in" and "filter_no_in" functions. # The value evaluation, if it is a list, could be an empty one. Fix it there! pass if evaluation: cost = checklist_cost else: cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost return cost
def _get_state_cost(self, worlds: List[WikiTablesWorld], state: CoverageState) -> torch.Tensor: if not state.is_finished(): raise RuntimeError("_get_state_cost() is not defined for unfinished states!") world = worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. We clamp the lower limit of the balance at 0 to avoid # penalizing agenda actions produced multiple times. checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0) checklist_cost = torch.sum((checklist_balance) ** 2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost action_history = state.action_history[0] batch_index = state.batch_indices[0] action_strings = [state.possible_actions[batch_index][i][0] for i in action_history] logical_form = world.get_logical_form(action_strings) lisp_string = state.extras[batch_index] if self._executor.evaluate_logical_form(logical_form, lisp_string): cost = checklist_cost else: cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost return cost
def _get_state_info(self, state: CoverageState, batch_worlds: List[List[NlvrWorld]]) -> Dict[str, List]: """ This method is here for debugging purposes, in case you want to look at the what the model is learning. It may be inefficient to call it while training the model on real data. """ if len(state.batch_indices) == 1 and state.is_finished(): costs = [float(self._get_state_cost(batch_worlds, state).detach().cpu().numpy())] else: costs = [] model_scores = [float(score.detach().cpu().numpy()) for score in state.score] all_actions = state.possible_actions[0] action_sequences = [[self._get_action_string(all_actions[action]) for action in history] for history in state.action_history] agenda_sequences = [] all_agenda_indices = [] for checklist_state in state.checklist_state: agenda_indices = [] for action, is_wanted in zip(checklist_state.terminal_actions, checklist_state.checklist_target): action_int = int(action.detach().cpu().numpy()) is_wanted_int = int(is_wanted.detach().cpu().numpy()) if is_wanted_int != 0: agenda_indices.append(action_int) agenda_sequences.append([self._get_action_string(all_actions[action]) for action in agenda_indices]) all_agenda_indices.append(agenda_indices) return {"agenda": agenda_sequences, "agenda_indices": all_agenda_indices, "history": action_sequences, "history_indices": state.action_history, "costs": costs, "scores": model_scores}
def _get_state_cost(self, batch_worlds: List[List[NlvrWorld]], state: CoverageState) -> torch.Tensor: """ Return the cost of a finished state. Since it is a finished state, the group size will be 1, and hence we'll return just one cost. The ``batch_worlds`` parameter here is because we need the world to check the denotation accuracy of the action sequence in the finished state. Instead of adding a field to the ``State`` object just for this method, we take the ``World`` as a parameter here. """ if not state.is_finished(): raise RuntimeError("_get_state_cost() is not defined for unfinished states!") instance_worlds = batch_worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. checklist_balance = state.checklist_state[0].get_balance() checklist_cost = torch.sum((checklist_balance) ** 2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. # Note: If we are penalizing the model for producing non-agenda actions, this is not the # upper limit on the checklist cost. That would be the number of terminal actions. denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on # how many worlds the logical form is correct in? # extras being None happens when we are testing. We do not care about the cost # then. TODO (pradeep): Make this cleaner. if state.extras is None or all(self._check_state_denotations(state, instance_worlds)): cost = checklist_cost else: cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost return cost
def _compute_action_probabilities( self, # type: ignore state: CoverageState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[ group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions[ 'global'] # This embedding addition the only difference between the logic here and the # corresponding logic in the super class. embedding_addition = self._get_predicted_embedding_addition( state.checklist_state[group_index], action_ids, action_embeddings) addition = embedding_addition * self._checklist_multiplier predicted_action_embedding = predicted_action_embedding + addition # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _compute_action_probabilities(self, # type: ignore state: CoverageState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions['global'] # This embedding addition the only difference between the logic here and the # corresponding logic in the super class. embedding_addition = self._get_predicted_embedding_addition(state.checklist_state[group_index], action_ids, action_embeddings) addition = embedding_addition * self._checklist_multiplier predicted_action_embedding = predicted_action_embedding + addition # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append((group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _get_state_cost(self, worlds: List[WikiTablesLanguage], state: CoverageState) -> torch.Tensor: if not state.is_finished(): raise RuntimeError( "_get_state_cost() is not defined for unfinished states!") world = worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. We clamp the lower limit of the balance at 0 to avoid # penalizing agenda actions produced multiple times. checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0) checklist_cost = torch.sum((checklist_balance)**2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. denotation_cost = torch.sum( state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost action_history = state.action_history[0] batch_index = state.batch_indices[0] action_strings = [ state.possible_actions[batch_index][i][0] for i in action_history ] target_values = state.extras[batch_index] evaluation = False executor_logger = \ logging.getLogger('allennlp.semparse.domain_languages.wikitables_language') executor_logger.setLevel(logging.ERROR) evaluation = world.evaluate_action_sequence(action_strings, target_values) if evaluation: cost = checklist_cost else: cost = checklist_cost + ( 1 - self._checklist_cost_weight) * denotation_cost return cost
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesLanguage], actions: List[List[ProductionRule]], agenda: torch.LongTensor, target_values: List[List[str]] = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesLanguage]`` We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``world`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. agenda : ``torch.LongTensor`` Agenda vectors that the checklist vectors will be compared against to compute the checklist cost. target_values : ``List[List[str]]``, optional (default = None) For each instance, a list of target values taken from the example lisp string. We pass this list to the evaluator along with logical forms to compute denotation accuracy. metadata : ``List[Dict[str, Any]]``, optional (default = None) Metadata containing the original tokenized question within a 'question_tokens' field. """ batch_size = list(question.values())[0].size(0) # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] checklist_states = [] all_terminal_productions = [ set(instance_world.terminal_productions.values()) for instance_world in world ] max_num_terminals = max( [len(terminals) for terminals in all_terminal_productions]) for instance_actions, instance_agenda, terminal_productions in zip( actions, agenda_list, all_terminal_productions): checklist_info = self._get_checklist_info(instance_agenda, instance_actions, terminal_productions, max_num_terminals) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) checklist_states.append( ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state( question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = CoverageState( batch_indices=list(range(batch_size)), # type: ignore action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, checklist_state=checklist_states, possible_actions=actions, extras=target_values, debug_info=None) if target_values is not None: logger.warning(f"TARGET VALUES: {target_values}") trainer_outputs = self._decoder_trainer.decode( initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, world)) outputs.update(trainer_outputs) else: initial_state.debug_info = [[] for _ in range(batch_size)] batch_size = len(actions) agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] best_final_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) for i in range(batch_size): in_agenda_ratio = 0.0 # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: action_sequence = best_final_states[i][0].action_history[0] action_strings = [ action_mapping[(i, action_index)] for action_index in action_sequence ] instance_possible_actions = actions[i] agenda_actions = [] for rule_id in agenda_indices[i]: rule_id = int(rule_id) if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] agenda_actions.append(action_string) actions_in_agenda = [ action in action_strings for action in agenda_actions ] if actions_in_agenda: # Note: This means that when there are no actions on agenda, agenda coverage # will be 0, not 1. in_agenda_ratio = sum(actions_in_agenda) / len( actions_in_agenda) self._agenda_coverage(in_agenda_ratio) self._compute_validation_outputs(actions, best_final_states, world, target_values, metadata, outputs) return outputs
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRule]], agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, epoch_num: List[int] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ # We look at the epoch number and adjust the checklist cost weight if needed here. instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: if self.training and instance_epoch_num is None: raise RuntimeError( "If you want a dynamic cost weight, use the " "EpochTrackingBucketIterator!") if instance_epoch_num != self._last_epoch_in_forward: if instance_epoch_num >= self._dynamic_cost_wait_epochs: decrement = self._checklist_cost_weight * self._dynamic_cost_rate self._checklist_cost_weight -= decrement logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) self._last_epoch_in_forward = instance_epoch_num batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size) ] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) # TODO(mattg): It looks like the agenda is only ever used on the CPU. In that case, it's a # waste to copy it to the GPU and then back, and this should probably be a MetadataField. agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) initial_checklist_states.append( ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_state = CoverageState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings, checklist_state=initial_checklist_states) agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode( initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, worlds)) if identifier is not None: outputs['identifier'] = identifier best_final_states = outputs['best_final_states'] best_action_sequences = {} for batch_index, states in best_final_states.items(): best_action_sequences[batch_index] = [ state.action_history[0] for state in states ] batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data) else: # We're testing. outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs
def _compute_action_probabilities(self, # type: ignore state: CoverageState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] action_ids: List[int] = [] if "global" in instance_actions: action_embeddings, output_action_embeddings, embedded_actions = instance_actions['global'] # This embedding addition the only difference between the logic here and the # corresponding logic in the super class. embedding_addition = self._get_predicted_embedding_addition(state.checklist_state[group_index], embedded_actions, action_embeddings) addition = embedding_addition * self._checklist_multiplier predicted_action_embedding = predicted_action_embedding + addition # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. embedded_action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) action_ids += embedded_actions else: embedded_action_logits = None output_action_embeddings = None if 'linked' in instance_actions: linking_scores, type_embeddings, linked_actions = instance_actions['linked'] action_ids += linked_actions # (num_question_tokens, 1) linked_action_logits = linking_scores.mm(attention_weights[group_index].unsqueeze(-1)).squeeze(-1) linked_logits_addition = self._get_linked_logits_addition(state.checklist_state[group_index], linked_actions, linked_action_logits) addition = linked_logits_addition * self._linked_checklist_multiplier linked_action_logits = linked_action_logits + addition # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use # the entity type instead. if output_action_embeddings is None: output_action_embeddings = type_embeddings else: output_action_embeddings = torch.cat([output_action_embeddings, type_embeddings], dim=0) if self._mixture_feedforward is not None: # The linked and global logits are combined with a mixture weight to prevent the # linked_action_logits from dominating the embedded_action_logits if a softmax # was applied on both together. mixture_weight = self._mixture_feedforward(hidden_state[group_index]) mix1 = torch.log(mixture_weight) mix2 = torch.log(1 - mixture_weight) entity_action_probs = torch.nn.functional.log_softmax(linked_action_logits, dim=-1) + mix1 if embedded_action_logits is None: current_log_probs = entity_action_probs else: embedded_action_probs = torch.nn.functional.log_softmax(embedded_action_logits, dim=-1) + mix2 current_log_probs = torch.cat([embedded_action_probs, entity_action_probs], dim=-1) else: if embedded_action_logits is None: current_log_probs = torch.nn.functional.log_softmax(linked_action_logits, dim=-1) else: action_logits = torch.cat([embedded_action_logits, linked_action_logits], dim=-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) else: current_log_probs = torch.nn.functional.log_softmax(embedded_action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append((group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRule]], agenda: torch.LongTensor, example_lisp_string: List[str], metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]`` The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenized question within a 'question_tokens' key. """ batch_size = list(question.values())[0].size(0) # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] checklist_states = [] all_terminal_productions = [set(instance_world.terminal_productions.values()) for instance_world in world] max_num_terminals = max([len(terminals) for terminals in all_terminal_productions]) for instance_actions, instance_agenda, terminal_productions in zip(actions, agenda_list, all_terminal_productions): checklist_info = self._get_checklist_info(instance_agenda, instance_actions, terminal_productions, max_num_terminals) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros(checklist_target.size()) checklist_states.append(ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = CoverageState(batch_indices=list(range(batch_size)), # type: ignore action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, checklist_state=checklist_states, possible_actions=actions, extras=example_lisp_string, debug_info=None) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] outputs = self._decoder_trainer.decode(initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, world)) best_final_states = outputs['best_final_states'] if not self.training: batch_size = len(actions) agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] for i in range(batch_size): in_agenda_ratio = 0.0 # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: action_sequence = best_final_states[i][0].action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in action_sequence] instance_possible_actions = actions[i] agenda_actions = [] for rule_id in agenda_indices[i]: rule_id = int(rule_id) if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] agenda_actions.append(action_string) actions_in_agenda = [action in action_strings for action in agenda_actions] if actions_in_agenda: # Note: This means that when there are no actions on agenda, agenda coverage # will be 0, not 1. in_agenda_ratio = sum(actions_in_agenda) / len(actions_in_agenda) self._agenda_coverage(in_agenda_ratio) self._compute_validation_outputs(actions, best_final_states, world, example_lisp_string, metadata, outputs) return outputs
def _compute_action_probabilities(self, # type: ignore state: CoverageState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] action_embeddings, output_action_embeddings, embedded_actions = instance_actions['global'] # This embedding addition the only difference between the logic here and the # corresponding logic in the super class. embedding_addition = self._get_predicted_embedding_addition(state.checklist_state[group_index], embedded_actions, action_embeddings) addition = embedding_addition * self._checklist_multiplier predicted_action_embedding = predicted_action_embedding + addition # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. embedded_action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) action_ids = embedded_actions if 'linked' in instance_actions: linking_scores, type_embeddings, linked_actions = instance_actions['linked'] action_ids = embedded_actions + linked_actions # (num_question_tokens, 1) linked_action_logits = linking_scores.mm(attention_weights[group_index].unsqueeze(-1)).squeeze(-1) linked_logits_addition = self._get_linked_logits_addition(state.checklist_state[group_index], linked_actions, linked_action_logits) addition = linked_logits_addition * self._linked_checklist_multiplier linked_action_logits = linked_action_logits + addition # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use # the entity type instead. output_action_embeddings = torch.cat([output_action_embeddings, type_embeddings], dim=0) if self._mixture_feedforward is not None: # The linked and global logits are combined with a mixture weight to prevent the # linked_action_logits from dominating the embedded_action_logits if a softmax # was applied on both together. mixture_weight = self._mixture_feedforward(hidden_state[group_index]) mix1 = torch.log(mixture_weight) mix2 = torch.log(1 - mixture_weight) entity_action_probs = torch.nn.functional.log_softmax(linked_action_logits, dim=-1) + mix1 embedded_action_probs = torch.nn.functional.log_softmax(embedded_action_logits, dim=-1) + mix2 current_log_probs = torch.cat([embedded_action_probs, entity_action_probs], dim=-1) else: action_logits = torch.cat([embedded_action_logits, linked_action_logits], dim=-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) else: action_logits = embedded_action_logits current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append((group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def forward(self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRule]], agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, epoch_num: List[int] = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ # We look at the epoch number and adjust the checklist cost weight if needed here. instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: if self.training and instance_epoch_num is None: raise RuntimeError("If you want a dynamic cost weight, use the " "BucketIterator with track_epoch=True.") if instance_epoch_num != self._last_epoch_in_forward: if instance_epoch_num >= self._dynamic_cost_wait_epochs: decrement = self._checklist_cost_weight * self._dynamic_cost_rate self._checklist_cost_weight -= decrement logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) self._last_epoch_in_forward = instance_epoch_num batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size)] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)] label_strings = self._get_label_strings(labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) # TODO(mattg): It looks like the agenda is only ever used on the CPU. In that case, it's a # waste to copy it to the GPU and then back, and this should probably be a MetadataField. agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros(checklist_target.size()) initial_checklist_states.append(ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_state = CoverageState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings, checklist_state=initial_checklist_states) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode(initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, worlds)) if identifier is not None: outputs['identifier'] = identifier best_final_states = outputs['best_final_states'] best_action_sequences = {} for batch_index, states in best_final_states.items(): best_action_sequences[batch_index] = [state.action_history[0] for state in states] batch_action_strings = self._get_action_strings(actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data) else: # We're testing. if metadata is not None: outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata] outputs['debug_info'] = [] for i in range(batch_size): outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs['action_mapping'] = action_mapping return outputs