class TestWikiTablesVariableFreeWorld(AllenNlpTestCase): def setUp(self): super().setUp() question_tokens = [ Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?'] ] self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv' self.table_kg = TableQuestionKnowledgeGraph.read_from_file( self.table_file, question_tokens) self.world = WikiTablesVariableFreeWorld(self.table_kg) def test_get_valid_actions_returns_correct_set(self): # This test is long, but worth it. These are all of the valid actions in the grammar, and # we want to be sure they are what we expect. valid_actions = self.world.get_valid_actions() assert set(valid_actions.keys()) == { "<r,<l,s>>", "<r,<n,<l,r>>>", "<r,<l,r>>", "<r,<r,<l,n>>>", "<r,<s,<l,r>>>", "<n,<n,<n,d>>>", "<r,<d,<l,r>>>", "<r,<l,n>>", "<r,r>", "<r,n>", "d", "n", "s", "l", "r", "@start@", } check_productions_match(valid_actions['<r,<l,s>>'], ['mode', 'select']) check_productions_match(valid_actions['<r,<n,<l,r>>>'], [ 'filter_number_equals', 'filter_number_greater', 'filter_number_greater_equals', 'filter_number_lesser', 'filter_number_lesser_equals', 'filter_number_not_equals' ]) check_productions_match(valid_actions['<r,<l,r>>'], ['argmax', 'argmin', 'same_as']) check_productions_match(valid_actions['<r,<r,<l,n>>>'], ['diff']) check_productions_match(valid_actions['<r,<s,<l,r>>>'], ['filter_in', 'filter_not_in']) check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date']) check_productions_match(valid_actions['<r,<d,<l,r>>>'], [ 'filter_date_equals', 'filter_date_greater', 'filter_date_greater_equals', 'filter_date_lesser', 'filter_date_lesser_equals', 'filter_date_not_equals' ]) check_productions_match(valid_actions['<r,<l,n>>'], ['average', 'max', 'min', 'sum']) check_productions_match(valid_actions['<r,r>'], ['first', 'last', 'next', 'previous']) check_productions_match(valid_actions['<r,n>'], ['count']) # These are the columns in table, and are instance specific. check_productions_match(valid_actions['l'], [ 'fb:row.row.year', 'fb:row.row.league', 'fb:row.row.avg_attendance', 'fb:row.row.division', 'fb:row.row.regular_season', 'fb:row.row.playoffs', 'fb:row.row.open_cup' ]) check_productions_match(valid_actions['@start@'], ['d', 'n', 's']) # We merged cells and parts in SEMPRE to strings in this grammar. check_productions_match(valid_actions['s'], [ 'fb:cell.2', 'fb:cell.2001', 'fb:cell.2005', 'fb:cell.4th_round', 'fb:cell.4th_western', 'fb:cell.5th', 'fb:cell.6_028', 'fb:cell.7_169', 'fb:cell.did_not_qualify', 'fb:cell.quarterfinals', 'fb:cell.usl_a_league', 'fb:cell.usl_first_division', 'fb:part.4th', 'fb:part.western', 'fb:part.5th', '[<r,<l,s>>, r, l]' ]) check_productions_match(valid_actions['d'], ['[<n,<n,<n,d>>>, n, n, n]']) check_productions_match(valid_actions['n'], [ '-1', '0', '1', '2013', '[<r,<l,n>>, r, l]', '[<r,<r,<l,n>>>, r, r, l]', '[<r,n>, r]' ]) check_productions_match(valid_actions['r'], [ 'all_rows', '[<r,<d,<l,r>>>, r, d, l]', '[<r,<l,r>>, r, l]', '[<r,<n,<l,r>>>, r, n, l]', '[<r,<s,<l,r>>>, r, s, l]', '[<r,r>, r]' ]) def test_world_processes_logical_forms_correctly(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F30(R,string:usl_a_league,C2),C6)" def test_world_gets_correct_actions(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) expected_sequence = [ '@start@ -> s', 's -> [<r,<l,s>>, r, l]', '<r,<l,s>> -> select', 'r -> [<r,<s,<l,r>>>, r, s, l]', '<r,<s,<l,r>>> -> filter_in', 'r -> all_rows', 's -> fb:cell.usl_a_league', 'l -> fb:row.row.league', 'l -> fb:row.row.year' ] assert self.world.get_action_sequence(expression) == expected_sequence def test_world_gets_logical_form_from_actions(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) action_sequence = self.world.get_action_sequence(expression) reconstructed_logical_form = self.world.get_logical_form( action_sequence) assert logical_form == reconstructed_logical_form def test_world_processes_logical_forms_with_number_correctly(self): logical_form = "(select (filter_number_greater all_rows 2013 fb:row.row.year) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F10(R,num:2013,C6),C6)" def test_world_processes_logical_forms_with_date_correctly(self): logical_form = "(select (filter_date_greater all_rows (date 2013 -1 -1) fb:row.row.year) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F20(R,T0(num:2013,num:~1,num:~1),C6),C6)" def _get_world_with_question_tokens( self, tokens: List[Token]) -> WikiTablesVariableFreeWorld: table_kg = TableQuestionKnowledgeGraph.read_from_file( self.table_file, tokens) world = WikiTablesVariableFreeWorld(table_kg) return world def test_get_agenda(self): tokens = [ Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'n -> 2000', 'l -> fb:row.row.year', '<r,<l,r>> -> argmax' } tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'difference', 'in', 'attendance', 'between', 'years', '2001', 'and', '2005', '?' ] ] world = self._get_world_with_question_tokens(tokens) # The agenda contains strings here instead of numbers because 2001 and 2005 actually link to # entities in the table whereas 2000 (in the previous case) does not. assert set(world.get_agenda()) == { 's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year', '<r,<r,<l,n>>> -> diff' } tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'total', 'avg.', 'attendance', 'in', 'years', '2001', 'and', '2005', '?' ] ] world = self._get_world_with_question_tokens(tokens) # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to # entities in the table whereas 2000 (in the previous case) does not. assert set(world.get_agenda()) == { 's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year', 'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> sum' } tokens = [ Token(x) for x in ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'l -> fb:row.row.avg_attendance', '<r,<l,r>> -> argmin' } tokens = [ Token(x) for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> min' }
def text_to_instance( self, # type: ignore question: str, table_lines: List[List[str]], target_values: List[str], offline_search_output: List[str] = None) -> Instance: """ Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV files pre-tagged using CoreNLP, which we use for training. Parameters ---------- question : ``str`` Input question table_lines : ``List[List[str]]`` The table content preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines`` for the expected format. target_values : ``List[str]`` offline_search_output : List[str], optional List of logical forms, produced by offline search. Not required during test. """ # pylint: disable=arguments-differ tokenized_question = self._tokenizer.tokenize(question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) # TODO(pradeep): We'll need a better way to input CoreNLP processed lines. table_context = TableQuestionContext.read_from_lines( table_lines, tokenized_question) target_values_field = MetadataField(target_values) world = WikiTablesVariableFreeWorld(table_context) world_field = MetadataField(world) # Note: Not passing any featre extractors when instantiating the field below. This will make # it use all the available extractors. table_field = KnowledgeGraphField( table_context.get_table_knowledge_graph(), tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_instance_specific_entity( rule_right_side) field = ProductionRuleField(production_rule, is_global_rule=is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'table': table_field, 'world': world_field, 'actions': action_field, 'target_values': target_values_field } # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if offline_search_output: action_sequence_fields: List[Field] = [] for logical_form in offline_search_output: try: expression = world.parse_logical_form(logical_form) except ParsingError as error: logger.debug( f'Parsing error: {error.message}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') continue except: logger.error(logical_form) raise action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.debug( f'Missing production rule: {error.args}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') logger.debug(f'Logical form was: {logical_form}') continue if len(action_sequence_fields ) >= self._max_offline_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)