def evaluate(self, ques_rep: torch.Tensor, sampled_actions: List[str], slot_dic: Dict, target_list: List, world: WikiTablesVariableFreeWorld) -> Dict: _, _token_rnn_feat_size = ques_rep.size() assert self.token_rnn_feat_size == _token_rnn_feat_size id2column, column2id, column_type_dic, column_reps = self.collect_column_reps( world.table_context) actions = world.get_valid_actions() filtered_actions = self.filter_functions(actions) possible_paths = self.get_all_sequences(ques_rep, column2id, column_reps, \ sampled_actions, filtered_actions, \ slot_dic, world) max_path, max_score = possible_paths[0] for candidate_path, candidate_score in possible_paths[1:]: if candidate_score > max_score: max_path = candidate_path lf = world.get_logical_form(max_path) if world._executor.evaluate_logical_form(lf, target_list): return True else: return False
def forward(self, ques_rep: torch.Tensor, sampled_actions: List[str], slot_dic: Dict, target_list: List, world: WikiTablesVariableFreeWorld) -> Dict: """ It takes in a sampled path and finish the selection part based on alignments to the question, table and fileter/same_as function. Operations for selecting one row: filter_eq, filter_in Operations for selecting multiple rows: all filters and all_rows """ _, _token_rnn_feat_size = ques_rep.size() assert self.token_rnn_feat_size == _token_rnn_feat_size id2column, column2id, column_type_dic, column_reps = self.collect_column_reps( world.table_context) actions = world.get_valid_actions() filtered_actions = self.filter_functions(actions) possible_paths = self.get_all_sequences(ques_rep, column2id, column_reps, \ sampled_actions, filtered_actions, \ slot_dic, world) correct_lf = [] candidate_scores = [] gold_ids = [] for candidate_path, candidate_score in possible_paths: lf = world.get_logical_form(candidate_path) candidate_scores.append(candidate_score) if world._executor.evaluate_logical_form(lf, target_list): correct_lf.append(lf) gold_ids.append(1) else: gold_ids.append(0) gold_id_v = torch.FloatTensor(gold_ids) if torch.sum(gold_id_v) == 0: return 0 else: score_v = torch.stack(candidate_scores, 0) score_prob = F.softmax(score_v, 0) reward_v = gold_id_v * score_prob return torch.sum(reward_v, 0)
class TestWikiTablesVariableFreeWorld(AllenNlpTestCase): def setUp(self): super().setUp() question_tokens = [ Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?'] ] self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv' self.table_kg = TableQuestionKnowledgeGraph.read_from_file( self.table_file, question_tokens) self.world = WikiTablesVariableFreeWorld(self.table_kg) def test_get_valid_actions_returns_correct_set(self): # This test is long, but worth it. These are all of the valid actions in the grammar, and # we want to be sure they are what we expect. valid_actions = self.world.get_valid_actions() assert set(valid_actions.keys()) == { "<r,<l,s>>", "<r,<n,<l,r>>>", "<r,<l,r>>", "<r,<r,<l,n>>>", "<r,<s,<l,r>>>", "<n,<n,<n,d>>>", "<r,<d,<l,r>>>", "<r,<l,n>>", "<r,r>", "<r,n>", "d", "n", "s", "l", "r", "@start@", } check_productions_match(valid_actions['<r,<l,s>>'], ['mode', 'select']) check_productions_match(valid_actions['<r,<n,<l,r>>>'], [ 'filter_number_equals', 'filter_number_greater', 'filter_number_greater_equals', 'filter_number_lesser', 'filter_number_lesser_equals', 'filter_number_not_equals' ]) check_productions_match(valid_actions['<r,<l,r>>'], ['argmax', 'argmin', 'same_as']) check_productions_match(valid_actions['<r,<r,<l,n>>>'], ['diff']) check_productions_match(valid_actions['<r,<s,<l,r>>>'], ['filter_in', 'filter_not_in']) check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date']) check_productions_match(valid_actions['<r,<d,<l,r>>>'], [ 'filter_date_equals', 'filter_date_greater', 'filter_date_greater_equals', 'filter_date_lesser', 'filter_date_lesser_equals', 'filter_date_not_equals' ]) check_productions_match(valid_actions['<r,<l,n>>'], ['average', 'max', 'min', 'sum']) check_productions_match(valid_actions['<r,r>'], ['first', 'last', 'next', 'previous']) check_productions_match(valid_actions['<r,n>'], ['count']) # These are the columns in table, and are instance specific. check_productions_match(valid_actions['l'], [ 'fb:row.row.year', 'fb:row.row.league', 'fb:row.row.avg_attendance', 'fb:row.row.division', 'fb:row.row.regular_season', 'fb:row.row.playoffs', 'fb:row.row.open_cup' ]) check_productions_match(valid_actions['@start@'], ['d', 'n', 's']) # We merged cells and parts in SEMPRE to strings in this grammar. check_productions_match(valid_actions['s'], [ 'fb:cell.2', 'fb:cell.2001', 'fb:cell.2005', 'fb:cell.4th_round', 'fb:cell.4th_western', 'fb:cell.5th', 'fb:cell.6_028', 'fb:cell.7_169', 'fb:cell.did_not_qualify', 'fb:cell.quarterfinals', 'fb:cell.usl_a_league', 'fb:cell.usl_first_division', 'fb:part.4th', 'fb:part.western', 'fb:part.5th', '[<r,<l,s>>, r, l]' ]) check_productions_match(valid_actions['d'], ['[<n,<n,<n,d>>>, n, n, n]']) check_productions_match(valid_actions['n'], [ '-1', '0', '1', '2013', '[<r,<l,n>>, r, l]', '[<r,<r,<l,n>>>, r, r, l]', '[<r,n>, r]' ]) check_productions_match(valid_actions['r'], [ 'all_rows', '[<r,<d,<l,r>>>, r, d, l]', '[<r,<l,r>>, r, l]', '[<r,<n,<l,r>>>, r, n, l]', '[<r,<s,<l,r>>>, r, s, l]', '[<r,r>, r]' ]) def test_world_processes_logical_forms_correctly(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F30(R,string:usl_a_league,C2),C6)" def test_world_gets_correct_actions(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) expected_sequence = [ '@start@ -> s', 's -> [<r,<l,s>>, r, l]', '<r,<l,s>> -> select', 'r -> [<r,<s,<l,r>>>, r, s, l]', '<r,<s,<l,r>>> -> filter_in', 'r -> all_rows', 's -> fb:cell.usl_a_league', 'l -> fb:row.row.league', 'l -> fb:row.row.year' ] assert self.world.get_action_sequence(expression) == expected_sequence def test_world_gets_logical_form_from_actions(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) action_sequence = self.world.get_action_sequence(expression) reconstructed_logical_form = self.world.get_logical_form( action_sequence) assert logical_form == reconstructed_logical_form def test_world_processes_logical_forms_with_number_correctly(self): logical_form = "(select (filter_number_greater all_rows 2013 fb:row.row.year) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F10(R,num:2013,C6),C6)" def test_world_processes_logical_forms_with_date_correctly(self): logical_form = "(select (filter_date_greater all_rows (date 2013 -1 -1) fb:row.row.year) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F20(R,T0(num:2013,num:~1,num:~1),C6),C6)" def _get_world_with_question_tokens( self, tokens: List[Token]) -> WikiTablesVariableFreeWorld: table_kg = TableQuestionKnowledgeGraph.read_from_file( self.table_file, tokens) world = WikiTablesVariableFreeWorld(table_kg) return world def test_get_agenda(self): tokens = [ Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'n -> 2000', 'l -> fb:row.row.year', '<r,<l,r>> -> argmax' } tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'difference', 'in', 'attendance', 'between', 'years', '2001', 'and', '2005', '?' ] ] world = self._get_world_with_question_tokens(tokens) # The agenda contains strings here instead of numbers because 2001 and 2005 actually link to # entities in the table whereas 2000 (in the previous case) does not. assert set(world.get_agenda()) == { 's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year', '<r,<r,<l,n>>> -> diff' } tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'total', 'avg.', 'attendance', 'in', 'years', '2001', 'and', '2005', '?' ] ] world = self._get_world_with_question_tokens(tokens) # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to # entities in the table whereas 2000 (in the previous case) does not. assert set(world.get_agenda()) == { 's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year', 'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> sum' } tokens = [ Token(x) for x in ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'l -> fb:row.row.avg_attendance', '<r,<l,r>> -> argmin' } tokens = [ Token(x) for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> min' }
def get_all_skethch_lf(self, actions: Dict, prod_score_dic: Dict, world: WikiTablesVariableFreeWorld) -> List: paths = self.get_all_sketches(actions, prod_score_dic, world) logical_forms = [(world.get_logical_form(path), score) for (path, score) in paths] return logical_forms