Example #1
0
    def evaluate(self, ques_rep: torch.Tensor, sampled_actions: List[str],
                 slot_dic: Dict, target_list: List,
                 world: WikiTablesVariableFreeWorld) -> Dict:
        _, _token_rnn_feat_size = ques_rep.size()
        assert self.token_rnn_feat_size == _token_rnn_feat_size
        id2column, column2id, column_type_dic, column_reps = self.collect_column_reps(
            world.table_context)

        actions = world.get_valid_actions()
        filtered_actions = self.filter_functions(actions)

        possible_paths = self.get_all_sequences(ques_rep, column2id, column_reps, \
                        sampled_actions, filtered_actions, \
                        slot_dic, world)

        max_path, max_score = possible_paths[0]
        for candidate_path, candidate_score in possible_paths[1:]:
            if candidate_score > max_score:
                max_path = candidate_path

        lf = world.get_logical_form(max_path)
        if world._executor.evaluate_logical_form(lf, target_list):
            return True
        else:
            return False
Example #2
0
    def forward(self, ques_rep: torch.Tensor, sampled_actions: List[str],
                slot_dic: Dict, target_list: List,
                world: WikiTablesVariableFreeWorld) -> Dict:
        """
        It takes in a sampled path and finish the selection part
        based on alignments to the question, table and fileter/same_as function.

        Operations for selecting one row: filter_eq, filter_in
        Operations for selecting multiple rows: all filters and all_rows
        """
        _, _token_rnn_feat_size = ques_rep.size()
        assert self.token_rnn_feat_size == _token_rnn_feat_size
        id2column, column2id, column_type_dic, column_reps = self.collect_column_reps(
            world.table_context)

        actions = world.get_valid_actions()
        filtered_actions = self.filter_functions(actions)

        possible_paths = self.get_all_sequences(ques_rep, column2id, column_reps, \
                        sampled_actions, filtered_actions, \
                        slot_dic, world)

        correct_lf = []
        candidate_scores = []
        gold_ids = []
        for candidate_path, candidate_score in possible_paths:
            lf = world.get_logical_form(candidate_path)
            candidate_scores.append(candidate_score)
            if world._executor.evaluate_logical_form(lf, target_list):
                correct_lf.append(lf)
                gold_ids.append(1)
            else:
                gold_ids.append(0)

        gold_id_v = torch.FloatTensor(gold_ids)
        if torch.sum(gold_id_v) == 0:
            return 0
        else:
            score_v = torch.stack(candidate_scores, 0)
            score_prob = F.softmax(score_v, 0)
            reward_v = gold_id_v * score_prob
            return torch.sum(reward_v, 0)
