def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        table: Dict[str, torch.LongTensor],
        world: List[WikiTablesLanguage],
        actions: List[List[ProductionRule]],
        agenda: torch.LongTensor,
        target_values: List[List[str]] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesLanguage]``
            We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance.
            Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``world`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        agenda : ``torch.LongTensor``
            Agenda vectors that the checklist vectors will be compared against to compute the checklist
            cost.
        target_values : ``List[List[str]]``, optional (default = None)
            For each instance, a list of target values taken from the example lisp string. We pass
            this list to the evaluator along with logical forms to compute denotation accuracy.
        metadata : ``List[Dict[str, Any]]``, optional (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' field.
        """
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [a for a in agenda]
        checklist_states = []
        all_terminal_productions = [
            set(instance_world.terminal_productions.values())
            for instance_world in world
        ]
        max_num_terminals = max(
            [len(terminals) for terminals in all_terminal_productions])
        for instance_actions, instance_agenda, terminal_productions in zip(
                actions, agenda_list, all_terminal_productions):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions,
                                                      terminal_productions,
                                                      max_num_terminals)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            checklist_states.append(
                ChecklistStatelet(
                    terminal_actions=terminal_actions,
                    checklist_target=checklist_target,
                    checklist_mask=checklist_mask,
                    checklist=initial_checklist,
                ))
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(
            question, table, world, actions, outputs)

        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),  # type: ignore
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            checklist_state=checklist_states,
            possible_actions=actions,
            extras=target_values,
            debug_info=None,
        )

        if target_values is not None:
            logger.warning(f"TARGET VALUES: {target_values}")
            trainer_outputs = self._decoder_trainer.decode(  # type: ignore
                initial_state, self._decoder_step,
                partial(self._get_state_cost, world))
            outputs.update(trainer_outputs)
        else:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            batch_size = len(actions)
            agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            best_final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False,
            )
            for i in range(batch_size):
                in_agenda_ratio = 0.0
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    action_sequence = best_final_states[i][0].action_history[0]
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in action_sequence
                    ]
                    instance_possible_actions = actions[i]
                    agenda_actions = []
                    for rule_id in agenda_indices[i]:
                        rule_id = int(rule_id)
                        if rule_id == -1:
                            continue
                        action_string = instance_possible_actions[rule_id][0]
                        agenda_actions.append(action_string)
                    actions_in_agenda = [
                        action in action_strings for action in agenda_actions
                    ]
                    if actions_in_agenda:
                        # Note: This means that when there are no actions on agenda, agenda coverage
                        # will be 0, not 1.
                        in_agenda_ratio = sum(actions_in_agenda) / len(
                            actions_in_agenda)
                self._agenda_coverage(in_agenda_ratio)

            self._compute_validation_outputs(actions, best_final_states, world,
                                             target_values, metadata, outputs)
        return outputs
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrLanguage]],
            actions: List[List[ProductionRule]],
            agenda: torch.LongTensor,
            identifier: List[str] = None,
            labels: torch.LongTensor = None,
            epoch_num: List[int] = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError(
                    "If you want a dynamic cost weight, use the "
                    "BucketIterator with track_epoch=True.")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f",
                                self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
            for i in range(batch_size)
        ]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            initial_checklist_states.append(
                ChecklistStatelet(terminal_actions=terminal_actions,
                                  checklist_target=checklist_target,
                                  checklist_mask=checklist_mask,
                                  checklist=initial_checklist))
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states)
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(
            initial_state,  # type: ignore
            self._decoder_step,
            partial(self._get_state_cost, worlds))
        if identifier is not None:
            outputs['identifier'] = identifier
        best_final_states = outputs['best_final_states']
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [
                state.action_history[0] for state in states
            ]
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings,
                                 possible_actions=actions,
                                 agenda_data=agenda_data)
        else:
            # We're testing.
            if metadata is not None:
                outputs["sentence_tokens"] = [
                    x["sentence_tokens"] for x in metadata
                ]
            outputs['debug_info'] = []
            for i in range(batch_size):
                outputs['debug_info'].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs['action_mapping'] = action_mapping
        return outputs
Esempio n. 3
0
    def forward(
        self,  # type: ignore
        sentence: Dict[str, torch.LongTensor],
        worlds: List[List[NlvrLanguage]],
        actions: List[List[ProductionRule]],
        agenda: torch.LongTensor,
        identifier: List[str] = None,
        labels: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        if self._dynamic_cost_rate is not None:
            # This could be added back pretty easily with an EpochCallback passed to the Trainer (it
            # just has to set the epoch number on the model, which could then be queried in here).
            logger.warning(
                "Dynamic cost rate functionality was removed in AllenNLP 1.0. If you want this, "
                "use version 0.9.  We will just use the static checklist cost weight."
            )
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [agenda.new_zeros(1, dtype=torch.float) for i in range(batch_size)]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(checklist_target.size())
            initial_checklist_states.append(
                ChecklistStatelet(
                    terminal_actions=terminal_actions,
                    checklist_target=checklist_target,
                    checklist_mask=checklist_mask,
                    checklist=initial_checklist,
                )
            )
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states,
        )
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(  # type: ignore
            initial_state, self._decoder_step, partial(self._get_state_cost, worlds)
        )
        if identifier is not None:
            outputs["identifier"] = identifier
        best_final_states = outputs["best_final_states"]
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [state.action_history[0] for state in states]
        batch_action_strings = self._get_action_strings(actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(
                action_strings=batch_action_strings,
                worlds=worlds,
                label_strings=label_strings,
                possible_actions=actions,
                agenda_data=agenda_data,
            )
        else:
            # We're testing.
            if metadata is not None:
                outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
            outputs["debug_info"] = []
            for i in range(batch_size):
                outputs["debug_info"].append(best_final_states[i][0].debug_info[0])  # type: ignore
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs["action_mapping"] = action_mapping
        return outputs