Exemplo n.º 1
0
    def _get_predicted_embedding_addition(
        self,
        checklist_state: ChecklistStatelet,
        action_ids: List[int],
        action_embeddings: torch.Tensor,
    ) -> torch.Tensor:
        """
        Gets the embeddings of desired terminal actions yet to be produced by the decoder, and
        returns their sum for the decoder to add it to the predicted embedding to bias the
        prediction towards missing actions.
        """
        # Our basic approach here will be to figure out which actions we want to bias, by doing
        # some fancy indexing work, then multiply the action embeddings by a mask for those
        # actions, and return the sum of the result.

        # Shape: (num_terminal_actions, 1).  This is 1 if we still want to predict something on the
        # checklist, and 0 otherwise.
        checklist_balance = checklist_state.get_balance().clamp(min=0)

        # (num_terminal_actions, 1)
        actions_in_agenda = checklist_state.terminal_actions
        # (1, num_current_actions)
        action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze(
            0)
        # Shape: (num_terminal_actions, num_current_actions).  Will have a value of 1 if the
        # terminal action i is our current action j, and a value of 0 otherwise.  Because both sets
        # of actions are free of duplicates, there will be at most one non-zero value per current
        # action, and per terminal action.
        current_agenda_actions = (
            actions_in_agenda == action_id_tensor).float()

        # Shape: (num_current_actions,).  With the inner multiplication, we remove any current
        # agenda actions that are not in our checklist balance, then we sum over the terminal
        # action dimension, which will have a sum of at most one.  So this will be a 0/1 tensor,
        # where a 1 means to encourage the current action in that position.
        actions_to_encourage = torch.sum(current_agenda_actions *
                                         checklist_balance,
                                         dim=0)

        # Shape: (action_embedding_dim,).  This is the sum of the action embeddings that we want
        # the model to prefer.
        embedding_addition = torch.sum(action_embeddings *
                                       actions_to_encourage.unsqueeze(1),
                                       dim=0,
                                       keepdim=False)

        if self._add_action_bias:
            # If we're adding an action bias, the last dimension of the action embedding is a bias
            # weight.  We don't want this addition to affect the bias (TODO(mattg): or do we?), so
            # we zero out that dimension here.
            embedding_addition[-1] = 0

        return embedding_addition
    def _get_linked_logits_addition(
            checklist_state: ChecklistStatelet, action_ids: List[int],
            action_logits: torch.Tensor) -> torch.Tensor:
        """
        Gets the logits of desired terminal actions yet to be produced by the decoder, and
        returns them for the decoder to add to the prior action logits, biasing the model towards
        predicting missing linked actions.
        """
        # Our basic approach here will be to figure out which actions we want to bias, by doing
        # some fancy indexing work, then multiply the action embeddings by a mask for those
        # actions, and return the sum of the result.

        # Shape: (num_terminal_actions, 1).  This is 1 if we still want to predict something on the
        # checklist, and 0 otherwise.
        checklist_balance = checklist_state.get_balance().clamp(min=0)

        # (num_terminal_actions, 1)
        actions_in_agenda = checklist_state.terminal_actions
        # (1, num_current_actions)
        action_id_tensor = checklist_balance.new(action_ids).long().unsqueeze(
            0)
        # Shape: (num_terminal_actions, num_current_actions).  Will have a value of 1 if the
        # terminal action i is our current action j, and a value of 0 otherwise.  Because both sets
        # of actions are free of duplicates, there will be at most one non-zero value per current
        # action, and per terminal action.
        current_agenda_actions = (
            actions_in_agenda == action_id_tensor).float()

        # Shape: (num_current_actions,).  With the inner multiplication, we remove any current
        # agenda actions that are not in our checklist balance, then we sum over the terminal
        # action dimension, which will have a sum of at most one.  So this will be a 0/1 tensor,
        # where a 1 means to encourage the current action in that position.
        actions_to_encourage = torch.sum(current_agenda_actions *
                                         checklist_balance,
                                         dim=0)

        # Shape: (num_current_actions,).  This is the sum of the action embeddings that we want
        # the model to prefer.
        logit_addition = action_logits * actions_to_encourage
        return logit_addition
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        table: Dict[str, torch.LongTensor],
        world: List[WikiTablesLanguage],
        actions: List[List[ProductionRule]],
        agenda: torch.LongTensor,
        target_values: List[List[str]] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesLanguage]``
            We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance.
            Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``world`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        agenda : ``torch.LongTensor``
            Agenda vectors that the checklist vectors will be compared against to compute the checklist
            cost.
        target_values : ``List[List[str]]``, optional (default = None)
            For each instance, a list of target values taken from the example lisp string. We pass
            this list to the evaluator along with logical forms to compute denotation accuracy.
        metadata : ``List[Dict[str, Any]]``, optional (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' field.
        """
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [a for a in agenda]
        checklist_states = []
        all_terminal_productions = [
            set(instance_world.terminal_productions.values())
            for instance_world in world
        ]
        max_num_terminals = max(
            [len(terminals) for terminals in all_terminal_productions])
        for instance_actions, instance_agenda, terminal_productions in zip(
                actions, agenda_list, all_terminal_productions):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions,
                                                      terminal_productions,
                                                      max_num_terminals)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            checklist_states.append(
                ChecklistStatelet(
                    terminal_actions=terminal_actions,
                    checklist_target=checklist_target,
                    checklist_mask=checklist_mask,
                    checklist=initial_checklist,
                ))
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(
            question, table, world, actions, outputs)

        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),  # type: ignore
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            checklist_state=checklist_states,
            possible_actions=actions,
            extras=target_values,
            debug_info=None,
        )

        if target_values is not None:
            logger.warning(f"TARGET VALUES: {target_values}")
            trainer_outputs = self._decoder_trainer.decode(  # type: ignore
                initial_state, self._decoder_step,
                partial(self._get_state_cost, world))
            outputs.update(trainer_outputs)
        else:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            batch_size = len(actions)
            agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            best_final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False,
            )
            for i in range(batch_size):
                in_agenda_ratio = 0.0
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    action_sequence = best_final_states[i][0].action_history[0]
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in action_sequence
                    ]
                    instance_possible_actions = actions[i]
                    agenda_actions = []
                    for rule_id in agenda_indices[i]:
                        rule_id = int(rule_id)
                        if rule_id == -1:
                            continue
                        action_string = instance_possible_actions[rule_id][0]
                        agenda_actions.append(action_string)
                    actions_in_agenda = [
                        action in action_strings for action in agenda_actions
                    ]
                    if actions_in_agenda:
                        # Note: This means that when there are no actions on agenda, agenda coverage
                        # will be 0, not 1.
                        in_agenda_ratio = sum(actions_in_agenda) / len(
                            actions_in_agenda)
                self._agenda_coverage(in_agenda_ratio)

            self._compute_validation_outputs(actions, best_final_states, world,
                                             target_values, metadata, outputs)
        return outputs
Exemplo n.º 4
0
    def forward(
        self,  # type: ignore
        sentence: Dict[str, torch.LongTensor],
        worlds: List[List[NlvrLanguage]],
        actions: List[List[ProductionRule]],
        agenda: torch.LongTensor,
        identifier: List[str] = None,
        labels: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        if self._dynamic_cost_rate is not None:
            # This could be added back pretty easily with an EpochCallback passed to the Trainer (it
            # just has to set the epoch number on the model, which could then be queried in here).
            logger.warning(
                "Dynamic cost rate functionality was removed in AllenNLP 1.0. If you want this, "
                "use version 0.9.  We will just use the static checklist cost weight."
            )
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [agenda.new_zeros(1, dtype=torch.float) for i in range(batch_size)]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(checklist_target.size())
            initial_checklist_states.append(
                ChecklistStatelet(
                    terminal_actions=terminal_actions,
                    checklist_target=checklist_target,
                    checklist_mask=checklist_mask,
                    checklist=initial_checklist,
                )
            )
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states,
        )
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(  # type: ignore
            initial_state, self._decoder_step, partial(self._get_state_cost, worlds)
        )
        if identifier is not None:
            outputs["identifier"] = identifier
        best_final_states = outputs["best_final_states"]
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [state.action_history[0] for state in states]
        batch_action_strings = self._get_action_strings(actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(
                action_strings=batch_action_strings,
                worlds=worlds,
                label_strings=label_strings,
                possible_actions=actions,
                agenda_data=agenda_data,
            )
        else:
            # We're testing.
            if metadata is not None:
                outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
            outputs["debug_info"] = []
            for i in range(batch_size):
                outputs["debug_info"].append(best_final_states[i][0].debug_info[0])  # type: ignore
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs["action_mapping"] = action_mapping
        return outputs