Exemplo n.º 1
0
    def test_action_sequence_to_sql(self):
        action_sequence = [
            'statement -> [query, ";"]',
            'query -> ["(", "SELECT", distinct, select_results, "FROM", table_refs, '
            'where_clause, ")"]',
            'distinct -> ["DISTINCT"]',
            "select_results -> [col_refs]",
            'col_refs -> [col_ref, ",", col_refs]',
            'col_ref -> ["city", ".", "city_code"]',
            "col_refs -> [col_ref]",
            'col_ref -> ["city", ".", "city_name"]',
            "table_refs -> [table_name]",
            'table_name -> ["city"]',
            'where_clause -> ["WHERE", "(", conditions, ")"]',
            "conditions -> [condition]",
            "condition -> [biexpr]",
            'biexpr -> ["city", ".", "city_name", binaryop, city_city_name_string]',
            'binaryop -> ["="]',
            "city_city_name_string -> [\"'BOSTON'\"]",
        ]

        sql_query = action_sequence_to_sql(action_sequence)
        assert (sql_query ==
                "( SELECT DISTINCT city . city_code , city . city_name "
                "FROM city WHERE ( city . city_name = 'BOSTON' ) ) ;")
Exemplo n.º 2
0
    def forward(
        self,  # type: ignore
        utterance: Dict[str, torch.LongTensor],
        world: List[AtisWorld],
        actions: List[List[ProductionRule]],
        linking_scores: torch.Tensor,
        target_action_sequence: torch.LongTensor = None,
        sql_queries: List[List[str]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        utterance : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        world : ``List[AtisWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        linking_scores: ``torch.Tensor``
            A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
            is deterministically generated where each entry indicates whether a token generated an entity.
            This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        initial_state = self._get_initial_state(utterance, world, actions,
                                                linking_scores)
        batch_size = linking_scores.shape[0]
        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
            # the MML trainer.
            return self._decoder_trainer.decode(
                initial_state,
                self._transition_function,
                (target_action_sequence.unsqueeze(1),
                 target_mask.unsqueeze(1)),
            )
        else:
            # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {"action_mapping": action_mapping}
            outputs["linking_scores"] = linking_scores
            if target_action_sequence is not None:
                outputs["loss"] = self._decoder_trainer.decode(
                    initial_state,
                    self._transition_function,
                    (target_action_sequence.unsqueeze(1),
                     target_mask.unsqueeze(1)),
                )["loss"]
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False,
            )
            outputs["best_action_sequence"] = []
            outputs["debug_info"] = []
            outputs["entities"] = []
            outputs["predicted_sql_query"] = []
            outputs["sql_queries"] = []
            outputs["utterance"] = []
            outputs["tokenized_utterance"] = []

            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs["predicted_sql_query"].append("")
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [
                    action_mapping[(i, action_index)]
                    for action_index in best_action_indices
                ]
                predicted_sql_query = action_sequence_to_sql(action_strings)

                if target_action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = target_action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(
                        best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(
                        None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                if sql_queries and sql_queries[i]:
                    denotation_correct = self._executor.evaluate_sql_query(
                        predicted_sql_query, sql_queries[i])
                    self._denotation_accuracy(denotation_correct)
                    outputs["sql_queries"].append(sql_queries[i])

                outputs["utterance"].append(world[i].utterances[-1])
                outputs["tokenized_utterance"].append([
                    token.text for token in world[i].tokenized_utterances[-1]
                ])
                outputs["entities"].append(world[i].entities)
                outputs["best_action_sequence"].append(action_strings)
                outputs["predicted_sql_query"].append(
                    sqlparse.format(predicted_sql_query, reindent=True))
                outputs["debug_info"].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
            return outputs
Exemplo n.º 3
0
    def forward(
        self,  # type: ignore
        tokens: Dict[str, torch.LongTensor],
        valid_actions: List[List[ProductionRule]],
        action_sequence: torch.LongTensor = None,
    ) -> Dict[str, torch.Tensor]:
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        """
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        batch_size = embedded_utterance.size(0)

        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance,
                                                      mask))
        initial_state = self._get_initial_state(encoder_outputs, mask,
                                                valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            loss_output = self._decoder_trainer.decode(
                initial_state,
                self._transition_function,
                (action_sequence.unsqueeze(1), target_mask.unsqueeze(1)),
            )
            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs["action_mapping"] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=True,
            )
            outputs["best_action_sequence"] = []
            outputs["debug_info"] = []
            outputs["predicted_sql_query"] = []
            outputs["sql_queries"] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs["predicted_sql_query"].append("")
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [
                    action_mapping[i][action_index]
                    for action_index in best_action_indices
                ]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(
                        best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(
                        None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                outputs["best_action_sequence"].append(action_strings)
                outputs["predicted_sql_query"].append(
                    sqlparse.format(predicted_sql_query, reindent=True))
                outputs["debug_info"].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs