Exemple #1
0
 def test_world_modifies_unconstrained_grammar_correctly(self):
     world = Text2SqlWorld(self.schema)
     grammar_dictionary = world.base_grammar_dictionary
     assert grammar_dictionary["table_name"] == ['"RESTAURANT"', '"LOCATION"', '"GEOGRAPHIC"']
     assert grammar_dictionary["column_name"] == ['"STREET_NAME"', '"RESTAURANT_ID"', '"REGION"',
                                                  '"RATING"', '"NAME"', '"HOUSE_NUMBER"',
                                                  '"FOOD_TYPE"', '"COUNTY"', '"CITY_NAME"']
Exemple #2
0
    def test_grammar_from_world_can_produce_entities_as_values(self):
        world = Text2SqlWorld(self.schema)
        sql = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
            'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
            "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION',
            '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=',
            "'name0'", ';'
        ]

        entities = {
            "city_name0": {
                "text": "San fran",
                "type": "location"
            },
            "name0": {
                "text": "Matt Gardinios Pizza",
                "type": "restaurant"
            }
        }
        action_sequence, actions = world.get_action_sequence_and_all_actions(
            sql, entities)
        assert 'string -> ["\'city_name0\'"]' in action_sequence
        assert 'string -> ["\'name0\'"]' in action_sequence
        assert 'string -> ["\'city_name0\'"]' in actions
        assert 'string -> ["\'name0\'"]' in actions
Exemple #3
0
    def __init__(self,
                 schema_path: str,
                 database_file: str,
                 use_all_sql: bool = False,
                 remove_unneeded_aliases: bool = True,
                 use_prelinked_entities: bool = True,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 cross_validation_split_to_exclude: int = None,
                 lazy: bool = False) -> None:
        super().__init__(lazy)
        self._token_indexers = token_indexers or {
            'tokens': SingleIdTokenIndexer()
        }
        self._use_all_sql = use_all_sql
        self._remove_unneeded_aliases = remove_unneeded_aliases
        self._use_prelinked_entities = use_prelinked_entities

        if not self._use_prelinked_entities:
            raise ConfigurationError(
                "The grammar based text2sql dataset reader "
                "currently requires the use of entity pre-linking.")

        self._cross_validation_split_to_exclude = str(
            cross_validation_split_to_exclude)

        self._database_file = cached_path(database_file)
        self._connection = sqlite3.connect(self._database_file)
        self._cursor = self._connection.cursor()

        self._schema_path = schema_path
        self._world = Text2SqlWorld(
            schema_path,
            self._cursor,
            use_prelinked_entities=use_prelinked_entities)
Exemple #4
0
 def test_world_modifies_grammar_with_global_values_for_dataset(self):
     world = Text2SqlWorld(self.schema)
     grammar_dictionary = world.base_grammar_dictionary
     # Should have added 2.5 because it is a global value
     # for the restaurants dataset.
     assert grammar_dictionary["value"] == ['"2.5"', 'parenval', '"YEAR(CURDATE())"',
                                            'number', 'boolean', 'function', 'col_ref', 'string']
Exemple #5
0
    def test_variable_free_world_cannot_parse_as_statements(self):
        world = Text2SqlWorld(self.schema)
        grammar_dictionary = world.base_grammar_dictionary
        for productions in grammar_dictionary.items():
            assert "AS" not in productions

        sql_with_as = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS',
            'LOCATIONalias0', ',', 'RESTAURANT', 'WHERE', 'LOCATION', '.',
            'CITY_NAME', '=', "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME',
            '=', 'LOCATION', '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.',
            'NAME', '=', "'name0'", ';'
        ]

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)

        with self.assertRaises(ParseError):
            sql_visitor.parse(" ".join(sql_with_as))

        sql = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
            'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
            "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION',
            '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=',
            "'name0'", ';'
        ]

        # Without the AS we should still be able to parse it.
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Exemple #6
0
    def test_grammar_from_world_can_parse_statements(self):
        world = Text2SqlWorld(self.schema)
        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
               'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
               "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION',
               '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';']

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Exemple #7
0
    def test_untyped_grammar_has_no_string_or_number_references(self):
        world = Text2SqlWorld(self.schema, use_untyped_entities=True)
        grammar_dictionary = world.base_grammar_dictionary

        for key, value in grammar_dictionary.items():
            assert key not in {"number", "string"}
            # We don't check for string directly here because
            # string_set is a valid non-terminal.
            assert all(["number" not in production for production in value])
            assert all(["string)" not in production for production in value])
            assert all(["string " not in production for production in value])
            assert all(["(string " not in production for production in value])
    def test_grammar_statelet(self):
        valid_actions = None
        world = Text2SqlWorld(self.schema)

        sql = ["SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", ";"]
        action_sequence, valid_actions = world.get_action_sequence_and_all_actions(sql)

        grammar_state = GrammarStatelet(
            ["statement"], valid_actions, Text2SqlParser.is_nonterminal, reverse_productions=True
        )
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []
Exemple #9
0
    def test_world_adds_values_from_tables(self):
        connection = sqlite3.connect(self.database_path)
        cursor = connection.cursor()
        world = Text2SqlWorld(self.schema,
                              cursor=cursor,
                              use_prelinked_entities=False)
        assert world.base_grammar_dictionary["number"] == [
            '"229"',
            '"228"',
            '"227"',
            '"226"',
            '"225"',
            '"5"',
            '"4"',
            '"3"',
            '"2"',
            '"1"',
            '"833"',
            '"430"',
            '"242"',
            '"135"',
            '"1103"',
        ]

        assert world.base_grammar_dictionary["string"] == [
            '"tommy\'s"',
            '"rod\'s hickory pit restaurant"',
            '"lyons restaurant"',
            '"jamerican cuisine"',
            '"denny\'s restaurant"',
            '"american"',
            '"vallejo"',
            '"w. el camino real"',
            '"el camino real"',
            '"e. el camino real"',
            '"church st"',
            '"broadway"',
            '"sunnyvale"',
            '"san francisco"',
            '"san carlos"',
            '"american canyon"',
            '"alviso"',
            '"albany"',
            '"alamo"',
            '"alameda"',
            '"unknown"',
            '"santa clara county"',
            '"contra costa county"',
            '"alameda county"',
            '"bay area"',
        ]
