def test_variable_free_world_cannot_parse_as_statements(self): world = Text2SqlWorld(self.schema) grammar_dictionary = world.base_grammar_dictionary for productions in grammar_dictionary.items(): assert "AS" not in productions sql_with_as = [ 'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS', 'LOCATIONalias0', ',', 'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=', "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION', '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';' ] grammar = Grammar(format_grammar_string(world.base_grammar_dictionary)) sql_visitor = SqlVisitor(grammar) with self.assertRaises(ParseError): sql_visitor.parse(" ".join(sql_with_as)) sql = [ 'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',', 'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=', "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION', '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';' ] # Without the AS we should still be able to parse it. sql_visitor = SqlVisitor(grammar) sql_visitor.parse(" ".join(sql))
def test_grammar_from_world_can_produce_entities_as_values(self): world = Text2SqlWorld(self.schema) sql = [ 'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',', 'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=', "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION', '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';' ] entities = { "city_name0": { "text": "San fran", "type": "location" }, "name0": { "text": "Matt Gardinios Pizza", "type": "restaurant" } } action_sequence, actions = world.get_action_sequence_and_all_actions( sql, entities) assert 'string -> ["\'city_name0\'"]' in action_sequence assert 'string -> ["\'name0\'"]' in action_sequence assert 'string -> ["\'city_name0\'"]' in actions assert 'string -> ["\'name0\'"]' in actions
def test_world_modifies_grammar_with_global_values_for_dataset(self): world = Text2SqlWorld(self.schema) grammar_dictionary = world.base_grammar_dictionary # Should have added 2.5 because it is a global value # for the restaurants dataset. assert grammar_dictionary["value"] == [ '"2.5"', 'parenval', '"YEAR(CURDATE())"', 'number', 'boolean', 'function', 'col_ref', 'string' ]
def test_world_modifies_unconstrained_grammar_correctly(self): world = Text2SqlWorld(self.schema) grammar_dictionary = world.base_grammar_dictionary assert grammar_dictionary["table_name"] == [ '"RESTAURANT"', '"LOCATION"', '"GEOGRAPHIC"' ] assert grammar_dictionary["column_name"] == [ '"STREET_NAME"', '"RESTAURANT_ID"', '"REGION"', '"RATING"', '"NAME"', '"HOUSE_NUMBER"', '"FOOD_TYPE"', '"COUNTY"', '"CITY_NAME"' ]
def test_untyped_grammar_has_no_string_or_number_references(self): world = Text2SqlWorld(self.schema, use_untyped_entities=True) grammar_dictionary = world.base_grammar_dictionary for key, value in grammar_dictionary.items(): assert key not in {"number", "string"} # We don't check for string directly here because # string_set is a valid non-terminal. assert all(["number" not in production for production in value]) assert all(["string)" not in production for production in value]) assert all(["string " not in production for production in value]) assert all(["(string " not in production for production in value])
def test_grammar_from_world_can_parse_statements(self): world = Text2SqlWorld(self.schema) sql = [ 'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',', 'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=', "'city_name0'", 'AND', 'RESTAURANT', '.', 'NAME', '=', 'LOCATION', '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';' ] grammar = Grammar(format_grammar_string(world.base_grammar_dictionary)) sql_visitor = SqlVisitor(grammar) sql_visitor.parse(" ".join(sql))
def test_grammar_statelet(self): valid_actions = None world = Text2SqlWorld(self.schema) sql = ["SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", ";"] action_sequence, valid_actions = world.get_action_sequence_and_all_actions(sql) grammar_state = GrammarStatelet( ["statement"], valid_actions, Text2SqlParser.is_nonterminal, reverse_productions=True ) for action in action_sequence: grammar_state = grammar_state.take_action(action) assert grammar_state._nonterminal_stack == []
def test_world_adds_values_from_tables(self): connection = sqlite3.connect(self.database_path) cursor = connection.cursor() world = Text2SqlWorld(self.schema, cursor=cursor, use_prelinked_entities=False) assert world.base_grammar_dictionary["number"] == [ '"229"', '"228"', '"227"', '"226"', '"225"', '"5"', '"4"', '"3"', '"2"', '"1"', '"833"', '"430"', '"242"', '"135"', '"1103"', ] assert world.base_grammar_dictionary["string"] == [ '"tommy\'s"', '"rod\'s hickory pit restaurant"', '"lyons restaurant"', '"jamerican cuisine"', '"denny\'s restaurant"', '"american"', '"vallejo"', '"w. el camino real"', '"el camino real"', '"e. el camino real"', '"church st"', '"broadway"', '"sunnyvale"', '"san francisco"', '"san carlos"', '"american canyon"', '"alviso"', '"albany"', '"alamo"', '"alameda"', '"unknown"', '"santa clara county"', '"contra costa county"', '"alameda county"', '"bay area"', ]
def test_grammar_from_world_can_produce_entities_as_values(self): world = Text2SqlWorld(self.schema) sql = [ "SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", "WHERE", "LOCATION", ".", "CITY_NAME", "=", "'city_name0'", "AND", "RESTAURANT", ".", "NAME", "=", "LOCATION", ".", "RESTAURANT_ID", "AND", "RESTAURANT", ".", "NAME", "=", "'name0'", ";", ] entities = { "city_name0": { "text": "San fran", "type": "location" }, "name0": { "text": "Matt Gardinios Pizza", "type": "restaurant" }, } action_sequence, actions = world.get_action_sequence_and_all_actions( sql, entities) assert "string -> [\"'city_name0'\"]" in action_sequence assert "string -> [\"'name0'\"]" in action_sequence assert "string -> [\"'city_name0'\"]" in actions assert "string -> [\"'name0'\"]" in actions
def test_grammar_statelet(self): valid_actions = None world = Text2SqlWorld(self.schema) sql = [ 'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',', 'RESTAURANT', ';' ] action_sequence, valid_actions = world.get_action_sequence_and_all_actions( sql) grammar_state = GrammarStatelet(['statement'], valid_actions, Text2SqlParser.is_nonterminal, reverse_productions=True) for action in action_sequence: grammar_state = grammar_state.take_action(action) assert grammar_state._nonterminal_stack == [] # pylint: disable=protected-access
def __init__( self, schema_path: str, database_file: str = None, use_all_sql: bool = False, remove_unneeded_aliases: bool = True, use_prelinked_entities: bool = True, use_untyped_entities: bool = True, token_indexers: Dict[str, TokenIndexer] = None, cross_validation_split_to_exclude: int = None, keep_if_unparseable: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) self._token_indexers = token_indexers or { "tokens": SingleIdTokenIndexer() } self._use_all_sql = use_all_sql self._remove_unneeded_aliases = remove_unneeded_aliases self._use_prelinked_entities = use_prelinked_entities self._keep_if_unparsable = keep_if_unparseable if not self._use_prelinked_entities: raise ConfigurationError( "The grammar based text2sql dataset reader " "currently requires the use of entity pre-linking.") self._cross_validation_split_to_exclude = str( cross_validation_split_to_exclude) if database_file is not None: database_file = cached_path(database_file) connection = sqlite3.connect(database_file) self._cursor = connection.cursor() else: self._cursor = None self._schema_path = schema_path self._world = Text2SqlWorld( schema_path, self._cursor, use_prelinked_entities=use_prelinked_entities, use_untyped_entities=use_untyped_entities, )
def test_grammar_from_world_can_parse_statements(self): world = Text2SqlWorld(self.schema) sql = [ "SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", "WHERE", "LOCATION", ".", "CITY_NAME", "=", "'city_name0'", "AND", "RESTAURANT", ".", "NAME", "=", "LOCATION", ".", "RESTAURANT_ID", "AND", "RESTAURANT", ".", "NAME", "=", "'name0'", ";", ] grammar = Grammar(format_grammar_string(world.base_grammar_dictionary)) sql_visitor = SqlVisitor(grammar) sql_visitor.parse(" ".join(sql))
def test_variable_free_world_cannot_parse_as_statements(self): world = Text2SqlWorld(self.schema) grammar_dictionary = world.base_grammar_dictionary for productions in grammar_dictionary.items(): assert "AS" not in productions sql_with_as = [ "SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", "AS", "LOCATIONalias0", ",", "RESTAURANT", "WHERE", "LOCATION", ".", "CITY_NAME", "=", "'city_name0'", "AND", "RESTAURANT", ".", "NAME", "=", "LOCATION", ".", "RESTAURANT_ID", "AND", "RESTAURANT", ".", "NAME", "=", "'name0'", ";", ] grammar = Grammar(format_grammar_string(world.base_grammar_dictionary)) sql_visitor = SqlVisitor(grammar) with self.assertRaises(ParseError): sql_visitor.parse(" ".join(sql_with_as)) sql = [ "SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", "WHERE", "LOCATION", ".", "CITY_NAME", "=", "'city_name0'", "AND", "RESTAURANT", ".", "NAME", "=", "LOCATION", ".", "RESTAURANT_ID", "AND", "RESTAURANT", ".", "NAME", "=", "'name0'", ";", ] # Without the AS we should still be able to parse it. sql_visitor = SqlVisitor(grammar) sql_visitor.parse(" ".join(sql))
def test_world_identifies_non_global_rules(self): world = Text2SqlWorld(self.schema) assert not world.is_global_rule("value -> [\"'food_type0'\"]")
def test_world_identifies_non_global_rules(self): world = Text2SqlWorld(self.schema) assert not world.is_global_rule('value -> ["\'food_type0\'"]')