Exemple #1
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]:  # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError(
                "The Text2SqlWorld was specified to not use prelinked "
                "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        if self.use_untyped_entities:
            update_grammar_values_with_variables(grammar_with_context,
                                                 prelinked_entities)
        else:
            update_grammar_numbers_and_strings_with_variables(
                grammar_with_context, prelinked_entities, self.columns)

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        return action_sequence, sorted_actions
Exemple #2
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]: # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError("The Text2SqlNoGrammarWorld was specified to not use prelinked "
                                     "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        update_grammar_numbers_and_strings_with_variables(grammar_with_context,
                                                              prelinked_entities,
                                                              self.columns)
        update_grammar_with_tokens(grammar_with_context,
                                    query)

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        try:
            action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        except ParseError as e:
            print("\nParse Error - details:\n", e.pos, '\n', e.expr, '\n', e.text)
            action_sequence = None
        except RecursionError as er:
            print("\nParse recursion error - details:\n", " ".join(query), '\n', grammar_with_context['terminal'])
            action_sequence = None

        return action_sequence, sorted_actions
Exemple #3
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]: # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError("The Text2SqlWorld was specified to not use prelinked "
                                     "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        if self.use_untyped_entities:
            update_grammar_values_with_variables(grammar_with_context, prelinked_entities)
        else:
            update_grammar_numbers_and_strings_with_variables(grammar_with_context,
                                                              prelinked_entities,
                                                              self.columns)

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        try:
            action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        except ParseError:
            action_sequence = None

        return action_sequence, sorted_actions
 def __init__(self,
              schema_path: str = None) -> None:
     self.grammar_dictionary = deepcopy(GRAMMAR_DICTIONARY)
     schema = read_dataset_schema(schema_path)
     self.all_tables = {k: [x[0] for x in v] for k, v in schema.items()}
     self.grammar_str: str = self.initialize_grammar_str()
     self.grammar: Grammar = Grammar(self.grammar_str)
     self.valid_actions: Dict[str, List[str]] = initialize_valid_actions(self.grammar)
Exemple #5
0
    def get_action_sequence_and_all_actions(
            self, query: List[str]) -> Tuple[List[str], List[str]]:
        # TODO(Mark): Add in modifications here
        grammar_with_context = deepcopy(self.base_grammar_dictionary)
        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        return action_sequence, sorted_actions
Exemple #6
0
    def __init__(self,
                 all_tables: Dict[str, List[str]] = None,
                 tables_with_strings: Dict[str, List[str]] = None,
                 database_file: str = None) -> None:
        self.grammar_dictionary = deepcopy(GRAMMAR_DICTIONARY)
        self.all_tables = all_tables
        self.tables_with_strings = tables_with_strings
        if database_file:
            self.database_file = cached_path(database_file)
            self.connection = sqlite3.connect(self.database_file)
            self.cursor = self.connection.cursor()

        self.grammar_str: str = self.initialize_grammar_str()
        self.grammar: Grammar = Grammar(self.grammar_str)
        self.valid_actions: Dict[str, List[str]] = initialize_valid_actions(self.grammar, KEYWORDS)
        if database_file:
            self.connection.close()
Exemple #7
0
    def __init__(self,
                 utterances: List[str],
                 tokenizer: Tokenizer = None) -> None:
        if AtisWorld.sql_table_context is None:
            AtisWorld.sql_table_context = AtisSqlTableContext(
                ALL_TABLES, TABLES_WITH_STRINGS, AtisWorld.database_file)
        self.utterances: List[str] = utterances
        self.tokenizer = tokenizer if tokenizer else WordTokenizer()
        self.tokenized_utterances = [
            self.tokenizer.tokenize(utterance) for utterance in self.utterances
        ]
        self.dates = self._get_dates()
        self.linked_entities = self._get_linked_entities()

        entities, linking_scores = self._flatten_entities()
        # This has shape (num_entities, num_utterance_tokens).
        self.linking_scores: numpy.ndarray = linking_scores
        self.entities: List[str] = entities
        self.grammar: Grammar = self._update_grammar()
        self.valid_actions = initialize_valid_actions(self.grammar, KEYWORDS)
    def __init__(self,
                 all_tables: Dict[str, List[str]] = None,
                 tables_with_strings: Dict[str, List[str]] = None,
                 database_file: str = None) -> None:
        self.all_tables = all_tables
        self.tables_with_strings = tables_with_strings
        if database_file:
            self.database_file = cached_path(database_file)
            self.connection = sqlite3.connect(self.database_file)
            self.cursor = self.connection.cursor()

        grammar_dictionary, strings_list = self.create_grammar_dict_and_strings()
        self.grammar_dictionary: Dict[str, List[str]] = grammar_dictionary
        self.strings_list: List[Tuple[str, str]] = strings_list

        self.grammar_string: str = self.get_grammar_string()
        self.grammar: Grammar = Grammar(self.grammar_string)
        self.valid_actions: Dict[str, List[str]] = initialize_valid_actions(self.grammar, KEYWORDS)
        if database_file:
            self.connection.close()
    def __init__(self,
                 all_tables: Dict[str, List[str]] = None,
                 tables_with_strings: Dict[str, List[str]] = None,
                 database_file: str = None) -> None:
        self.all_tables = all_tables
        self.tables_with_strings = tables_with_strings
        if database_file:
            self.database_file = cached_path(database_file)
            self.connection = sqlite3.connect(self.database_file)
            self.cursor = self.connection.cursor()

        grammar_dictionary, strings_list = self.create_grammar_dict_and_strings()
        self.grammar_dictionary: Dict[str, List[str]] = grammar_dictionary
        self.strings_list: List[Tuple[str, str]] = strings_list

        self.grammar_string: str = self.get_grammar_string()
        self.grammar: Grammar = Grammar(self.grammar_string)
        self.valid_actions: Dict[str, List[str]] = initialize_valid_actions(self.grammar, KEYWORDS)
        if database_file:
            self.connection.close()
Exemple #10
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]:  # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError(
                "The Text2SqlWorld was specified to not use prelinked "
                "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        for variable, info in prelinked_entities.items():
            variable_column = info["type"].upper()
            matched_column = self.columns.get(variable_column, None)

            if matched_column is not None:
                # Try to infer the variable's type by matching it to a column in
                # the database. If we can't, we just add it as a value.
                if column_has_numeric_type(matched_column):
                    grammar_with_context["number"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["number"]
                elif column_has_string_type(matched_column):
                    grammar_with_context["string"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["string"]
                else:
                    grammar_with_context["value"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["value"]
            # Otherwise, try to infer by looking at the actual value:
            else:
                try:
                    # This is what happens if you try and do type inference
                    # in a grammar which parses _strings_ in _Python_.
                    # We're just seeing if the python interpreter can convert
                    # to to a float - if it can, we assume it's a number.
                    float(info["text"])
                    is_numeric = True
                except ValueError:
                    is_numeric = False
                if is_numeric:
                    grammar_with_context["number"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["number"]
                elif info["text"].replace(" ", "").isalpha():
                    grammar_with_context["string"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["string"]
                else:
                    grammar_with_context["value"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["value"]

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        return action_sequence, sorted_actions