Ejemplo n.º 1
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]:  # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError(
                "The Text2SqlWorld was specified to not use prelinked "
                "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        if self.use_untyped_entities:
            update_grammar_values_with_variables(grammar_with_context,
                                                 prelinked_entities)
        else:
            update_grammar_numbers_and_strings_with_variables(
                grammar_with_context, prelinked_entities, self.columns)

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        try:
            action_sequence = sql_visitor.parse(
                " ".join(query)) if query else []
        except ParseError:
            action_sequence = None

        return action_sequence, sorted_actions
Ejemplo n.º 2
0
    def __init__(self,
                 utterances: List[str],
                 tokenizer: Tokenizer = None) -> None:
        if AtisWorld.sql_table_context is None:
            AtisWorld.sql_table_context = AtisSqlTableContext(
                ALL_TABLES, TABLES_WITH_STRINGS, AtisWorld.database_file)
        self.utterances: List[str] = utterances
        self.tokenizer = tokenizer if tokenizer else WordTokenizer()
        self.tokenized_utterances = [
            self.tokenizer.tokenize(utterance) for utterance in self.utterances
        ]
        self.dates = self._get_dates()
        self.linked_entities = self._get_linked_entities()

        entities, linking_scores = self._flatten_entities()
        # This has shape (num_entities, num_utterance_tokens).
        self.linking_scores: numpy.ndarray = linking_scores
        self.entities: List[str] = entities
        self.grammar: Grammar = self._update_grammar()
        self.valid_actions = initialize_valid_actions(self.grammar, KEYWORDS)
Ejemplo n.º 3
0
    def __init__(self,
                 all_tables: Dict[str, List[str]] = None,
                 tables_with_strings: Dict[str, List[str]] = None,
                 database_file: str = None) -> None:
        self.all_tables = all_tables
        self.tables_with_strings = tables_with_strings
        if database_file:
            self.database_file = cached_path(database_file)
            self.connection = sqlite3.connect(self.database_file)
            self.cursor = self.connection.cursor()

        grammar_dictionary, strings_list = self.create_grammar_dict_and_strings(
        )
        self.grammar_dictionary: Dict[str, List[str]] = grammar_dictionary
        self.strings_list: List[Tuple[str, str]] = strings_list

        self.grammar_string: str = self.get_grammar_string()
        self.grammar: Grammar = Grammar(self.grammar_string)
        self.valid_actions: Dict[str, List[str]] = initialize_valid_actions(
            self.grammar, KEYWORDS)
        if database_file:
            self.connection.close()