예제 #1
0
    def _get_checklist_info(self,
                            agenda: torch.LongTensor,
                            all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor,
                                                                             torch.Tensor,
                                                                             torch.Tensor]:
        """
        Takes an agenda and a list of all actions and returns a target checklist against which the
        checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
        and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
        checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
        ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
        ``False``, indices of all terminals that are not in the agenda will be masked.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRuleArray]``
            All actions for one instance.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).data.cpu().numpy()])
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRuleArray, a tuple where the first item is the production
            # rule string.
            if action[0] in self._terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        # (num_terminals, 1)
        terminal_actions = util.new_variable_with_data(agenda,
                                                       torch.Tensor(terminal_indices))
        # (num_terminals, 1)
        target_checklist = util.new_variable_with_data(agenda,
                                                       torch.Tensor(target_checklist_list))
        if self._penalize_non_agenda_actions:
            # All terminal actions are relevant
            checklist_mask = torch.ones_like(target_checklist)
        else:
            checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
예제 #2
0
 def _get_model_scores_by_batch(self, states: List[StateType]) -> Dict[int, List[Variable]]:
     batch_scores: Dict[int, List[Variable]] = defaultdict(list)
     for state in states:
         for batch_index, model_score, history in zip(state.batch_indices,
                                                      state.score,
                                                      state.action_history):
             if self._normalize_by_length:
                 path_length = nn_util.new_variable_with_data(model_score,
                                                              torch.Tensor([len(history)]))
                 model_score = model_score / path_length
             batch_scores[batch_index].append(model_score)
     return batch_scores
예제 #3
0
 def decode(self,
            initial_state: DecoderState,
            decode_step: DecoderStep,
            supervision: Callable[[StateType], torch.Tensor]) -> Dict[str, torch.Tensor]:
     cost_function = supervision
     finished_states = self._get_finished_states(initial_state, decode_step)
     loss = nn_util.new_variable_with_data(initial_state.score[0], torch.Tensor([0.0]))
     finished_model_scores = self._get_model_scores_by_batch(finished_states)
     finished_costs = self._get_costs_by_batch(finished_states, cost_function)
     for batch_index in finished_model_scores:
         # Finished model scores are log-probabilities of the predicted sequences. We convert
         # log probabilities into probabilities and re-normalize them to compute expected cost under
         # the distribution approximated by the beam search.
         costs = torch.cat(finished_costs[batch_index])
         logprobs = torch.cat(finished_model_scores[batch_index])
         # Unmasked softmax of log probabilities will convert them into probabilities and
         # renormalize them.
         renormalized_probs = nn_util.masked_softmax(logprobs, None)
         loss += renormalized_probs.dot(costs)
     mean_loss = loss / len(finished_model_scores)
     return {'loss': mean_loss,
             'best_action_sequences': self._get_best_action_sequences(finished_states)}
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrWorld]],
            actions: List[List[ProductionRuleArray]],
            target_action_sequences: torch.LongTensor = None,
            labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            util.new_variable_with_data(
                list(sentence.values())[0], torch.Tensor([0.0]))
            for i in range(batch_size)
        ]
        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]
        worlds_list = [worlds[i] for i in range(batch_size)]

        initial_state = NlvrDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            action_indices=action_indices,
            possible_actions=actions,
            worlds=worlds_list,
            label_strings=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        best_final_states = self._decoder_beam_search.search(
            self._max_decoding_steps,
            initial_state,
            self._decoder_step,
            keep_final_unfinished_states=False)
        best_action_sequences: Dict[int, List[List[int]]] = {}
        for i in range(batch_size):
            # Decoding may not have terminated with any completed logical forms, if `num_steps`
            # isn't long enough (or if the model is not trained enough and gets into an
            # infinite action loop).
            if i in best_final_states:
                best_action_indices = [
                    best_final_states[i][0].action_history[0]
                ]
                best_action_sequences[i] = best_action_indices
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if target_action_sequences is not None:
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings)
        else:
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrWorld]],
            actions: List[List[ProductionRuleArray]],
            agenda: torch.LongTensor,
            labels: torch.LongTensor = None,
            epoch_num: List[int] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError(
                    "If you want a dynamic cost weight, use the "
                    "EpochTrackingBucketIterator!")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f",
                                self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            util.new_variable_with_data(
                list(sentence.values())[0], torch.Tensor([0.0]))
            for i in range(batch_size)
        ]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = util.new_variable_with_size(
                checklist_target, checklist_target.size(), 0)
            initial_checklist_states.append(
                ChecklistState(terminal_actions=terminal_actions,
                               checklist_target=checklist_target,
                               checklist_mask=checklist_mask,
                               checklist=initial_checklist))
        initial_state = NlvrDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            action_indices=action_indices,
            possible_actions=actions,
            worlds=worlds,
            label_strings=label_strings,
            checklist_state=initial_checklist_states)

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(initial_state,
                                               self._decoder_step,
                                               self._get_state_cost)
        best_action_sequences = outputs['best_action_sequences']
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings,
                                 possible_actions=actions,
                                 agenda_data=agenda_data)
        else:
            # We're testing.
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs