Ejemplo n.º 1
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]:  # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError(
                "The Text2SqlWorld was specified to not use prelinked "
                "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        if self.use_untyped_entities:
            update_grammar_values_with_variables(grammar_with_context,
                                                 prelinked_entities)
        else:
            update_grammar_numbers_and_strings_with_variables(
                grammar_with_context, prelinked_entities, self.columns)

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        return action_sequence, sorted_actions
Ejemplo n.º 2
0
    def test_variable_free_world_cannot_parse_as_statements(self):
        world = Text2SqlWorld(self.schema)
        grammar_dictionary = world.base_grammar_dictionary
        for productions in grammar_dictionary.items():
            assert "AS" not in productions

        sql_with_as = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS', 'LOCATIONalias0', ',',
                       'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
                       "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION',
                       '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';']

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)

        with self.assertRaises(ParseError):
            sql_visitor.parse(" ".join(sql_with_as))

        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
               'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
               "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION',
               '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';']

        # Without the AS we should still be able to parse it.
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Ejemplo n.º 3
0
    def test_variable_free_world_cannot_parse_as_statements(self):
        world = Text2SqlWorld(self.schema)
        grammar_dictionary = world.base_grammar_dictionary
        for productions in grammar_dictionary.items():
            assert "AS" not in productions

        sql_with_as = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS',
            'LOCATIONalias0', ',', 'RESTAURANT', 'WHERE', 'LOCATION', '.',
            'CITY_NAME', '=', "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME',
            '=', 'LOCATION', '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.',
            'NAME', '=', "'name0'", ';'
        ]

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)

        with self.assertRaises(ParseError):
            sql_visitor.parse(" ".join(sql_with_as))

        sql = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
            'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
            "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION',
            '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=',
            "'name0'", ';'
        ]

        # Without the AS we should still be able to parse it.
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Ejemplo n.º 4
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]: # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError("The Text2SqlNoGrammarWorld was specified to not use prelinked "
                                     "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        update_grammar_numbers_and_strings_with_variables(grammar_with_context,
                                                              prelinked_entities,
                                                              self.columns)
        update_grammar_with_tokens(grammar_with_context,
                                    query)

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        try:
            action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        except ParseError as e:
            print("\nParse Error - details:\n", e.pos, '\n', e.expr, '\n', e.text)
            action_sequence = None
        except RecursionError as er:
            print("\nParse recursion error - details:\n", " ".join(query), '\n', grammar_with_context['terminal'])
            action_sequence = None

        return action_sequence, sorted_actions
Ejemplo n.º 5
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]: # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError("The Text2SqlWorld was specified to not use prelinked "
                                     "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        if self.use_untyped_entities:
            update_grammar_values_with_variables(grammar_with_context, prelinked_entities)
        else:
            update_grammar_numbers_and_strings_with_variables(grammar_with_context,
                                                              prelinked_entities,
                                                              self.columns)

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        try:
            action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        except ParseError:
            action_sequence = None

        return action_sequence, sorted_actions
Ejemplo n.º 6
0
    def test_grammar_from_world_can_parse_statements(self):
        world = Text2SqlWorld(self.schema)
        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
               'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
               "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION',
               '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';']

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Ejemplo n.º 7
0
    def test_grammar_from_world_can_parse_statements(self):
        world = Text2SqlWorld(self.schema)
        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
               'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
               "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION',
               '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';']

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Ejemplo n.º 8
0
    def initialize_grammar_str(self):
        # Add all the table and column names to the grammar.
        if self.all_tables:
            table_names = sorted([f'"{table}"' for table in
                                  list(self.all_tables.keys())], reverse=True)
            self.grammar_dictionary['table_name'] = table_names

            all_columns = set()
            for columns in self.all_tables.values():
                all_columns.update(columns)
            sorted_columns = sorted([f'"{column}"' for column in all_columns], reverse=True)
            self.grammar_dictionary['column_name'] = sorted_columns

        return format_grammar_string(self.grammar_dictionary)
Ejemplo n.º 9
0
    def get_action_sequence_and_all_actions(
            self, query: List[str]) -> Tuple[List[str], List[str]]:
        # TODO(Mark): Add in modifications here
        grammar_with_context = deepcopy(self.base_grammar_dictionary)
        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        return action_sequence, sorted_actions
Ejemplo n.º 10
0
    def test_grammar_from_world_can_parse_statements(self):
        world = Text2SqlWorld(self.schema)
        sql = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            ",",
            "RESTAURANT",
            "WHERE",
            "LOCATION",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Ejemplo n.º 11
0
def parse_dataset(filename: str, filter_by: str = None, verbose: bool = False):

    grammar_string = format_grammar_string(GRAMMAR_DICTIONARY)
    grammar = Grammar(grammar_string)

    filter_by = filter_by or "13754332dvmklfdsaf-3543543"
    data = json.load(open(filename))
    num_queries = 0
    num_parsed = 0
    filtered_errors = 0

    non_basic_as_aliases = 0
    as_count = 0
    queries_with_weird_as = 0

    for i, sql_data in enumerate(process_sql_data(data)):
        sql_visitor = SqlVisitor(grammar)

        if any([x[:7] == "DERIVED"] for x in sql_data.sql):
            # NOTE: DATA hack alert - the geography dataset doesn't alias derived tables consistently,
            # so we fix the data a bit here instead of completely re-working the grammar.
            sql_to_use = []
            for j, token in enumerate(sql_data.sql):
                if token[:7] == "DERIVED" and sql_data.sql[j-1] == ")":
                    sql_to_use.append("AS")
                sql_to_use.append(token)

            previous_token = None
            query_has_weird_as = False
            for j, token in enumerate(sql_to_use[:-1]):

                if token == "AS" and previous_token is not None:

                    table_name = sql_to_use[j + 1][:-6]
                    if table_name != previous_token:
                        non_basic_as_aliases += 1
                        query_has_weird_as = True
                    as_count += 1
                previous_token = token

            if query_has_weird_as:
                queries_with_weird_as += 1


            sql_string = " ".join(sql_to_use)
        else:
            sql_string = " ".join(sql_data.sql)
        num_queries += 1
        try:
            prod_rules = sql_visitor.parse(sql_string)
            num_parsed += 1
        except Exception as e:

            if filter_by in sql_string:
                filtered_errors += 1

            if verbose and filter_by not in sql_string:
                print()
                print(e)
                print(" ".join(sql_data.text))
                print(sql_data.sql)
                try:
                    import sqlparse
                    print(sqlparse.format(sql_string, reindent=True))
                except Exception:
                    print(sql_string)

        if (i + 1) % 500 == 0:
            print(f"\tProcessed {i + 1} queries.")

    return num_parsed, num_queries, filtered_errors, non_basic_as_aliases, as_count, queries_with_weird_as
 def get_grammar_string(self):
     return format_grammar_string(self.grammar_dictionary)
Ejemplo n.º 13
0
 def get_grammar_string(self):
     return format_grammar_string(self.grammar_dictionary)
