Exemplo n.º 1
0
    def _get_predicted_embedding_addition(
            self, checklist_state: ChecklistState, action_ids: List[int],
            action_embeddings: torch.Tensor) -> torch.Tensor:
        """
        Gets the embeddings of desired terminal actions yet to be produced by the decoder, and
        returns their sum for the decoder to add it to the predicted embedding to bias the
        prediction towards missing actions.
        """
        # Our basic approach here will be to figure out which actions we want to bias, by doing
        # some fancy indexing work, then multiply the action embeddings by a mask for those
        # actions, and return the sum of the result.

        # Shape: (num_terminal_actions, 1).  This is 1 if we still want to predict something on the
        # checklist, and 0 otherwise.
        checklist_balance = checklist_state.get_balance().clamp(min=0)

        # (num_terminal_actions, 1)
        actions_in_agenda = checklist_state.terminal_actions
        # (1, num_current_actions)
        action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze(
            0)
        # Shape: (num_terminal_actions, num_current_actions).  Will have a value of 1 if the
        # terminal action i is our current action j, and a value of 0 otherwise.  Because both sets
        # of actions are free of duplicates, there will be at most one non-zero value per current
        # action, and per terminal action.
        current_agenda_actions = (
            actions_in_agenda == action_id_tensor).float()

        # Shape: (num_current_actions,).  With the inner multiplication, we remove any current
        # agenda actions that are not in our checklist balance, then we sum over the terminal
        # action dimension, which will have a sum of at most one.  So this will be a 0/1 tensor,
        # where a 1 means to encourage the current action in that position.
        actions_to_encourage = torch.sum(current_agenda_actions *
                                         checklist_balance,
                                         dim=0)

        # Shape: (action_embedding_dim,).  This is the sum of the action embeddings that we want
        # the model to prefer.
        embedding_addition = torch.sum(action_embeddings *
                                       actions_to_encourage.unsqueeze(1),
                                       dim=0,
                                       keepdim=False)

        if self._add_action_bias:
            # If we're adding an action bias, the last dimension of the action embedding is a bias
            # weight.  We don't want this addition to affect the bias (TODO(mattg): or do we?), so
            # we zero out that dimension here.
            embedding_addition[-1] = 0

        return embedding_addition
    def _get_linked_logits_addition(
            checklist_state: ChecklistState, action_ids: List[int],
            action_logits: torch.Tensor) -> torch.Tensor:
        """
        Gets the logits of desired terminal actions yet to be produced by the decoder, and
        returns them for the decoder to add to the prior action logits, biasing the model towards
        predicting missing linked actions.
        """
        # Our basic approach here will be to figure out which actions we want to bias, by doing
        # some fancy indexing work, then multiply the action embeddings by a mask for those
        # actions, and return the sum of the result.

        # Shape: (num_terminal_actions, 1).  This is 1 if we still want to predict something on the
        # checklist, and 0 otherwise.
        checklist_balance = checklist_state.get_balance().clamp(min=0)

        # (num_terminal_actions, 1)
        actions_in_agenda = checklist_state.terminal_actions
        # (1, num_current_actions)
        action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze(
            0)
        # Shape: (num_terminal_actions, num_current_actions).  Will have a value of 1 if the
        # terminal action i is our current action j, and a value of 0 otherwise.  Because both sets
        # of actions are free of duplicates, there will be at most one non-zero value per current
        # action, and per terminal action.
        current_agenda_actions = (
            actions_in_agenda == action_id_tensor).float()

        # Shape: (num_current_actions,).  With the inner multiplication, we remove any current
        # agenda actions that are not in our checklist balance, then we sum over the terminal
        # action dimension, which will have a sum of at most one.  So this will be a 0/1 tensor,
        # where a 1 means to encourage the current action in that position.
        actions_to_encourage = torch.sum(current_agenda_actions *
                                         checklist_balance,
                                         dim=0)

        # Shape: (num_current_actions,).  This is the sum of the action embeddings that we want
        # the model to prefer.
        logit_addition = action_logits * actions_to_encourage
        return logit_addition
 def forward(
         self,  # type: ignore
         question: Dict[str, torch.LongTensor],
         table: Dict[str, torch.LongTensor],
         world: List[WikiTablesWorld],
         actions: List[List[ProductionRuleArray]],
         agenda: torch.LongTensor,
         example_lisp_string: List[str]) -> Dict[str, torch.Tensor]:
     # pylint: disable=arguments-differ
     """
     Parameters
     ----------
     question : Dict[str, torch.LongTensor]
        The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
        be passed through a ``TextFieldEmbedder`` and then through an encoder.
     table : ``Dict[str, torch.LongTensor]``
         The output of ``KnowledgeGraphField.as_array()`` applied on the table
         ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
         entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
         get embeddings for each entity.
     world : ``List[WikiTablesWorld]``
         We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
         how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
     actions : ``List[List[ProductionRuleArray]]``
         A list of all possible actions for each ``World`` in the batch, indexed into a
         ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
         and use the embeddings to determine which action to take at each timestep in the
         decoder.
     example_lisp_string : ``List[str]``
         The example (lisp-formatted) string corresponding to the given input.  This comes
         directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
         when evaluating denotation accuracy; it is otherwise unused.
     """
     batch_size = list(question.values())[0].size(0)
     # Each instance's agenda is of size (agenda_size, 1)
     agenda_list = [agenda[i] for i in range(batch_size)]
     checklist_states = []
     all_terminal_productions = [
         set(instance_world.terminal_productions.values())
         for instance_world in world
     ]
     max_num_terminals = max(
         [len(terminals) for terminals in all_terminal_productions])
     for instance_actions, instance_agenda, terminal_productions in zip(
             actions, agenda_list, all_terminal_productions):
         checklist_info = self._get_checklist_info(instance_agenda,
                                                   instance_actions,
                                                   terminal_productions,
                                                   max_num_terminals)
         checklist_target, terminal_actions, checklist_mask = checklist_info
         initial_checklist = checklist_target.new_zeros(
             checklist_target.size())
         checklist_states.append(
             ChecklistState(terminal_actions=terminal_actions,
                            checklist_target=checklist_target,
                            checklist_mask=checklist_mask,
                            checklist=initial_checklist))
     initial_info = self._get_initial_state_and_scores(
         question=question,
         table=table,
         world=world,
         actions=actions,
         example_lisp_string=example_lisp_string,
         add_world_to_initial_state=True,
         checklist_states=checklist_states)
     initial_state = initial_info["initial_state"]
     # TODO(pradeep): Keep track of debug info. It's not straightforward currently because the
     # ERM's decode does not return the best states.
     outputs = self._decoder_trainer.decode(initial_state,
                                            self._decoder_step,
                                            self._get_state_cost)
     if not self.training:
         # TODO(pradeep): Can move most of this block to super class.
         linking_scores = initial_info["linking_scores"]
         feature_scores = initial_info["feature_scores"]
         similarity_scores = initial_info["similarity_scores"]
         batch_size = list(question.values())[0].size(0)
         action_mapping = {}
         for batch_index, batch_actions in enumerate(actions):
             for action_index, action in enumerate(batch_actions):
                 action_mapping[(batch_index, action_index)] = action[0]
         outputs['action_mapping'] = action_mapping
         outputs['entities'] = []
         outputs['linking_scores'] = linking_scores
         if feature_scores is not None:
             outputs['feature_scores'] = feature_scores
         outputs['similarity_scores'] = similarity_scores
         outputs['logical_form'] = []
         best_action_sequences = outputs['best_action_sequences']
         outputs["best_action_sequence"] = []
         outputs['debug_info'] = []
         agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
         for i in range(batch_size):
             in_agenda_ratio = 0.0
             # Decoding may not have terminated with any completed logical forms, if `num_steps`
             # isn't long enough (or if the model is not trained enough and gets into an
             # infinite action loop).
             outputs['logical_form'].append([])
             if i in best_action_sequences:
                 for j, action_sequence in enumerate(
                         best_action_sequences[i]):
                     action_strings = [
                         action_mapping[(i, action_index)]
                         for action_index in action_sequence
                     ]
                     try:
                         logical_form = world[i].get_logical_form(
                             action_strings, add_var_function=False)
                         outputs['logical_form'][-1].append(logical_form)
                     except ParsingError:
                         logical_form = "Error producing logical form"
                     if j == 0:
                         # Updating denotation accuracy and has_logical_form only based on the
                         # first logical form.
                         if logical_form.startswith("Error"):
                             self._has_logical_form(0.0)
                         else:
                             self._has_logical_form(1.0)
                         if example_lisp_string:
                             self._denotation_accuracy(
                                 logical_form, example_lisp_string[i])
                         outputs['best_action_sequence'].append(
                             action_strings)
                 outputs['entities'].append(world[i].table_graph.entities)
                 instance_possible_actions = actions[i]
                 agenda_actions = []
                 for rule_id in agenda_indices[i]:
                     rule_id = int(rule_id)
                     if rule_id == -1:
                         continue
                     action_string = instance_possible_actions[rule_id][0]
                     agenda_actions.append(action_string)
                 actions_in_agenda = [
                     action in action_strings for action in agenda_actions
                 ]
                 if actions_in_agenda:
                     # Note: This means that when there are no actions on agenda, agenda coverage
                     # will be 0, not 1.
                     in_agenda_ratio = sum(actions_in_agenda) / len(
                         actions_in_agenda)
             else:
                 outputs['best_action_sequence'].append([])
                 outputs['logical_form'][-1].append('')
                 self._has_logical_form(0.0)
                 if example_lisp_string:
                     self._denotation_accuracy(None, example_lisp_string[i])
             self._agenda_coverage(in_agenda_ratio)
     return outputs
