Example #1
0
    def test_atis_grammar_statelet(self):
        world = AtisWorld([("give me all flights from boston to "
                            "philadelphia next week arriving after lunch")])
        action_sequence = [
            'statement -> [query, ";"]',
            'query -> ["(", "SELECT", distinct, select_results, "FROM", table_refs, '
            'where_clause, ")"]',
            'distinct -> ["DISTINCT"]',
            "select_results -> [col_refs]",
            'col_refs -> [col_ref, ",", col_refs]',
            'col_ref -> ["city", ".", "city_code"]',
            "col_refs -> [col_ref]",
            'col_ref -> ["city", ".", "city_name"]',
            "table_refs -> [table_name]",
            'table_name -> ["city"]',
            'where_clause -> ["WHERE", "(", conditions, ")"]',
            "conditions -> [condition]",
            "condition -> [biexpr]",
            'biexpr -> ["city", ".", "city_name", binaryop, city_city_name_string]',
            'binaryop -> ["="]',
            "city_city_name_string -> [\"'BOSTON'\"]",
        ]

        grammar_state = GrammarStatelet(["statement"], world.valid_actions,
                                        AtisSemanticParser.is_nonterminal)
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []
Example #2
0
    def test_grammar_statelet(self):
        valid_actions = None
        world = Text2SqlWorld(self.schema)

        sql = ["SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", ";"]
        action_sequence, valid_actions = world.get_action_sequence_and_all_actions(sql)

        grammar_state = GrammarStatelet(
            ["statement"], valid_actions, Text2SqlParser.is_nonterminal, reverse_productions=True
        )
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []
Example #3
0
    def test_grammar_statelet(self):
        valid_actions = None
        world = Text2SqlWorld(self.schema)

        sql = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
            'RESTAURANT', ';'
        ]
        action_sequence, valid_actions = world.get_action_sequence_and_all_actions(
            sql)

        grammar_state = GrammarStatelet(['statement'],
                                        valid_actions,
                                        Text2SqlParser.is_nonterminal,
                                        reverse_productions=True)
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []  # pylint: disable=protected-access
Example #4
0
 def test_take_action_crashes_with_mismatched_types(self):
     with pytest.raises(AssertionError):
         state = GrammarStatelet(["s"], {}, is_nonterminal)
         state.take_action("t -> identity")