Ejemplo n.º 14
0
def parse_dataset(filename: str, filter_by: str = None, verbose: bool = False):

    grammar_string = format_grammar_string(GRAMMAR_DICTIONARY)
    grammar = Grammar(grammar_string)

    filter_by = filter_by or "13754332dvmklfdsaf-3543543"
    data = json.load(open(filename))
    num_queries = 0
    num_parsed = 0
    filtered_errors = 0

    non_basic_as_aliases = 0
    as_count = 0
    queries_with_weird_as = 0

    for i, sql_data in enumerate(process_sql_data(data)):
        sql_visitor = SqlVisitor(grammar)

        if any([x[:7] == "DERIVED"] for x in sql_data.sql):
            # NOTE: DATA hack alert - the geography dataset doesn't alias derived tables consistently,
            # so we fix the data a bit here instead of completely re-working the grammar.
            sql_to_use = []
            for j, token in enumerate(sql_data.sql):
                if token[:7] == "DERIVED" and sql_data.sql[j - 1] == ")":
                    sql_to_use.append("AS")
                sql_to_use.append(token)

            previous_token = None
            query_has_weird_as = False
            for j, token in enumerate(sql_to_use[:-1]):

                if token == "AS" and previous_token is not None:

                    table_name = sql_to_use[j + 1][:-6]
                    if table_name != previous_token:
                        non_basic_as_aliases += 1
                        query_has_weird_as = True
                    as_count += 1
                previous_token = token

            if query_has_weird_as:
                queries_with_weird_as += 1

            sql_string = " ".join(sql_to_use)
        else:
            sql_string = " ".join(sql_data.sql)
        num_queries += 1
        try:
            prod_rules = sql_visitor.parse(sql_string)
            num_parsed += 1
        except Exception as e:

            if filter_by in sql_string:
                filtered_errors += 1

            if verbose and filter_by not in sql_string:
                print()
                print(e)
                print(" ".join(sql_data.text))
                print(sql_data.sql)
                try:
                    import sqlparse
                    print(sqlparse.format(sql_string, reindent=True))
                except Exception:
                    print(sql_string)

        if (i + 1) % 500 == 0:
            print(f"\tProcessed {i + 1} queries.")

    return num_parsed, num_queries, filtered_errors, non_basic_as_aliases, as_count, queries_with_weird_as
Ejemplo n.º 15
0
    def get_action_sequence_and_all_actions(self,
                                            query: List[str] = None,
                                            prelinked_entities: Dict[str, Dict[str, str]] = None) -> Tuple[List[str], List[str]]:  # pylint: disable=line-too-long
        grammar_with_context = deepcopy(self.base_grammar_dictionary)

        if not self.use_prelinked_entities and prelinked_entities is not None:
            raise ConfigurationError(
                "The Text2SqlWorld was specified to not use prelinked "
                "entities, but prelinked entities were passed.")
        prelinked_entities = prelinked_entities or {}

        for variable, info in prelinked_entities.items():
            variable_column = info["type"].upper()
            matched_column = self.columns.get(variable_column, None)

            if matched_column is not None:
                # Try to infer the variable's type by matching it to a column in
                # the database. If we can't, we just add it as a value.
                if column_has_numeric_type(matched_column):
                    grammar_with_context["number"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["number"]
                elif column_has_string_type(matched_column):
                    grammar_with_context["string"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["string"]
                else:
                    grammar_with_context["value"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["value"]
            # Otherwise, try to infer by looking at the actual value:
            else:
                try:
                    # This is what happens if you try and do type inference
                    # in a grammar which parses _strings_ in _Python_.
                    # We're just seeing if the python interpreter can convert
                    # to to a float - if it can, we assume it's a number.
                    float(info["text"])
                    is_numeric = True
                except ValueError:
                    is_numeric = False
                if is_numeric:
                    grammar_with_context["number"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["number"]
                elif info["text"].replace(" ", "").isalpha():
                    grammar_with_context["string"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["string"]
                else:
                    grammar_with_context["value"] = [
                        f'"\'{variable}\'"'
                    ] + grammar_with_context["value"]

        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)

        sql_visitor = SqlVisitor(grammar)
        action_sequence = sql_visitor.parse(" ".join(query)) if query else []
        return action_sequence, sorted_actions
Ejemplo n.º 16
0
    def test_variable_free_world_cannot_parse_as_statements(self):
        world = Text2SqlWorld(self.schema)
        grammar_dictionary = world.base_grammar_dictionary
        for productions in grammar_dictionary.items():
            assert "AS" not in productions

        sql_with_as = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            "AS",
            "LOCATIONalias0",
            ",",
            "RESTAURANT",
            "WHERE",
            "LOCATION",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)

        with self.assertRaises(ParseError):
            sql_visitor.parse(" ".join(sql_with_as))

        sql = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            ",",
            "RESTAURANT",
            "WHERE",
            "LOCATION",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        # Without the AS we should still be able to parse it.
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))