Ejemplo n.º 1
0
    def _get_state_cost(self, worlds: List[WikiTablesWorld],
                        state: CoverageDecoderState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError(
                "_get_state_cost() is not defined for unfinished states!")
        world = worlds[state.batch_indices[0]]

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(),
                                        min=0.0)
        checklist_cost = torch.sum((checklist_balance)**2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(
            state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [
            state.possible_actions[batch_index][i][0] for i in action_history
        ]
        logical_form = world.get_logical_form(action_strings)
        lisp_string = state.extras[batch_index]
        if self._denotation_accuracy.evaluate_logical_form(
                logical_form, lisp_string):
            cost = checklist_cost
        else:
            cost = checklist_cost + (
                1 - self._checklist_cost_weight) * denotation_cost
        return cost
Ejemplo n.º 2
0
    def _compute_action_probabilities(
        self,  # type: ignore
        state: CoverageDecoderState,
        hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, torch.Tensor, torch.Tensor,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_embeddings, output_action_embeddings, action_ids = instance_actions[
                'global']

            # This embedding addition the only difference between the logic here and the
            # corresponding logic in the super class.
            embedding_addition = self._get_predicted_embedding_addition(
                state.checklist_state[group_index], action_ids,
                action_embeddings)
            addition = embedding_addition * self._checklist_multiplier
            predicted_action_embedding = predicted_action_embedding + addition

            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            action_logits = action_embeddings.mm(
                predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits,
                                                                dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, output_action_embeddings, action_ids))
        return batch_results
 def _get_state_info(
         self, state: CoverageDecoderState,
         batch_worlds: List[List[NlvrWorld]]) -> Dict[str, List]:
     """
     This method is here for debugging purposes, in case you want to look at the what the model
     is learning. It may be inefficient to call it while training the model on real data.
     """
     if len(state.batch_indices) == 1 and state.is_finished():
         costs = [
             float(
                 self._get_state_cost(batch_worlds,
                                      state).detach().cpu().numpy())
         ]
     else:
         costs = []
     model_scores = [
         float(score.detach().cpu().numpy()) for score in state.score
     ]
     all_actions = state.possible_actions[0]
     action_sequences = [[
         self._get_action_string(all_actions[action]) for action in history
     ] for history in state.action_history]
     agenda_sequences = []
     all_agenda_indices = []
     for checklist_state in state.checklist_state:
         agenda_indices = []
         for action, is_wanted in zip(checklist_state.terminal_actions,
                                      checklist_state.checklist_target):
             action_int = int(action.detach().cpu().numpy())
             is_wanted_int = int(is_wanted.detach().cpu().numpy())
             if is_wanted_int != 0:
                 agenda_indices.append(action_int)
         agenda_sequences.append([
             self._get_action_string(all_actions[action])
             for action in agenda_indices
         ])
         all_agenda_indices.append(agenda_indices)
     return {
         "agenda": agenda_sequences,
         "agenda_indices": all_agenda_indices,
         "history": action_sequences,
         "history_indices": state.action_history,
         "costs": costs,
         "scores": model_scores
     }
    def _get_state_cost(self, batch_worlds: List[List[NlvrWorld]],
                        state: CoverageDecoderState) -> torch.Tensor:
        """
        Return the cost of a finished state. Since it is a finished state, the group size will be
        1, and hence we'll return just one cost.

        The ``batch_worlds`` parameter here is because we need the world to check the denotation
        accuracy of the action sequence in the finished state.  Instead of adding a field to the
        ``State`` object just for this method, we take the ``World`` as a parameter here.
        """
        if not state.is_finished():
            raise RuntimeError(
                "_get_state_cost() is not defined for unfinished states!")
        instace_worlds = batch_worlds[state.batch_indices[0]]
        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask.
        checklist_balance = state.checklist_state[0].get_balance()
        checklist_cost = torch.sum((checklist_balance)**2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        # Note: If we are penalizing the model for producing non-agenda actions, this is not the
        # upper limit on the checklist cost. That would be the number of terminal actions.
        denotation_cost = torch.sum(
            state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
        # how many worlds the logical form is correct in?
        # extras being None happens when we are testing. We do not care about the cost
        # then.  TODO (pradeep): Make this cleaner.
        if state.extras is None or all(
                self._check_state_denotations(state, instace_worlds)):
            cost = checklist_cost
        else:
            cost = checklist_cost + (
                1 - self._checklist_cost_weight) * denotation_cost
        return cost
Ejemplo n.º 5
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[WikiTablesWorld],
            actions: List[List[ProductionRuleArray]],
            agenda: torch.LongTensor,
            example_lisp_string: List[str],
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' key.
        """
        batch_size = list(question.values())[0].size(0)
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        checklist_states = []
        all_terminal_productions = [
            set(instance_world.terminal_productions.values())
            for instance_world in world
        ]
        max_num_terminals = max(
            [len(terminals) for terminals in all_terminal_productions])
        for instance_actions, instance_agenda, terminal_productions in zip(
                actions, agenda_list, all_terminal_productions):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions,
                                                      terminal_productions,
                                                      max_num_terminals)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            checklist_states.append(
                ChecklistState(terminal_actions=terminal_actions,
                               checklist_target=checklist_target,
                               checklist_mask=checklist_mask,
                               checklist=initial_checklist))
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(
            question, table, world, actions, outputs)

        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = CoverageDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            checklist_state=checklist_states,
            possible_actions=actions,
            extras=example_lisp_string,
            debug_info=None)

        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        outputs = self._decoder_trainer.decode(
            initial_state,  # type: ignore
            self._decoder_step,
            partial(self._get_state_cost, world))
        best_final_states = outputs['best_final_states']

        if not self.training:
            batch_size = len(actions)
            agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            for i in range(batch_size):
                in_agenda_ratio = 0.0
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    action_sequence = best_final_states[i][0].action_history[0]
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in action_sequence
                    ]
                    instance_possible_actions = actions[i]
                    agenda_actions = []
                    for rule_id in agenda_indices[i]:
                        rule_id = int(rule_id)
                        if rule_id == -1:
                            continue
                        action_string = instance_possible_actions[rule_id][0]
                        agenda_actions.append(action_string)
                    actions_in_agenda = [
                        action in action_strings for action in agenda_actions
                    ]
                    if actions_in_agenda:
                        # Note: This means that when there are no actions on agenda, agenda coverage
                        # will be 0, not 1.
                        in_agenda_ratio = sum(actions_in_agenda) / len(
                            actions_in_agenda)
                self._agenda_coverage(in_agenda_ratio)

            self._compute_validation_outputs(actions, best_final_states, world,
                                             example_lisp_string, metadata,
                                             outputs)
        return outputs
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrWorld]],
            actions: List[List[ProductionRuleArray]],
            agenda: torch.LongTensor,
            identifier: List[str] = None,
            labels: torch.LongTensor = None,
            epoch_num: List[int] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError(
                    "If you want a dynamic cost weight, use the "
                    "EpochTrackingBucketIterator!")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f",
                                self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
            for i in range(batch_size)
        ]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            initial_checklist_states.append(
                ChecklistState(terminal_actions=terminal_actions,
                               checklist_target=checklist_target,
                               checklist_mask=checklist_mask,
                               checklist=initial_checklist))
        initial_state = CoverageDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states)

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(
            initial_state,  # type: ignore
            self._decoder_step,
            partial(self._get_state_cost, worlds))
        if identifier is not None:
            outputs['identifier'] = identifier
        best_final_states = outputs['best_final_states']
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [
                state.action_history[0] for state in states
            ]
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings,
                                 possible_actions=actions,
                                 agenda_data=agenda_data)
        else:
            # We're testing.
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs
    def _compute_action_probabilities(
        self,  # type: ignore
        state: CoverageDecoderState,
        hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, torch.Tensor, torch.Tensor,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_embeddings, output_action_embeddings, embedded_actions = instance_actions[
                'global']

            # This embedding addition the only difference between the logic here and the
            # corresponding logic in the super class.
            embedding_addition = self._get_predicted_embedding_addition(
                state.checklist_state[group_index], embedded_actions,
                action_embeddings)
            addition = embedding_addition * self._checklist_multiplier
            predicted_action_embedding = predicted_action_embedding + addition

            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            embedded_action_logits = action_embeddings.mm(
                predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            action_ids = embedded_actions

            if 'linked' in instance_actions:
                linking_scores, type_embeddings, linked_actions = instance_actions[
                    'linked']
                action_ids = embedded_actions + linked_actions
                # (num_question_tokens, 1)
                linked_action_logits = linking_scores.mm(
                    attention_weights[group_index].unsqueeze(-1)).squeeze(-1)

                linked_logits_addition = self._get_linked_logits_addition(
                    state.checklist_state[group_index], linked_actions,
                    linked_action_logits)

                addition = linked_logits_addition * self._linked_checklist_multiplier
                linked_action_logits = linked_action_logits + addition

                # The `output_action_embeddings` tensor gets used later as the input to the next
                # decoder step.  For linked actions, we don't have any action embedding, so we use
                # the entity type instead.
                output_action_embeddings = torch.cat(
                    [output_action_embeddings, type_embeddings], dim=0)

                if self._mixture_feedforward is not None:
                    # The linked and global logits are combined with a mixture weight to prevent the
                    # linked_action_logits from dominating the embedded_action_logits if a softmax
                    # was applied on both together.
                    mixture_weight = self._mixture_feedforward(
                        hidden_state[group_index])
                    mix1 = torch.log(mixture_weight)
                    mix2 = torch.log(1 - mixture_weight)

                    entity_action_probs = torch.nn.functional.log_softmax(
                        linked_action_logits, dim=-1) + mix1
                    embedded_action_probs = torch.nn.functional.log_softmax(
                        embedded_action_logits, dim=-1) + mix2
                    current_log_probs = torch.cat(
                        [embedded_action_probs, entity_action_probs], dim=-1)
                else:
                    action_logits = torch.cat(
                        [embedded_action_logits, linked_action_logits], dim=-1)
                    current_log_probs = torch.nn.functional.log_softmax(
                        action_logits, dim=-1)
            else:
                action_logits = embedded_action_logits
                current_log_probs = torch.nn.functional.log_softmax(
                    action_logits, dim=-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits,
                                                                dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, output_action_embeddings, action_ids))
        return batch_results