def _get_checklist_info(self, agenda: torch.LongTensor, all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda and a list of all actions and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``, ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to ``False``, indices of all terminals that are not in the agenda will be masked. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRuleArray]`` All actions for one instance. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set([int(x) for x in agenda.squeeze(0).data.cpu().numpy()]) for index, action in enumerate(all_actions): # Each action is a ProductionRuleArray, a tuple where the first item is the production # rule string. if action[0] in self._terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. # (num_terminals, 1) terminal_actions = util.new_variable_with_data(agenda, torch.Tensor(terminal_indices)) # (num_terminals, 1) target_checklist = util.new_variable_with_data(agenda, torch.Tensor(target_checklist_list)) if self._penalize_non_agenda_actions: # All terminal actions are relevant checklist_mask = torch.ones_like(target_checklist) else: checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask
def _get_model_scores_by_batch(self, states: List[StateType]) -> Dict[int, List[Variable]]: batch_scores: Dict[int, List[Variable]] = defaultdict(list) for state in states: for batch_index, model_score, history in zip(state.batch_indices, state.score, state.action_history): if self._normalize_by_length: path_length = nn_util.new_variable_with_data(model_score, torch.Tensor([len(history)])) model_score = model_score / path_length batch_scores[batch_index].append(model_score) return batch_scores
def decode(self, initial_state: DecoderState, decode_step: DecoderStep, supervision: Callable[[StateType], torch.Tensor]) -> Dict[str, torch.Tensor]: cost_function = supervision finished_states = self._get_finished_states(initial_state, decode_step) loss = nn_util.new_variable_with_data(initial_state.score[0], torch.Tensor([0.0])) finished_model_scores = self._get_model_scores_by_batch(finished_states) finished_costs = self._get_costs_by_batch(finished_states, cost_function) for batch_index in finished_model_scores: # Finished model scores are log-probabilities of the predicted sequences. We convert # log probabilities into probabilities and re-normalize them to compute expected cost under # the distribution approximated by the beam search. costs = torch.cat(finished_costs[batch_index]) logprobs = torch.cat(finished_model_scores[batch_index]) # Unmasked softmax of log probabilities will convert them into probabilities and # renormalize them. renormalized_probs = nn_util.masked_softmax(logprobs, None) loss += renormalized_probs.dot(costs) mean_loss = loss / len(finished_model_scores) return {'loss': mean_loss, 'best_action_sequences': self._get_best_action_sequences(finished_states)}
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRuleArray]], target_action_sequences: torch.LongTensor = None, labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihod over a set of approximate logical forms. """ batch_size = len(worlds) action_embeddings, action_indices = self._embed_actions(actions) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ util.new_variable_with_data( list(sentence.values())[0], torch.Tensor([0.0])) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] worlds_list = [worlds[i] for i in range(batch_size)] initial_state = NlvrDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, worlds=worlds_list, label_strings=label_strings) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} if target_action_sequences is not None: outputs = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) best_final_states = self._decoder_beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [ best_final_states[i][0].action_history[0] ] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if target_action_sequences is not None: self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings) else: outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRuleArray]], agenda: torch.LongTensor, labels: torch.LongTensor = None, epoch_num: List[int] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ # We look at the epoch number and adjust the checklist cost weight if needed here. instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: if self.training and instance_epoch_num is None: raise RuntimeError( "If you want a dynamic cost weight, use the " "EpochTrackingBucketIterator!") if instance_epoch_num != self._last_epoch_in_forward: if instance_epoch_num >= self._dynamic_cost_wait_epochs: decrement = self._checklist_cost_weight * self._dynamic_cost_rate self._checklist_cost_weight -= decrement logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) self._last_epoch_in_forward = instance_epoch_num batch_size = len(worlds) action_embeddings, action_indices = self._embed_actions(actions) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ util.new_variable_with_data( list(sentence.values())[0], torch.Tensor([0.0])) for i in range(batch_size) ] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = util.new_variable_with_size( checklist_target, checklist_target.size(), 0) initial_checklist_states.append( ChecklistState(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_state = NlvrDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, worlds=worlds, label_strings=label_strings, checklist_state=initial_checklist_states) agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, self._get_state_cost) best_action_sequences = outputs['best_action_sequences'] batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data) else: # We're testing. outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs