Exemple #1
0
 def _create_grammar_state(
         world: WikiTablesWorld,
         possible_actions: List[ProductionRuleArray]) -> GrammarState:
     valid_actions = world.get_valid_actions()
     action_mapping = {}
     for i, action in enumerate(possible_actions):
         action_string = action[0]
         action_mapping[action_string] = i
     translated_valid_actions = {}
     for key, action_strings in valid_actions.items():
         translated_valid_actions[key] = [
             action_mapping[action_string]
             for action_string in action_strings
         ]
     return GrammarState([START_SYMBOL], {}, translated_valid_actions,
                         action_mapping, type_declaration.is_nonterminal)
 def _create_grammar_state(world: WikiTablesWorld,
                           possible_actions: List[ProductionRuleArray]) -> GrammarState:
     valid_actions = world.get_valid_actions()
     action_mapping = {}
     for i, action in enumerate(possible_actions):
         action_string = action[0]
         action_mapping[action_string] = i
     translated_valid_actions = {}
     for key, action_strings in valid_actions.items():
         translated_valid_actions[key] = [action_mapping[action_string]
                                          for action_string in action_strings]
     return GrammarState([START_SYMBOL],
                         {},
                         translated_valid_actions,
                         action_mapping,
                         type_declaration.is_nonterminal)
    def test_world_adds_numbers_from_question(self):
        question_tokens = [Token(x) for x in ['what', '2007', '2,107', '0.2', '1800s', '1950s', '?']]
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
                self.FIXTURES_ROOT / "data" / "wikitables" / "sample_table.tsv", question_tokens)
        world = WikiTablesWorld(table_kg)
        valid_actions = world.get_valid_actions()
        assert 'n -> 2007' in valid_actions['n']
        assert 'n -> 2107' in valid_actions['n']

        # It appears that sempre normalizes floating point numbers.
        assert 'n -> 0.200' in valid_actions['n']

        # We want to add the end-points to things like "1800s": 1800 and 1900.
        assert 'n -> 1800' in valid_actions['n']
        assert 'n -> 1900' in valid_actions['n']
        assert 'n -> 1950' in valid_actions['n']
        assert 'n -> 1960' in valid_actions['n']
Exemple #4
0
    def test_world_adds_numbers_from_question(self):
        question_tokens = [Token(x) for x in ['what', '2007', '2,107', '0.2', '1800s', '1950s', '?']]
        table_kg = TableQuestionKnowledgeGraph.read_from_file("tests/fixtures/data/wikitables/sample_table.tsv",
                                                              question_tokens)
        world = WikiTablesWorld(table_kg)
        valid_actions = world.get_valid_actions()
        assert 'n -> 2007' in valid_actions['n']
        assert 'n -> 2107' in valid_actions['n']

        # It appears that sempre normalizes floating point numbers.
        assert 'n -> 0.200' in valid_actions['n']

        # We want to add the end-points to things like "1800s": 1800 and 1900.
        assert 'n -> 1800' in valid_actions['n']
        assert 'n -> 1900' in valid_actions['n']
        assert 'n -> 1950' in valid_actions['n']
        assert 'n -> 1960' in valid_actions['n']
