class TestWikiTablesVariableFreeWorld(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        question_tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']
        ]
        self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv'
        self.table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.table_file, question_tokens)
        self.world = WikiTablesVariableFreeWorld(self.table_kg)

    def test_get_valid_actions_returns_correct_set(self):
        # This test is long, but worth it.  These are all of the valid actions in the grammar, and
        # we want to be sure they are what we expect.

        valid_actions = self.world.get_valid_actions()
        assert set(valid_actions.keys()) == {
            "<r,<l,s>>",
            "<r,<n,<l,r>>>",
            "<r,<l,r>>",
            "<r,<r,<l,n>>>",
            "<r,<s,<l,r>>>",
            "<n,<n,<n,d>>>",
            "<r,<d,<l,r>>>",
            "<r,<l,n>>",
            "<r,r>",
            "<r,n>",
            "d",
            "n",
            "s",
            "l",
            "r",
            "@start@",
        }

        check_productions_match(valid_actions['<r,<l,s>>'], ['mode', 'select'])

        check_productions_match(valid_actions['<r,<n,<l,r>>>'], [
            'filter_number_equals', 'filter_number_greater',
            'filter_number_greater_equals', 'filter_number_lesser',
            'filter_number_lesser_equals', 'filter_number_not_equals'
        ])

        check_productions_match(valid_actions['<r,<l,r>>'],
                                ['argmax', 'argmin', 'same_as'])

        check_productions_match(valid_actions['<r,<r,<l,n>>>'], ['diff'])

        check_productions_match(valid_actions['<r,<s,<l,r>>>'],
                                ['filter_in', 'filter_not_in'])

        check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date'])

        check_productions_match(valid_actions['<r,<d,<l,r>>>'], [
            'filter_date_equals', 'filter_date_greater',
            'filter_date_greater_equals', 'filter_date_lesser',
            'filter_date_lesser_equals', 'filter_date_not_equals'
        ])

        check_productions_match(valid_actions['<r,<l,n>>'],
                                ['average', 'max', 'min', 'sum'])

        check_productions_match(valid_actions['<r,r>'],
                                ['first', 'last', 'next', 'previous'])

        check_productions_match(valid_actions['<r,n>'], ['count'])

        # These are the columns in table, and are instance specific.
        check_productions_match(valid_actions['l'], [
            'fb:row.row.year', 'fb:row.row.league',
            'fb:row.row.avg_attendance', 'fb:row.row.division',
            'fb:row.row.regular_season', 'fb:row.row.playoffs',
            'fb:row.row.open_cup'
        ])

        check_productions_match(valid_actions['@start@'], ['d', 'n', 's'])

        # We merged cells and parts in SEMPRE to strings in this grammar.
        check_productions_match(valid_actions['s'], [
            'fb:cell.2', 'fb:cell.2001', 'fb:cell.2005', 'fb:cell.4th_round',
            'fb:cell.4th_western', 'fb:cell.5th', 'fb:cell.6_028',
            'fb:cell.7_169', 'fb:cell.did_not_qualify',
            'fb:cell.quarterfinals', 'fb:cell.usl_a_league',
            'fb:cell.usl_first_division', 'fb:part.4th', 'fb:part.western',
            'fb:part.5th', '[<r,<l,s>>, r, l]'
        ])

        check_productions_match(valid_actions['d'],
                                ['[<n,<n,<n,d>>>, n, n, n]'])

        check_productions_match(valid_actions['n'], [
            '-1', '0', '1', '2013', '[<r,<l,n>>, r, l]',
            '[<r,<r,<l,n>>>, r, r, l]', '[<r,n>, r]'
        ])

        check_productions_match(valid_actions['r'], [
            'all_rows', '[<r,<d,<l,r>>>, r, d, l]', '[<r,<l,r>>, r, l]',
            '[<r,<n,<l,r>>>, r, n, l]', '[<r,<s,<l,r>>>, r, s, l]',
            '[<r,r>, r]'
        ])

    def test_world_processes_logical_forms_correctly(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F30(R,string:usl_a_league,C2),C6)"

    def test_world_gets_correct_actions(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        expected_sequence = [
            '@start@ -> s', 's -> [<r,<l,s>>, r, l]', '<r,<l,s>> -> select',
            'r -> [<r,<s,<l,r>>>, r, s, l]', '<r,<s,<l,r>>> -> filter_in',
            'r -> all_rows', 's -> fb:cell.usl_a_league',
            'l -> fb:row.row.league', 'l -> fb:row.row.year'
        ]
        assert self.world.get_action_sequence(expression) == expected_sequence

    def test_world_gets_logical_form_from_actions(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(expression)
        reconstructed_logical_form = self.world.get_logical_form(
            action_sequence)
        assert logical_form == reconstructed_logical_form

    def test_world_processes_logical_forms_with_number_correctly(self):
        logical_form = "(select (filter_number_greater all_rows 2013 fb:row.row.year) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F10(R,num:2013,C6),C6)"

    def test_world_processes_logical_forms_with_date_correctly(self):
        logical_form = "(select (filter_date_greater all_rows (date 2013 -1 -1) fb:row.row.year) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F20(R,T0(num:2013,num:~1,num:~1),C6),C6)"

    def _get_world_with_question_tokens(
            self, tokens: List[Token]) -> WikiTablesVariableFreeWorld:
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.table_file, tokens)
        world = WikiTablesVariableFreeWorld(table_kg)
        return world

    def test_get_agenda(self):
        tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'n -> 2000', 'l -> fb:row.row.year', '<r,<l,r>> -> argmax'
        }
        tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'difference', 'in', 'attendance',
                'between', 'years', '2001', 'and', '2005', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains strings here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == {
            's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year',
            '<r,<r,<l,n>>> -> diff'
        }
        tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'total', 'avg.', 'attendance', 'in',
                'years', '2001', 'and', '2005', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == {
            's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year',
            'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> sum'
        }
        tokens = [
            Token(x) for x in
            ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'l -> fb:row.row.avg_attendance', '<r,<l,r>> -> argmin'
        }
        tokens = [
            Token(x)
            for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> min'
        }
Exemple #2
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[List[str]],
            target_values: List[str],
            offline_search_output: List[str] = None) -> Instance:
        """
        Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as
        TSV files pre-tagged using CoreNLP, which we use for training.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``
        offline_search_output : List[str], optional
            List of logical forms, produced by offline search. Not required during test.
        """
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        # TODO(pradeep): We'll need a better way to input CoreNLP processed lines.
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        target_values_field = MetadataField(target_values)
        world = WikiTablesVariableFreeWorld(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'target_values': target_values_field
        }

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)