Exemple #10
0
    def test_grammar_from_world_can_produce_entities_as_values(self):
        world = Text2SqlWorld(self.schema)
        sql = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            ",",
            "RESTAURANT",
            "WHERE",
            "LOCATION",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        entities = {
            "city_name0": {
                "text": "San fran",
                "type": "location"
            },
            "name0": {
                "text": "Matt Gardinios Pizza",
                "type": "restaurant"
            },
        }
        action_sequence, actions = world.get_action_sequence_and_all_actions(
            sql, entities)
        assert "string -> [\"'city_name0'\"]" in action_sequence
        assert "string -> [\"'name0'\"]" in action_sequence
        assert "string -> [\"'city_name0'\"]" in actions
        assert "string -> [\"'name0'\"]" in actions
Exemple #11
0
    def __init__(self,
                 schema_path: str,
                 use_all_sql: bool = False,
                 remove_unneeded_aliases: bool = True,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 cross_validation_split_to_exclude: int = None,
                 lazy: bool = False) -> None:
        super().__init__(lazy)
        self._token_indexers = token_indexers or {
            'tokens': SingleIdTokenIndexer()
        }
        self._use_all_sql = use_all_sql
        self._remove_unneeded_aliases = remove_unneeded_aliases
        self._cross_validation_split_to_exclude = str(
            cross_validation_split_to_exclude)

        self._sql_table_context = Text2SqlTableContext(schema_path)
        self._world = Text2SqlWorld(self._sql_table_context)
    def test_grammar_statelet(self):
        valid_actions = None
        world = Text2SqlWorld(self.schema)

        sql = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
            'RESTAURANT', ';'
        ]
        action_sequence, valid_actions = world.get_action_sequence_and_all_actions(
            sql)

        grammar_state = GrammarStatelet(['statement'],
                                        valid_actions,
                                        Text2SqlParser.is_nonterminal,
                                        reverse_productions=True)
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []  # pylint: disable=protected-access
Exemple #13
0
    def test_grammar_from_world_can_parse_statements(self):
        world = Text2SqlWorld(self.schema)
        sql = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            ",",
            "RESTAURANT",
            "WHERE",
            "LOCATION",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Exemple #14
0
 def test_world_identifies_non_global_rules(self):
     world = Text2SqlWorld(self.schema)
     assert not world.is_global_rule('value -> ["\'food_type0\'"]')
Exemple #15
0
 def setUp(self):
     super().setUp()
     self.schema = str(self.FIXTURES_ROOT / 'data' / 'text2sql' /
                       'restaurants-schema.csv')
     context = Text2SqlTableContext(self.schema)
     self.world = Text2SqlWorld(context)
Exemple #16
0
    def test_variable_free_world_cannot_parse_as_statements(self):
        world = Text2SqlWorld(self.schema)
        grammar_dictionary = world.base_grammar_dictionary
        for productions in grammar_dictionary.items():
            assert "AS" not in productions

        sql_with_as = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            "AS",
            "LOCATIONalias0",
            ",",
            "RESTAURANT",
            "WHERE",
            "LOCATION",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        grammar = Grammar(format_grammar_string(world.base_grammar_dictionary))
        sql_visitor = SqlVisitor(grammar)

        with self.assertRaises(ParseError):
            sql_visitor.parse(" ".join(sql_with_as))

        sql = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            ",",
            "RESTAURANT",
            "WHERE",
            "LOCATION",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        # Without the AS we should still be able to parse it.
        sql_visitor = SqlVisitor(grammar)
        sql_visitor.parse(" ".join(sql))
Exemple #17
0
 def test_world_identifies_non_global_rules(self):
     world = Text2SqlWorld(self.schema)
     assert not world.is_global_rule("value -> [\"'food_type0'\"]")