Exemple #5
0
    def _create_grammar_state(self, world: WikiTablesWorld,
                              possible_actions: List[ProductionRuleArray],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRuleArrays``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRuleArray]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(
                    global_action_tensor)
                global_input_embeddings = torch.cat(
                    [global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(
                global_action_tensor)
            translated_valid_actions[key]['global'] = (
                global_input_embeddings, global_output_embeddings,
                list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                    entity_type_tensor)
                translated_valid_actions[key]['linked'] = (
                    entity_linking_scores, entity_type_embeddings,
                    list(linked_action_ids))

        # Lastly, we need to also create embedded representations of context-specific actions.  In
        # this case, those are only variable productions, like "r -> x".  Note that our language
        # only permits one lambda at a time, so we don't need to worry about how nested lambdas
        # might impact this.
        context_actions = {}
        for action_id, action in enumerate(possible_actions):
            if action[0].endswith(" -> x"):
                input_embedding = self._action_embedder(action[2])
                if self._add_action_bias:
                    input_bias = self._action_biases(action[2])
                    input_embedding = torch.cat([input_embedding, input_bias],
                                                dim=-1)
                output_embedding = self._output_action_embedder(action[2])
                context_actions[action[0]] = (input_embedding,
                                              output_embedding, action_id)

        return GrammarStatelet([START_SYMBOL], {}, translated_valid_actions,
                               context_actions,
                               type_declaration.is_nonterminal)
Exemple #6
0
class TestWikiTablesWorld(AllenNlpTestCase):
    def setUp(self):
        super(TestWikiTablesWorld, self).setUp()
        question_tokens = [
            Token(x) for x in
            [u'what', u'was', u'the', u'last', u'year', u'2000', u'?']
        ]
        self.table_file = self.FIXTURES_ROOT / u'data' / u'wikitables' / u'sample_table.tsv'
        self.table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.table_file, question_tokens)
        self.world = WikiTablesWorld(self.table_kg)

    def test_get_valid_actions_returns_correct_set(self):
        # This test is long, but worth it.  These are all of the valid actions in the grammar, and
        # we want to be sure they are what we expect.

        # This test checks that our valid actions for each type match  PNP's, except for the
        # terminal productions for type 'p'.
        valid_actions = self.world.get_valid_actions()
        assert set(valid_actions.keys()) == set([
            u'<#1,#1>',
            u'<#1,<#1,#1>>',
            u'<#1,n>',
            u'<<#1,#2>,<#2,#1>>',
            u'<c,d>',
            u'<c,n>',
            u'<c,p>',
            u'<c,r>',
            u'<d,c>',
            u'<d,d>',
            u'<d,n>',
            u'<d,r>',
            u'<n,<n,<#1,<<#2,#1>,#1>>>>',
            u'<n,<n,<n,d>>>',
            u'<n,<n,n>>',
            u'<n,c>',
            u'<n,d>',
            u'<n,n>',
            u'<n,p>',
            u'<n,r>',
            u'<nd,nd>',
            u'<p,c>',
            u'<p,n>',
            u'<r,c>',
            u'<r,d>',
            u'<r,n>',
            u'<r,p>',
            u'<r,r>',
            u'@start@',
            u'c',
            u'd',
            u'n',
            u'p',
            u'r',
        ])

        check_productions_match(valid_actions[u'<#1,#1>'], [u'!='])

        check_productions_match(valid_actions[u'<#1,<#1,#1>>'],
                                [u'and', u'or'])

        check_productions_match(valid_actions[u'<#1,n>'], [u'count'])

        check_productions_match(valid_actions[u'<<#1,#2>,<#2,#1>>'],
                                [u'reverse'])

        check_productions_match(
            valid_actions[u'<c,d>'],
            [u"['lambda x', d]", u'[<<#1,#2>,<#2,#1>>, <d,c>]'])

        check_productions_match(
            valid_actions[u'<c,n>'],
            [u"['lambda x', n]", u'[<<#1,#2>,<#2,#1>>, <n,c>]'])

        check_productions_match(valid_actions[u'<c,p>'],
                                [u'[<<#1,#2>,<#2,#1>>, <p,c>]'])

        # Most of these are instance-specific production rules.  These are the columns in the
        # table.  Remember that SEMPRE did things backwards: fb:row.row.division takes a cell ID
        # and returns the row that has that cell in its row.division column.  This is why we have
        # to reverse all of these functions to go from a row to the cell in a particular column.
        check_productions_match(
            valid_actions[u'<c,r>'],
            [
                u'fb:row.row.null',  # This one is global, representing an empty set.
                u'fb:row.row.year',
                u'fb:row.row.league',
                u'fb:row.row.avg_attendance',
                u'fb:row.row.division',
                u'fb:row.row.regular_season',
                u'fb:row.row.playoffs',
                u'fb:row.row.open_cup'
            ])

        # These might look backwards, but that's because SEMPRE chose to make them backwards.
        # fb:a.b is a function that takes b and returns a.  So fb:cell.cell.date takes cell.date
        # and returns cell and fb:row.row.index takes row.index and returns row.
        check_productions_match(
            valid_actions[u'<d,c>'],
            [u'fb:cell.cell.date', u'[<<#1,#2>,<#2,#1>>, <c,d>]'])

        check_productions_match(
            valid_actions[u'<d,d>'],
            [u"['lambda x', d]", u'[<<#1,#2>,<#2,#1>>, <d,d>]'])

        check_productions_match(
            valid_actions[u'<d,n>'],
            [u"['lambda x', n]", u'[<<#1,#2>,<#2,#1>>, <n,d>]'])

        check_productions_match(valid_actions[u'<d,r>'],
                                [u'[<<#1,#2>,<#2,#1>>, <r,d>]'])

        check_productions_match(valid_actions[u'<n,<n,<#1,<<#2,#1>,#1>>>>'],
                                [u'argmax', u'argmin'])

        # "date" is a function that takes three numbers: (date 2018 01 06).
        check_productions_match(valid_actions[u'<n,<n,<n,d>>>'], [u'date'])

        check_productions_match(valid_actions[u'<n,<n,n>>'], [u'-'])

        check_productions_match(valid_actions[u'<n,c>'], [
            u'fb:cell.cell.num2', u'fb:cell.cell.number',
            u'[<<#1,#2>,<#2,#1>>, <c,n>]'
        ])

        check_productions_match(
            valid_actions[u'<n,d>'],
            [u"['lambda x', d]", u'[<<#1,#2>,<#2,#1>>, <d,n>]'])

        check_productions_match(valid_actions[u'<n,n>'], [
            u'avg', u'sum', u'number', u"['lambda x', n]",
            u'[<<#1,#2>,<#2,#1>>, <n,n>]'
        ])

        check_productions_match(valid_actions[u'<n,p>'],
                                [u'[<<#1,#2>,<#2,#1>>, <p,n>]'])

        check_productions_match(
            valid_actions[u'<n,r>'],
            [u'fb:row.row.index', u'[<<#1,#2>,<#2,#1>>, <r,n>]'])

        check_productions_match(valid_actions[u'<nd,nd>'],
                                [u'<', u'<=', u'>', u'>=', u'min', u'max'])

        # PART_TYPE rules.  A cell part is for when a cell has text that can be split into multiple
        # parts.
        check_productions_match(valid_actions[u'<p,c>'],
                                [u'fb:cell.cell.part'])

        check_productions_match(valid_actions[u'<p,n>'], [u"['lambda x', n]"])

        check_productions_match(valid_actions[u'<r,c>'],
                                [u'[<<#1,#2>,<#2,#1>>, <c,r>]'])

        check_productions_match(valid_actions[u'<r,d>'], [u"['lambda x', d]"])

        check_productions_match(
            valid_actions[u'<r,n>'],
            [u"['lambda x', n]", u'[<<#1,#2>,<#2,#1>>, <n,r>]'])

        check_productions_match(
            valid_actions[u'<r,p>'],
            [u"['lambda x', p]", u'[<<#1,#2>,<#2,#1>>, <p,r>]'])

        check_productions_match(valid_actions[u'<r,r>'], [
            u'fb:row.row.next', u'fb:type.object.type',
            u'[<<#1,#2>,<#2,#1>>, <r,r>]'
        ])

        check_productions_match(valid_actions[u'@start@'],
                                [u'd', u'c', u'p', u'r', u'n'])

        check_productions_match(valid_actions[u'c'], [
            u'[<#1,#1>, c]', u'[<#1,<#1,#1>>, c, c]',
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, c, <n,c>]',
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, c, <d,c>]', u'[<d,c>, d]',
            u'[<n,c>, n]', u'[<p,c>, p]', u'[<r,c>, r]', u'fb:cell.null',
            u'fb:cell.2', u'fb:cell.2001', u'fb:cell.2005',
            u'fb:cell.4th_round', u'fb:cell.4th_western', u'fb:cell.5th',
            u'fb:cell.6_028', u'fb:cell.7_169', u'fb:cell.did_not_qualify',
            u'fb:cell.quarterfinals', u'fb:cell.usl_a_league',
            u'fb:cell.usl_first_division'
        ])

        check_productions_match(valid_actions[u'd'], [
            u'[<n,<n,<n,d>>>, n, n, n]', u'[<#1,#1>, d]',
            u'[<#1,<#1,#1>>, d, d]',
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, d, <d,d>]',
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, d, <n,d>]', u'[<c,d>, c]',
            u'[<nd,nd>, d]'
        ])

        check_productions_match(valid_actions[u'n'], [
            u'-1', u'0', u'1', u'2000', u'[<#1,#1>, n]',
            u'[<#1,<#1,#1>>, n, n]', u'[<#1,n>, c]', u'[<#1,n>, d]',
            u'[<#1,n>, n]', u'[<#1,n>, p]', u'[<#1,n>, r]', u'[<c,n>, c]',
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, n, <d,n>]',
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, n, <n,n>]',
            u'[<n,<n,n>>, n, n]', u'[<n,n>, n]', u'[<nd,nd>, n]', u'[<r,n>, r]'
        ])

        check_productions_match(valid_actions[u'p'], [
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, p, <n,p>]', u'[<#1,#1>, p]',
            u'[<c,p>, c]', u'[<#1,<#1,#1>>, p, p]', u'fb:part.4th',
            u'fb:part.5th', u'fb:part.western'
        ])

        check_productions_match(valid_actions[u'r'], [
            u'fb:type.row', u'[<#1,#1>, r]', u'[<#1,<#1,#1>>, r, r]',
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, r, <d,r>]',
            u'[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, r, <n,r>]', u'[<n,r>, n]',
            u'[<c,r>, c]', u'[<r,r>, r]'
        ])

    def test_world_processes_sempre_forms_correctly(self):
        sempre_form = u"((reverse fb:row.row.year) (fb:row.row.league fb:cell.usl_a_league))"
        expression = self.world.parse_logical_form(sempre_form)
        # We add columns to the name mapping in sorted order, so "league" and "year" end up as C2
        # and C6.
        assert unicode(expression) == u"R(C6,C2(cell:usl_a_league))"

    def test_world_parses_logical_forms_with_dates(self):
        sempre_form = u"((reverse fb:row.row.league) (fb:row.row.year (fb:cell.cell.date (date 2000 -1 -1))))"
        expression = self.world.parse_logical_form(sempre_form)
        assert unicode(
            expression) == u"R(C2,C6(D1(D0(num:2000,num:~1,num:~1))))"

    def test_world_parses_logical_forms_with_decimals(self):
        question_tokens = [Token(x) for x in [u'0.2']]
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.FIXTURES_ROOT / u"data" / u"wikitables" / u"sample_table.tsv",
            question_tokens)
        world = WikiTablesWorld(table_kg)
        sempre_form = u"(fb:cell.cell.number (number 0.200))"
        expression = world.parse_logical_form(sempre_form)
        assert unicode(expression) == u"I1(I(num:0_200))"

    def test_get_action_sequence_removes_currying_for_all_wikitables_functions(
            self):
        # minus
        logical_form = u"(- (number 0) (number 1))"
        parsed_logical_form = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(parsed_logical_form)
        assert u'n -> [<n,<n,n>>, n, n]' in action_sequence

        # date
        logical_form = u"(count (fb:cell.cell.date (date 2000 -1 -1)))"
        parsed_logical_form = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(parsed_logical_form)
        assert u'd -> [<n,<n,<n,d>>>, n, n, n]' in action_sequence

        # argmax
        logical_form = (
            u"(argmax (number 1) (number 1) (fb:row.row.division fb:cell.2) "
            u"(reverse (lambda x ((reverse fb:row.row.index) (var x))))")
        parsed_logical_form = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(parsed_logical_form)
        assert u'r -> [<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, r, <n,r>]' in action_sequence

        # and
        logical_form = u"(and (number 1) (number 1))"
        parsed_logical_form = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(parsed_logical_form)
        assert u'n -> [<#1,<#1,#1>>, n, n]' in action_sequence

    def test_parsing_logical_forms_fails_with_unmapped_names(self):
        with pytest.raises(ParsingError):
            _ = self.world.parse_logical_form(u"(number 20)")

    def test_world_has_only_basic_numbers(self):
        valid_actions = self.world.get_valid_actions()
        assert u'n -> -1' in valid_actions[u'n']
        assert u'n -> 0' in valid_actions[u'n']
        assert u'n -> 1' in valid_actions[u'n']
        assert u'n -> 17' not in valid_actions[u'n']
        assert u'n -> 231' not in valid_actions[u'n']
        assert u'n -> 2007' not in valid_actions[u'n']
        assert u'n -> 2107' not in valid_actions[u'n']
        assert u'n -> 1800' not in valid_actions[u'n']

    def test_world_adds_numbers_from_question(self):
        question_tokens = [
            Token(x) for x in
            [u'what', u'2007', u'2,107', u'0.2', u'1800s', u'1950s', u'?']
        ]
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.FIXTURES_ROOT / u"data" / u"wikitables" / u"sample_table.tsv",
            question_tokens)
        world = WikiTablesWorld(table_kg)
        valid_actions = world.get_valid_actions()
        assert u'n -> 2007' in valid_actions[u'n']
        assert u'n -> 2107' in valid_actions[u'n']

        # It appears that sempre normalizes floating point numbers.
        assert u'n -> 0.200' in valid_actions[u'n']

        # We want to add the end-points to things like "1800s": 1800 and 1900.
        assert u'n -> 1800' in valid_actions[u'n']
        assert u'n -> 1900' in valid_actions[u'n']
        assert u'n -> 1950' in valid_actions[u'n']
        assert u'n -> 1960' in valid_actions[u'n']

    def test_world_returns_correct_actions_with_reverse(self):
        sempre_form = u"((reverse fb:row.row.year) (fb:row.row.league fb:cell.usl_a_league))"
        expression = self.world.parse_logical_form(sempre_form)
        actions = self.world.get_action_sequence(expression)
        target_action_sequence = [
            u'@start@ -> c', u'c -> [<r,c>, r]',
            u'<r,c> -> [<<#1,#2>,<#2,#1>>, <c,r>]',
            u'<<#1,#2>,<#2,#1>> -> reverse', u'<c,r> -> fb:row.row.year',
            u'r -> [<c,r>, c]', u'<c,r> -> fb:row.row.league',
            u'c -> fb:cell.usl_a_league'
        ]
        assert actions == target_action_sequence

    def test_world_returns_correct_actions_with_two_reverses(self):
        sempre_form = (
            u"(max ((reverse fb:cell.cell.date) ((reverse fb:row.row.year) "
            u"(fb:row.row.league fb:cell.usl_a_league))))")
        expression = self.world.parse_logical_form(sempre_form)
        actions = self.world.get_action_sequence(expression)
        target_action_sequence = [
            u'@start@ -> d', u'd -> [<nd,nd>, d]', u'<nd,nd> -> max',
            u'd -> [<c,d>, c]', u'<c,d> -> [<<#1,#2>,<#2,#1>>, <d,c>]',
            u'<<#1,#2>,<#2,#1>> -> reverse', u'<d,c> -> fb:cell.cell.date',
            u'c -> [<r,c>, r]', u'<r,c> -> [<<#1,#2>,<#2,#1>>, <c,r>]',
            u'<<#1,#2>,<#2,#1>> -> reverse', u'<c,r> -> fb:row.row.year',
            u'r -> [<c,r>, c]', u'<c,r> -> fb:row.row.league',
            u'c -> fb:cell.usl_a_league'
        ]
        assert actions == target_action_sequence

    def test_world_returns_correct_actions_with_lambda_with_var(self):
        sempre_form = (
            u"((reverse fb:cell.cell.date) ((reverse fb:row.row.year) (argmax (number 1) "
            u"(number 1) (fb:row.row.league fb:cell.usl_a_league) (reverse (lambda x "
            u"((reverse fb:row.row.index) (var x)))))))")
        expression = self.world.parse_logical_form(sempre_form,
                                                   remove_var_function=False)
        actions_with_var = self.world.get_action_sequence(expression)
        assert u'<#1,#1> -> var' in actions_with_var
        assert u'r -> x' in actions_with_var

    def test_world_returns_correct_actions_with_lambda_without_var(self):
        sempre_form = (
            u"((reverse fb:cell.cell.date) ((reverse fb:row.row.year) (argmax (number 1) "
            u"(number 1) (fb:row.row.league fb:cell.usl_a_league) (reverse (lambda x "
            u"((reverse fb:row.row.index) (var x)))))))")
        expression = self.world.parse_logical_form(sempre_form)
        actions_without_var = self.world.get_action_sequence(expression)
        assert u'<#1,#1> -> var' not in actions_without_var
        assert u'r -> x' in actions_without_var

    @pytest.mark.skip(reason=u"fibonacci recursion currently going on here")
    def test_with_deeply_nested_logical_form(self):
        question_tokens = [
            Token(x) for x in [u'what', u'was', u'the', u'district', u'?']
        ]
        table_filename = self.FIXTURES_ROOT / u'data' / u'wikitables' / u'table' / u'109.tsv'
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            table_filename, question_tokens)
        world = WikiTablesWorld(table_kg)
        logical_form = (
            u"(count ((reverse fb:cell.cell.number) (or (or (or (or (or (or (or (or "
            u"(or (or (or (or (or (or (or (or (or (or (or (or (or fb:cell.virginia_1 "
            u"fb:cell.virginia_10) fb:cell.virginia_11) fb:cell.virginia_12) "
            u"fb:cell.virginia_13) fb:cell.virginia_14) fb:cell.virginia_15) "
            u"fb:cell.virginia_16) fb:cell.virginia_17) fb:cell.virginia_18) "
            u"fb:cell.virginia_19) fb:cell.virginia_2) fb:cell.virginia_20) "
            u"fb:cell.virginia_21) fb:cell.virginia_22) fb:cell.virginia_3) "
            u"fb:cell.virginia_4) fb:cell.virginia_5) fb:cell.virginia_6) "
            u"fb:cell.virginia_7) fb:cell.virginia_8) fb:cell.virginia_9)))")
        print(u"Parsing...")
        world.parse_logical_form(logical_form)

    def _get_world_with_question_tokens(self, tokens):
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.table_file, tokens)
        world = WikiTablesWorld(table_kg)
        return world

    def test_get_agenda(self):
        tokens = [
            Token(x) for x in
            [u'what', u'was', u'the', u'last', u'year', u'2000', u'?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == set([
            u'n -> 2000', u'<c,r> -> fb:row.row.year',
            u'<n,<n,<#1,<<#2,#1>,#1>>>> -> argmax'
        ])
        tokens = [
            Token(x) for x in [
                u'what', u'was', u'the', u'difference', u'in', u'attendance',
                u'between', u'years', u'2001', u'and', u'2005', u'?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == set([
            u'c -> fb:cell.2001', u'c -> fb:cell.2005',
            u'<c,r> -> fb:row.row.year', u'<n,<n,n>> -> -'
        ])
        tokens = [
            Token(x) for x in [
                u'what', u'was', u'the', u'total', u'avg.', u'attendance',
                u'in', u'years', u'2001', u'and', u'2005', u'?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == set([
            u'c -> fb:cell.2001', u'c -> fb:cell.2005',
            u'<c,r> -> fb:row.row.year', u'<c,r> -> fb:row.row.avg_attendance',
            u'<n,n> -> sum'
        ])
        tokens = [
            Token(x) for x in
            [u'when', u'was', u'the', u'least', u'avg.', u'attendance', u'?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == set([
            u'<c,r> -> fb:row.row.avg_attendance',
            u'<n,<n,<#1,<<#2,#1>,#1>>>> -> argmin'
        ])
        tokens = [
            Token(x) for x in
            [u'what', u'is', u'the', u'least', u'avg.', u'attendance', u'?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == set(
            [u'<c,r> -> fb:row.row.avg_attendance', u'<nd,nd> -> min'])
    def _create_grammar_state(self,
                              world: WikiTablesWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> LambdaGrammarStatelet:
        """
        This method creates the LambdaGrammarStatelet object that's used for decoding.  Part of
        creating that is creating the `valid_actions` dictionary, which contains embedded
        representations of all of the valid actions.  So, we create that here as well.

        The way we represent the valid expansions is a little complicated: we use a
        dictionary of `action types`, where the key is the action type (like "global", "linked", or
        whatever your model is expecting), and the value is a tuple representing all actions of
        that type.  The tuple is (input tensor, output tensor, action id).  The input tensor has
        the representation that is used when `selecting` actions, for all actions of this type.
        The output tensor has the representation that is used when feeding the action to the next
        step of the decoder (this could just be the same as the input tensor).  The action ids are
        a list of indices into the main action list for each batch instance.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        # TODO(mattg): Move the "valid_actions" construction to another method.
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(global_action_tensor)
                global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_output_embeddings,
                                                       list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor)
                translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                           entity_type_embeddings,
                                                           list(linked_action_ids))

        # Lastly, we need to also create embedded representations of context-specific actions.  In
        # this case, those are only variable productions, like "r -> x".  Note that our language
        # only permits one lambda at a time, so we don't need to worry about how nested lambdas
        # might impact this.
        context_actions = {}
        for action_id, action in enumerate(possible_actions):
            if action[0].endswith(" -> x"):
                input_embedding = self._action_embedder(action[2])
                if self._add_action_bias:
                    input_bias = self._action_biases(action[2])
                    input_embedding = torch.cat([input_embedding, input_bias], dim=-1)
                output_embedding = self._output_action_embedder(action[2])
                context_actions[action[0]] = (input_embedding, output_embedding, action_id)

        return LambdaGrammarStatelet([START_SYMBOL],
                                     {},
                                     translated_valid_actions,
                                     context_actions,
                                     type_declaration.is_nonterminal)
class TestWikiTablesWorld(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        question_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']]
        self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv'
        self.table_kg = TableQuestionKnowledgeGraph.read_from_file(self.table_file, question_tokens)
        self.world = WikiTablesWorld(self.table_kg)

    def test_get_valid_actions_returns_correct_set(self):
        # This test is long, but worth it.  These are all of the valid actions in the grammar, and
        # we want to be sure they are what we expect.

        # This test checks that our valid actions for each type match  PNP's, except for the
        # terminal productions for type 'p'.
        valid_actions = self.world.get_valid_actions()
        assert set(valid_actions.keys()) == {
                '<#1,#1>',
                '<#1,<#1,#1>>',
                '<#1,n>',
                '<<#1,#2>,<#2,#1>>',
                '<c,d>',
                '<c,n>',
                '<c,p>',
                '<c,r>',
                '<d,c>',
                '<d,d>',
                '<d,n>',
                '<d,r>',
                '<n,<n,<#1,<<#2,#1>,#1>>>>',
                '<n,<n,<n,d>>>',
                '<n,<n,n>>',
                '<n,c>',
                '<n,d>',
                '<n,n>',
                '<n,p>',
                '<n,r>',
                '<nd,nd>',
                '<p,c>',
                '<p,n>',
                '<r,c>',
                '<r,d>',
                '<r,n>',
                '<r,p>',
                '<r,r>',
                '@start@',
                'c',
                'd',
                'n',
                'p',
                'r',
                }

        check_productions_match(valid_actions['<#1,#1>'],
                                ['!='])

        check_productions_match(valid_actions['<#1,<#1,#1>>'],
                                ['and', 'or'])

        check_productions_match(valid_actions['<#1,n>'],
                                ['count'])

        check_productions_match(valid_actions['<<#1,#2>,<#2,#1>>'],
                                ['reverse'])

        check_productions_match(valid_actions['<c,d>'],
                                ["['lambda x', d]", '[<<#1,#2>,<#2,#1>>, <d,c>]'])

        check_productions_match(valid_actions['<c,n>'],
                                ["['lambda x', n]", '[<<#1,#2>,<#2,#1>>, <n,c>]'])

        check_productions_match(valid_actions['<c,p>'],
                                ['[<<#1,#2>,<#2,#1>>, <p,c>]'])

        # Most of these are instance-specific production rules.  These are the columns in the
        # table.  Remember that SEMPRE did things backwards: fb:row.row.division takes a cell ID
        # and returns the row that has that cell in its row.division column.  This is why we have
        # to reverse all of these functions to go from a row to the cell in a particular column.
        check_productions_match(valid_actions['<c,r>'],
                                ['fb:row.row.null',  # This one is global, representing an empty set.
                                 'fb:row.row.year',
                                 'fb:row.row.league',
                                 'fb:row.row.avg_attendance',
                                 'fb:row.row.division',
                                 'fb:row.row.regular_season',
                                 'fb:row.row.playoffs',
                                 'fb:row.row.open_cup'])

        # These might look backwards, but that's because SEMPRE chose to make them backwards.
        # fb:a.b is a function that takes b and returns a.  So fb:cell.cell.date takes cell.date
        # and returns cell and fb:row.row.index takes row.index and returns row.
        check_productions_match(valid_actions['<d,c>'],
                                ['fb:cell.cell.date',
                                 '[<<#1,#2>,<#2,#1>>, <c,d>]'])

        check_productions_match(valid_actions['<d,d>'],
                                ["['lambda x', d]", '[<<#1,#2>,<#2,#1>>, <d,d>]'])

        check_productions_match(valid_actions['<d,n>'],
                                ["['lambda x', n]", '[<<#1,#2>,<#2,#1>>, <n,d>]'])

        check_productions_match(valid_actions['<d,r>'],
                                ['[<<#1,#2>,<#2,#1>>, <r,d>]'])

        check_productions_match(valid_actions['<n,<n,<#1,<<#2,#1>,#1>>>>'],
                                ['argmax', 'argmin'])

        # "date" is a function that takes three numbers: (date 2018 01 06).
        check_productions_match(valid_actions['<n,<n,<n,d>>>'],
                                ['date'])

        check_productions_match(valid_actions['<n,<n,n>>'],
                                ['-'])

        check_productions_match(valid_actions['<n,c>'],
                                ['fb:cell.cell.num2', 'fb:cell.cell.number',
                                 '[<<#1,#2>,<#2,#1>>, <c,n>]'])

        check_productions_match(valid_actions['<n,d>'],
                                ["['lambda x', d]", '[<<#1,#2>,<#2,#1>>, <d,n>]'])

        check_productions_match(valid_actions['<n,n>'],
                                ['avg', 'sum', 'number',
                                 "['lambda x', n]", '[<<#1,#2>,<#2,#1>>, <n,n>]'])

        check_productions_match(valid_actions['<n,p>'],
                                ['[<<#1,#2>,<#2,#1>>, <p,n>]'])

        check_productions_match(valid_actions['<n,r>'],
                                ['fb:row.row.index', '[<<#1,#2>,<#2,#1>>, <r,n>]'])

        check_productions_match(valid_actions['<nd,nd>'],
                                ['<', '<=', '>', '>=', 'min', 'max'])

        # PART_TYPE rules.  A cell part is for when a cell has text that can be split into multiple
        # parts.
        check_productions_match(valid_actions['<p,c>'],
                                ['fb:cell.cell.part'])

        check_productions_match(valid_actions['<p,n>'],
                                ["['lambda x', n]"])

        check_productions_match(valid_actions['<r,c>'],
                                ['[<<#1,#2>,<#2,#1>>, <c,r>]'])

        check_productions_match(valid_actions['<r,d>'],
                                ["['lambda x', d]"])

        check_productions_match(valid_actions['<r,n>'],
                                ["['lambda x', n]", '[<<#1,#2>,<#2,#1>>, <n,r>]'])

        check_productions_match(valid_actions['<r,p>'],
                                ["['lambda x', p]", '[<<#1,#2>,<#2,#1>>, <p,r>]'])

        check_productions_match(valid_actions['<r,r>'],
                                ['fb:row.row.next', 'fb:type.object.type', '[<<#1,#2>,<#2,#1>>, <r,r>]'])

        check_productions_match(valid_actions['@start@'],
                                ['d', 'c', 'p', 'r', 'n'])

        check_productions_match(valid_actions['c'],
                                ['[<#1,#1>, c]',
                                 '[<#1,<#1,#1>>, c, c]',
                                 '[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, c, <n,c>]',
                                 '[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, c, <d,c>]',
                                 '[<d,c>, d]',
                                 '[<n,c>, n]',
                                 '[<p,c>, p]',
                                 '[<r,c>, r]',
                                 'fb:cell.null',
                                 'fb:cell.2',
                                 'fb:cell.2001',
                                 'fb:cell.2005',
                                 'fb:cell.4th_round',
                                 'fb:cell.4th_western',
                                 'fb:cell.5th',
                                 'fb:cell.6_028',
                                 'fb:cell.7_169',
                                 'fb:cell.did_not_qualify',
                                 'fb:cell.quarterfinals',
                                 'fb:cell.usl_a_league',
                                 'fb:cell.usl_first_division'])

        check_productions_match(valid_actions['d'],
                                ['[<n,<n,<n,d>>>, n, n, n]',
                                 '[<#1,#1>, d]',
                                 '[<#1,<#1,#1>>, d, d]',
                                 '[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, d, <d,d>]',
                                 '[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, d, <n,d>]',
                                 '[<c,d>, c]',
                                 '[<nd,nd>, d]'])

        check_productions_match(valid_actions['n'],
                                ['-1',
                                 '0',
                                 '1',
                                 '2000',
                                 '[<#1,#1>, n]',
                                 '[<#1,<#1,#1>>, n, n]',
                                 '[<#1,n>, c]',
                                 '[<#1,n>, d]',
                                 '[<#1,n>, n]',
                                 '[<#1,n>, p]',
                                 '[<#1,n>, r]',
                                 '[<c,n>, c]',
                                 '[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, n, <d,n>]',
                                 '[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, n, <n,n>]',
                                 '[<n,<n,n>>, n, n]',
                                 '[<n,n>, n]',
                                 '[<nd,nd>, n]',
                                 '[<r,n>, r]'])

        check_productions_match(valid_actions['p'],
                                ['[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, p, <n,p>]',
                                 '[<#1,#1>, p]',
                                 '[<c,p>, c]',
                                 '[<#1,<#1,#1>>, p, p]',
                                 'fb:part.4th',
                                 'fb:part.5th',
                                 'fb:part.western'])

        check_productions_match(valid_actions['r'],
                                ['fb:type.row',
                                 '[<#1,#1>, r]',
                                 '[<#1,<#1,#1>>, r, r]',
                                 '[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, r, <d,r>]',
                                 '[<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, r, <n,r>]',
                                 '[<n,r>, n]',
                                 '[<c,r>, c]',
                                 '[<r,r>, r]'])

    def test_world_processes_sempre_forms_correctly(self):
        sempre_form = "((reverse fb:row.row.year) (fb:row.row.league fb:cell.usl_a_league))"
        expression = self.world.parse_logical_form(sempre_form)
        # We add columns to the name mapping in sorted order, so "league" and "year" end up as C2
        # and C6.
        f = types.name_mapper.get_alias
        assert str(expression) == f"{f('reverse')}(C6,C2(cell:usl_a_league))"

    def test_world_parses_logical_forms_with_dates(self):
        sempre_form = "((reverse fb:row.row.league) (fb:row.row.year (fb:cell.cell.date (date 2000 -1 -1))))"
        expression = self.world.parse_logical_form(sempre_form)
        f = types.name_mapper.get_alias
        assert str(expression) == \
                f"{f('reverse')}(C2,C6({f('fb:cell.cell.date')}({f('date')}(num:2000,num:~1,num:~1))))"

    def test_world_parses_logical_forms_with_decimals(self):
        question_tokens = [Token(x) for x in ['0.2']]
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
                self.FIXTURES_ROOT / "data" / "wikitables" / "sample_table.tsv", question_tokens)
        world = WikiTablesWorld(table_kg)
        sempre_form = "(fb:cell.cell.number (number 0.200))"
        expression = world.parse_logical_form(sempre_form)
        f = types.name_mapper.get_alias
        assert str(expression) == f"{f('fb:cell.cell.number')}({f('number')}(num:0_200))"

    def test_get_action_sequence_removes_currying_for_all_wikitables_functions(self):
        # minus
        logical_form = "(- (number 0) (number 1))"
        parsed_logical_form = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(parsed_logical_form)
        assert 'n -> [<n,<n,n>>, n, n]' in action_sequence

        # date
        logical_form = "(count (fb:cell.cell.date (date 2000 -1 -1)))"
        parsed_logical_form = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(parsed_logical_form)
        assert 'd -> [<n,<n,<n,d>>>, n, n, n]' in action_sequence

        # argmax
        logical_form = ("(argmax (number 1) (number 1) (fb:row.row.division fb:cell.2) "
                        "(reverse (lambda x ((reverse fb:row.row.index) (var x)))))")
        parsed_logical_form = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(parsed_logical_form)
        assert 'r -> [<n,<n,<#1,<<#2,#1>,#1>>>>, n, n, r, <n,r>]' in action_sequence

        # and
        logical_form = "(and (number 1) (number 1))"
        parsed_logical_form = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(parsed_logical_form)
        assert 'n -> [<#1,<#1,#1>>, n, n]' in action_sequence

    def test_parsing_logical_forms_fails_with_unmapped_names(self):
        with pytest.raises(ParsingError):
            _ = self.world.parse_logical_form("(number 20)")

    def test_world_has_only_basic_numbers(self):
        valid_actions = self.world.get_valid_actions()
        assert 'n -> -1' in valid_actions['n']
        assert 'n -> 0' in valid_actions['n']
        assert 'n -> 1' in valid_actions['n']
        assert 'n -> 17' not in valid_actions['n']
        assert 'n -> 231' not in valid_actions['n']
        assert 'n -> 2007' not in valid_actions['n']
        assert 'n -> 2107' not in valid_actions['n']
        assert 'n -> 1800' not in valid_actions['n']

    def test_world_adds_numbers_from_question(self):
        question_tokens = [Token(x) for x in ['what', '2007', '2,107', '0.2', '1800s', '1950s', '?']]
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
                self.FIXTURES_ROOT / "data" / "wikitables" / "sample_table.tsv", question_tokens)
        world = WikiTablesWorld(table_kg)
        valid_actions = world.get_valid_actions()
        assert 'n -> 2007' in valid_actions['n']
        assert 'n -> 2107' in valid_actions['n']

        # It appears that sempre normalizes floating point numbers.
        assert 'n -> 0.200' in valid_actions['n']

        # We want to add the end-points to things like "1800s": 1800 and 1900.
        assert 'n -> 1800' in valid_actions['n']
        assert 'n -> 1900' in valid_actions['n']
        assert 'n -> 1950' in valid_actions['n']
        assert 'n -> 1960' in valid_actions['n']

    def test_world_returns_correct_actions_with_reverse(self):
        sempre_form = "((reverse fb:row.row.year) (fb:row.row.league fb:cell.usl_a_league))"
        expression = self.world.parse_logical_form(sempre_form)
        actions = self.world.get_action_sequence(expression)
        target_action_sequence = ['@start@ -> c', 'c -> [<r,c>, r]', '<r,c> -> [<<#1,#2>,<#2,#1>>, <c,r>]',
                                  '<<#1,#2>,<#2,#1>> -> reverse', '<c,r> -> fb:row.row.year',
                                  'r -> [<c,r>, c]', '<c,r> -> fb:row.row.league', 'c -> fb:cell.usl_a_league']
        assert actions == target_action_sequence

    def test_world_returns_correct_actions_with_two_reverses(self):
        sempre_form = ("(max ((reverse fb:cell.cell.date) ((reverse fb:row.row.year) "
                       "(fb:row.row.league fb:cell.usl_a_league))))")
        expression = self.world.parse_logical_form(sempre_form)
        actions = self.world.get_action_sequence(expression)
        target_action_sequence = ['@start@ -> d', 'd -> [<nd,nd>, d]', '<nd,nd> -> max', 'd -> [<c,d>, c]',
                                  '<c,d> -> [<<#1,#2>,<#2,#1>>, <d,c>]', '<<#1,#2>,<#2,#1>> -> reverse',
                                  '<d,c> -> fb:cell.cell.date', 'c -> [<r,c>, r]',
                                  '<r,c> -> [<<#1,#2>,<#2,#1>>, <c,r>]', '<<#1,#2>,<#2,#1>> -> reverse',
                                  '<c,r> -> fb:row.row.year', 'r -> [<c,r>, c]',
                                  '<c,r> -> fb:row.row.league', 'c -> fb:cell.usl_a_league']
        assert actions == target_action_sequence

    def test_world_returns_correct_actions_with_lambda_with_var(self):
        sempre_form = ("((reverse fb:cell.cell.date) ((reverse fb:row.row.year) (argmax (number 1) "
                       "(number 1) (fb:row.row.league fb:cell.usl_a_league) (reverse (lambda x "
                       "((reverse fb:row.row.index) (var x)))))))")
        expression = self.world.parse_logical_form(sempre_form, remove_var_function=False)
        actions_with_var = self.world.get_action_sequence(expression)
        assert '<#1,#1> -> var' in actions_with_var
        assert 'r -> x' in actions_with_var

    def test_world_returns_correct_actions_with_lambda_without_var(self):
        sempre_form = ("((reverse fb:cell.cell.date) ((reverse fb:row.row.year) (argmax (number 1) "
                       "(number 1) (fb:row.row.league fb:cell.usl_a_league) (reverse (lambda x "
                       "((reverse fb:row.row.index) (var x)))))))")
        expression = self.world.parse_logical_form(sempre_form)
        actions_without_var = self.world.get_action_sequence(expression)
        assert '<#1,#1> -> var' not in actions_without_var
        assert 'r -> x' in actions_without_var

    @pytest.mark.skip(reason="fibonacci recursion currently going on here")
    def test_with_deeply_nested_logical_form(self):
        question_tokens = [Token(x) for x in ['what', 'was', 'the', 'district', '?']]
        table_filename = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'table' / '109.tsv'
        table_kg = TableQuestionKnowledgeGraph.read_from_file(table_filename, question_tokens)
        world = WikiTablesWorld(table_kg)
        logical_form = ("(count ((reverse fb:cell.cell.number) (or (or (or (or (or (or (or (or "
                        "(or (or (or (or (or (or (or (or (or (or (or (or (or fb:cell.virginia_1 "
                        "fb:cell.virginia_10) fb:cell.virginia_11) fb:cell.virginia_12) "
                        "fb:cell.virginia_13) fb:cell.virginia_14) fb:cell.virginia_15) "
                        "fb:cell.virginia_16) fb:cell.virginia_17) fb:cell.virginia_18) "
                        "fb:cell.virginia_19) fb:cell.virginia_2) fb:cell.virginia_20) "
                        "fb:cell.virginia_21) fb:cell.virginia_22) fb:cell.virginia_3) "
                        "fb:cell.virginia_4) fb:cell.virginia_5) fb:cell.virginia_6) "
                        "fb:cell.virginia_7) fb:cell.virginia_8) fb:cell.virginia_9)))")
        print("Parsing...")
        world.parse_logical_form(logical_form)

    def _get_world_with_question_tokens(self, tokens: List[Token]) -> WikiTablesWorld:
        table_kg = TableQuestionKnowledgeGraph.read_from_file(self.table_file, tokens)
        world = WikiTablesWorld(table_kg)
        return world

    def test_get_agenda(self):
        tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'n -> 2000',
                                           '<c,r> -> fb:row.row.year',
                                           '<n,<n,<#1,<<#2,#1>,#1>>>> -> argmax'}
        tokens = [Token(x) for x in ['what', 'was', 'the', 'difference', 'in', 'attendance',
                                     'between', 'years', '2001', 'and', '2005', '?']]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == {'c -> fb:cell.2001',
                                           'c -> fb:cell.2005',
                                           '<c,r> -> fb:row.row.year',
                                           '<n,<n,n>> -> -'}
        tokens = [Token(x) for x in ['what', 'was', 'the', 'total', 'avg.', 'attendance', 'in',
                                     'years', '2001', 'and', '2005', '?']]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == {'c -> fb:cell.2001',
                                           'c -> fb:cell.2005',
                                           '<c,r> -> fb:row.row.year',
                                           '<c,r> -> fb:row.row.avg_attendance',
                                           '<n,n> -> sum'}
        tokens = [Token(x) for x in ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<c,r> -> fb:row.row.avg_attendance',
                                           '<n,<n,<#1,<<#2,#1>,#1>>>> -> argmin'
                                          }
        tokens = [Token(x) for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<c,r> -> fb:row.row.avg_attendance',
                                           '<nd,nd> -> min'
                                          }