Exemplo n.º 4
0
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.LongTensor],
                worlds: List[List[NlvrWorld]],
                actions: List[List[ProductionRuleArray]],
                agenda: torch.LongTensor,
                identifier: List[str] = None,
                labels: torch.LongTensor = None,
                epoch_num: List[int] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError("If you want a dynamic cost weight, use the "
                                   "EpochTrackingBucketIterator!")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f", self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
                              for i in range(batch_size)]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
                                 range(batch_size)]

        label_strings = self._get_label_strings(labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(checklist_target.size())
            initial_checklist_states.append(ChecklistState(terminal_actions=terminal_actions,
                                                           checklist_target=checklist_target,
                                                           checklist_mask=checklist_mask,
                                                           checklist=initial_checklist))
        initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)),
                                         action_history=[[] for _ in range(batch_size)],
                                         score=initial_score_list,
                                         rnn_state=initial_rnn_state,
                                         grammar_state=initial_grammar_state,
                                         action_embeddings=action_embeddings,
                                         action_indices=action_indices,
                                         possible_actions=actions,
                                         worlds=worlds,
                                         label_strings=label_strings,
                                         checklist_state=initial_checklist_states)

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(initial_state,
                                               self._decoder_step,
                                               self._get_state_cost)
        if identifier is not None:
            outputs['identifier'] = identifier
        best_action_sequences = outputs['best_action_sequences']
        batch_action_strings = self._get_action_strings(actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings,
                                 possible_actions=actions,
                                 agenda_data=agenda_data)
        else:
            # We're testing.
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs
Exemplo n.º 5
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[WikiTablesWorld],
            actions: List[List[ProductionRuleArray]],
            agenda: torch.LongTensor,
            example_lisp_string: List[str],
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' key.
        """
        batch_size = list(question.values())[0].size(0)
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        checklist_states = []
        all_terminal_productions = [
            set(instance_world.terminal_productions.values())
            for instance_world in world
        ]
        max_num_terminals = max(
            [len(terminals) for terminals in all_terminal_productions])
        for instance_actions, instance_agenda, terminal_productions in zip(
                actions, agenda_list, all_terminal_productions):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions,
                                                      terminal_productions,
                                                      max_num_terminals)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            checklist_states.append(
                ChecklistState(terminal_actions=terminal_actions,
                               checklist_target=checklist_target,
                               checklist_mask=checklist_mask,
                               checklist=initial_checklist))
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(
            question, table, world, actions, outputs)

        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = CoverageDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            checklist_state=checklist_states,
            possible_actions=actions,
            extras=example_lisp_string,
            debug_info=None)

        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        outputs = self._decoder_trainer.decode(
            initial_state,  # type: ignore
            self._decoder_step,
            partial(self._get_state_cost, world))
        best_final_states = outputs['best_final_states']

        if not self.training:
            batch_size = len(actions)
            agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            for i in range(batch_size):
                in_agenda_ratio = 0.0
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    action_sequence = best_final_states[i][0].action_history[0]
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in action_sequence
                    ]
                    instance_possible_actions = actions[i]
                    agenda_actions = []
                    for rule_id in agenda_indices[i]:
                        rule_id = int(rule_id)
                        if rule_id == -1:
                            continue
                        action_string = instance_possible_actions[rule_id][0]
                        agenda_actions.append(action_string)
                    actions_in_agenda = [
                        action in action_strings for action in agenda_actions
                    ]
                    if actions_in_agenda:
                        # Note: This means that when there are no actions on agenda, agenda coverage
                        # will be 0, not 1.
                        in_agenda_ratio = sum(actions_in_agenda) / len(
                            actions_in_agenda)
                self._agenda_coverage(in_agenda_ratio)

            self._compute_validation_outputs(actions, best_final_states, world,
                                             example_lisp_string, metadata,
                                             outputs)
        return outputs