def test_atis_grammar_statelet(self): world = AtisWorld([("give me all flights from boston to " "philadelphia next week arriving after lunch")]) action_sequence = [ 'statement -> [query, ";"]', 'query -> ["(", "SELECT", distinct, select_results, "FROM", table_refs, ' 'where_clause, ")"]', 'distinct -> ["DISTINCT"]', "select_results -> [col_refs]", 'col_refs -> [col_ref, ",", col_refs]', 'col_ref -> ["city", ".", "city_code"]', "col_refs -> [col_ref]", 'col_ref -> ["city", ".", "city_name"]', "table_refs -> [table_name]", 'table_name -> ["city"]', 'where_clause -> ["WHERE", "(", conditions, ")"]', "conditions -> [condition]", "condition -> [biexpr]", 'biexpr -> ["city", ".", "city_name", binaryop, city_city_name_string]', 'binaryop -> ["="]', "city_city_name_string -> [\"'BOSTON'\"]", ] grammar_state = GrammarStatelet(["statement"], world.valid_actions, AtisSemanticParser.is_nonterminal) for action in action_sequence: grammar_state = grammar_state.take_action(action) assert grammar_state._nonterminal_stack == []
def test_grammar_statelet(self): valid_actions = None world = Text2SqlWorld(self.schema) sql = ["SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", ";"] action_sequence, valid_actions = world.get_action_sequence_and_all_actions(sql) grammar_state = GrammarStatelet( ["statement"], valid_actions, Text2SqlParser.is_nonterminal, reverse_productions=True ) for action in action_sequence: grammar_state = grammar_state.take_action(action) assert grammar_state._nonterminal_stack == []
def test_grammar_statelet(self): valid_actions = None world = Text2SqlWorld(self.schema) sql = [ 'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',', 'RESTAURANT', ';' ] action_sequence, valid_actions = world.get_action_sequence_and_all_actions( sql) grammar_state = GrammarStatelet(['statement'], valid_actions, Text2SqlParser.is_nonterminal, reverse_productions=True) for action in action_sequence: grammar_state = grammar_state.take_action(action) assert grammar_state._nonterminal_stack == [] # pylint: disable=protected-access
def test_take_action_crashes_with_mismatched_types(self): with pytest.raises(AssertionError): state = GrammarStatelet(["s"], {}, is_nonterminal) state.take_action("t -> identity")