class TestWikiTablesVariableFreeWorld(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        question_tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']
        ]
        self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv'
        self.table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.table_file, question_tokens)
        self.world = WikiTablesVariableFreeWorld(self.table_kg)

    def test_get_valid_actions_returns_correct_set(self):
        # This test is long, but worth it.  These are all of the valid actions in the grammar, and
        # we want to be sure they are what we expect.

        valid_actions = self.world.get_valid_actions()
        assert set(valid_actions.keys()) == {
            "<r,<l,s>>",
            "<r,<n,<l,r>>>",
            "<r,<l,r>>",
            "<r,<r,<l,n>>>",
            "<r,<s,<l,r>>>",
            "<n,<n,<n,d>>>",
            "<r,<d,<l,r>>>",
            "<r,<l,n>>",
            "<r,r>",
            "<r,n>",
            "d",
            "n",
            "s",
            "l",
            "r",
            "@start@",
        }

        check_productions_match(valid_actions['<r,<l,s>>'], ['mode', 'select'])

        check_productions_match(valid_actions['<r,<n,<l,r>>>'], [
            'filter_number_equals', 'filter_number_greater',
            'filter_number_greater_equals', 'filter_number_lesser',
            'filter_number_lesser_equals', 'filter_number_not_equals'
        ])

        check_productions_match(valid_actions['<r,<l,r>>'],
                                ['argmax', 'argmin', 'same_as'])

        check_productions_match(valid_actions['<r,<r,<l,n>>>'], ['diff'])

        check_productions_match(valid_actions['<r,<s,<l,r>>>'],
                                ['filter_in', 'filter_not_in'])

        check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date'])

        check_productions_match(valid_actions['<r,<d,<l,r>>>'], [
            'filter_date_equals', 'filter_date_greater',
            'filter_date_greater_equals', 'filter_date_lesser',
            'filter_date_lesser_equals', 'filter_date_not_equals'
        ])

        check_productions_match(valid_actions['<r,<l,n>>'],
                                ['average', 'max', 'min', 'sum'])

        check_productions_match(valid_actions['<r,r>'],
                                ['first', 'last', 'next', 'previous'])

        check_productions_match(valid_actions['<r,n>'], ['count'])

        # These are the columns in table, and are instance specific.
        check_productions_match(valid_actions['l'], [
            'fb:row.row.year', 'fb:row.row.league',
            'fb:row.row.avg_attendance', 'fb:row.row.division',
            'fb:row.row.regular_season', 'fb:row.row.playoffs',
            'fb:row.row.open_cup'
        ])

        check_productions_match(valid_actions['@start@'], ['d', 'n', 's'])

        # We merged cells and parts in SEMPRE to strings in this grammar.
        check_productions_match(valid_actions['s'], [
            'fb:cell.2', 'fb:cell.2001', 'fb:cell.2005', 'fb:cell.4th_round',
            'fb:cell.4th_western', 'fb:cell.5th', 'fb:cell.6_028',
            'fb:cell.7_169', 'fb:cell.did_not_qualify',
            'fb:cell.quarterfinals', 'fb:cell.usl_a_league',
            'fb:cell.usl_first_division', 'fb:part.4th', 'fb:part.western',
            'fb:part.5th', '[<r,<l,s>>, r, l]'
        ])

        check_productions_match(valid_actions['d'],
                                ['[<n,<n,<n,d>>>, n, n, n]'])

        check_productions_match(valid_actions['n'], [
            '-1', '0', '1', '2013', '[<r,<l,n>>, r, l]',
            '[<r,<r,<l,n>>>, r, r, l]', '[<r,n>, r]'
        ])

        check_productions_match(valid_actions['r'], [
            'all_rows', '[<r,<d,<l,r>>>, r, d, l]', '[<r,<l,r>>, r, l]',
            '[<r,<n,<l,r>>>, r, n, l]', '[<r,<s,<l,r>>>, r, s, l]',
            '[<r,r>, r]'
        ])

    def test_world_processes_logical_forms_correctly(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F30(R,string:usl_a_league,C2),C6)"

    def test_world_gets_correct_actions(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        expected_sequence = [
            '@start@ -> s', 's -> [<r,<l,s>>, r, l]', '<r,<l,s>> -> select',
            'r -> [<r,<s,<l,r>>>, r, s, l]', '<r,<s,<l,r>>> -> filter_in',
            'r -> all_rows', 's -> fb:cell.usl_a_league',
            'l -> fb:row.row.league', 'l -> fb:row.row.year'
        ]
        assert self.world.get_action_sequence(expression) == expected_sequence

    def test_world_gets_logical_form_from_actions(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(expression)
        reconstructed_logical_form = self.world.get_logical_form(
            action_sequence)
        assert logical_form == reconstructed_logical_form

    def test_world_processes_logical_forms_with_number_correctly(self):
        logical_form = "(select (filter_number_greater all_rows 2013 fb:row.row.year) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F10(R,num:2013,C6),C6)"

    def test_world_processes_logical_forms_with_date_correctly(self):
        logical_form = "(select (filter_date_greater all_rows (date 2013 -1 -1) fb:row.row.year) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F20(R,T0(num:2013,num:~1,num:~1),C6),C6)"

    def _get_world_with_question_tokens(
            self, tokens: List[Token]) -> WikiTablesVariableFreeWorld:
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.table_file, tokens)
        world = WikiTablesVariableFreeWorld(table_kg)
        return world

    def test_get_agenda(self):
        tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'n -> 2000', 'l -> fb:row.row.year', '<r,<l,r>> -> argmax'
        }
        tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'difference', 'in', 'attendance',
                'between', 'years', '2001', 'and', '2005', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains strings here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == {
            's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year',
            '<r,<r,<l,n>>> -> diff'
        }
        tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'total', 'avg.', 'attendance', 'in',
                'years', '2001', 'and', '2005', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == {
            's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year',
            'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> sum'
        }
        tokens = [
            Token(x) for x in
            ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'l -> fb:row.row.avg_attendance', '<r,<l,r>> -> argmin'
        }
        tokens = [
            Token(x)
            for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> min'
        }
Example #4
0
 def get_all_skethch_lf(self, actions: Dict, prod_score_dic: Dict,
                        world: WikiTablesVariableFreeWorld) -> List:
     paths = self.get_all_sketches(actions, prod_score_dic, world)
     logical_forms = [(world.get_logical_form(path), score)
                      for (path, score) in paths]
     return logical_forms