class TestWikiTablesVariableFreeWorld(AllenNlpTestCase): def setUp(self): super().setUp() question_tokens = [ Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?'] ] self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv' self.table_kg = TableQuestionKnowledgeGraph.read_from_file( self.table_file, question_tokens) self.world = WikiTablesVariableFreeWorld(self.table_kg) def test_get_valid_actions_returns_correct_set(self): # This test is long, but worth it. These are all of the valid actions in the grammar, and # we want to be sure they are what we expect. valid_actions = self.world.get_valid_actions() assert set(valid_actions.keys()) == { "<r,<l,s>>", "<r,<n,<l,r>>>", "<r,<l,r>>", "<r,<r,<l,n>>>", "<r,<s,<l,r>>>", "<n,<n,<n,d>>>", "<r,<d,<l,r>>>", "<r,<l,n>>", "<r,r>", "<r,n>", "d", "n", "s", "l", "r", "@start@", } check_productions_match(valid_actions['<r,<l,s>>'], ['mode', 'select']) check_productions_match(valid_actions['<r,<n,<l,r>>>'], [ 'filter_number_equals', 'filter_number_greater', 'filter_number_greater_equals', 'filter_number_lesser', 'filter_number_lesser_equals', 'filter_number_not_equals' ]) check_productions_match(valid_actions['<r,<l,r>>'], ['argmax', 'argmin', 'same_as']) check_productions_match(valid_actions['<r,<r,<l,n>>>'], ['diff']) check_productions_match(valid_actions['<r,<s,<l,r>>>'], ['filter_in', 'filter_not_in']) check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date']) check_productions_match(valid_actions['<r,<d,<l,r>>>'], [ 'filter_date_equals', 'filter_date_greater', 'filter_date_greater_equals', 'filter_date_lesser', 'filter_date_lesser_equals', 'filter_date_not_equals' ]) check_productions_match(valid_actions['<r,<l,n>>'], ['average', 'max', 'min', 'sum']) check_productions_match(valid_actions['<r,r>'], ['first', 'last', 'next', 'previous']) check_productions_match(valid_actions['<r,n>'], ['count']) # These are the columns in table, and are instance specific. check_productions_match(valid_actions['l'], [ 'fb:row.row.year', 'fb:row.row.league', 'fb:row.row.avg_attendance', 'fb:row.row.division', 'fb:row.row.regular_season', 'fb:row.row.playoffs', 'fb:row.row.open_cup' ]) check_productions_match(valid_actions['@start@'], ['d', 'n', 's']) # We merged cells and parts in SEMPRE to strings in this grammar. check_productions_match(valid_actions['s'], [ 'fb:cell.2', 'fb:cell.2001', 'fb:cell.2005', 'fb:cell.4th_round', 'fb:cell.4th_western', 'fb:cell.5th', 'fb:cell.6_028', 'fb:cell.7_169', 'fb:cell.did_not_qualify', 'fb:cell.quarterfinals', 'fb:cell.usl_a_league', 'fb:cell.usl_first_division', 'fb:part.4th', 'fb:part.western', 'fb:part.5th', '[<r,<l,s>>, r, l]' ]) check_productions_match(valid_actions['d'], ['[<n,<n,<n,d>>>, n, n, n]']) check_productions_match(valid_actions['n'], [ '-1', '0', '1', '2013', '[<r,<l,n>>, r, l]', '[<r,<r,<l,n>>>, r, r, l]', '[<r,n>, r]' ]) check_productions_match(valid_actions['r'], [ 'all_rows', '[<r,<d,<l,r>>>, r, d, l]', '[<r,<l,r>>, r, l]', '[<r,<n,<l,r>>>, r, n, l]', '[<r,<s,<l,r>>>, r, s, l]', '[<r,r>, r]' ]) def test_world_processes_logical_forms_correctly(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F30(R,string:usl_a_league,C2),C6)" def test_world_gets_correct_actions(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) expected_sequence = [ '@start@ -> s', 's -> [<r,<l,s>>, r, l]', '<r,<l,s>> -> select', 'r -> [<r,<s,<l,r>>>, r, s, l]', '<r,<s,<l,r>>> -> filter_in', 'r -> all_rows', 's -> fb:cell.usl_a_league', 'l -> fb:row.row.league', 'l -> fb:row.row.year' ] assert self.world.get_action_sequence(expression) == expected_sequence def test_world_gets_logical_form_from_actions(self): logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) action_sequence = self.world.get_action_sequence(expression) reconstructed_logical_form = self.world.get_logical_form( action_sequence) assert logical_form == reconstructed_logical_form def test_world_processes_logical_forms_with_number_correctly(self): logical_form = "(select (filter_number_greater all_rows 2013 fb:row.row.year) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F10(R,num:2013,C6),C6)" def test_world_processes_logical_forms_with_date_correctly(self): logical_form = "(select (filter_date_greater all_rows (date 2013 -1 -1) fb:row.row.year) fb:row.row.year)" expression = self.world.parse_logical_form(logical_form) # Cells (and parts) get mapped to strings. assert str(expression) == "S0(F20(R,T0(num:2013,num:~1,num:~1),C6),C6)" def _get_world_with_question_tokens( self, tokens: List[Token]) -> WikiTablesVariableFreeWorld: table_kg = TableQuestionKnowledgeGraph.read_from_file( self.table_file, tokens) world = WikiTablesVariableFreeWorld(table_kg) return world def test_get_agenda(self): tokens = [ Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'n -> 2000', 'l -> fb:row.row.year', '<r,<l,r>> -> argmax' } tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'difference', 'in', 'attendance', 'between', 'years', '2001', 'and', '2005', '?' ] ] world = self._get_world_with_question_tokens(tokens) # The agenda contains strings here instead of numbers because 2001 and 2005 actually link to # entities in the table whereas 2000 (in the previous case) does not. assert set(world.get_agenda()) == { 's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year', '<r,<r,<l,n>>> -> diff' } tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'total', 'avg.', 'attendance', 'in', 'years', '2001', 'and', '2005', '?' ] ] world = self._get_world_with_question_tokens(tokens) # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to # entities in the table whereas 2000 (in the previous case) does not. assert set(world.get_agenda()) == { 's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year', 'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> sum' } tokens = [ Token(x) for x in ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'l -> fb:row.row.avg_attendance', '<r,<l,r>> -> argmin' } tokens = [ Token(x) for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> min' }
class TestWikiTablesVariableFreeWorld(AllenNlpTestCase): def setUp(self): super().setUp() question_tokens = [ Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?'] ] self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tagged' self.table_context = TableQuestionContext.read_from_file( self.table_file, question_tokens) self.world_with_2013 = WikiTablesVariableFreeWorld(self.table_context) usl_league_tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'last', 'year', 'with', 'usl', 'a', 'league', '?' ] ] self.world_with_usl_a_league = self._get_world_with_question_tokens( usl_league_tokens) def _get_world_with_question_tokens( self, tokens: List[Token]) -> WikiTablesVariableFreeWorld: table_context = TableQuestionContext.read_from_file( self.table_file, tokens) world = WikiTablesVariableFreeWorld(table_context) return world def test_get_valid_actions_returns_correct_set(self): # This test is long, but worth it. These are all of the valid actions in the grammar, and # we want to be sure they are what we expect. valid_actions = self.world_with_2013.get_valid_actions() assert set(valid_actions.keys()) == { "<r,<g,s>>", "<r,<f,<n,r>>>", "<r,<c,r>>", "<r,<g,r>>", "<r,<r,<f,n>>>", "<r,<t,<s,r>>>", "<n,<n,<n,d>>>", "<r,<m,<d,r>>>", "<r,<f,n>>", "<r,r>", "<r,n>", "d", "n", "s", "m", "t", "f", "r", "@start@", } check_productions_match(valid_actions['<r,<g,s>>'], ['mode', 'select']) check_productions_match(valid_actions['<r,<f,<n,r>>>'], [ 'filter_number_equals', 'filter_number_greater', 'filter_number_greater_equals', 'filter_number_lesser', 'filter_number_lesser_equals', 'filter_number_not_equals' ]) check_productions_match(valid_actions['<r,<c,r>>'], ['argmax', 'argmin']) check_productions_match(valid_actions['<r,<g,r>>'], ['same_as']) check_productions_match(valid_actions['<r,<r,<f,n>>>'], ['diff']) check_productions_match(valid_actions['<r,<t,<s,r>>>'], ['filter_in', 'filter_not_in']) check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date']) check_productions_match(valid_actions['<r,<m,<d,r>>>'], [ 'filter_date_equals', 'filter_date_greater', 'filter_date_greater_equals', 'filter_date_lesser', 'filter_date_lesser_equals', 'filter_date_not_equals' ]) check_productions_match(valid_actions['<r,<f,n>>'], ['average', 'max', 'min', 'sum']) check_productions_match(valid_actions['<r,r>'], ['first', 'last', 'next', 'previous']) check_productions_match(valid_actions['<r,n>'], ['count']) # These are the columns in table, and are instance specific. check_productions_match(valid_actions['m'], ['date_column:year']) check_productions_match( valid_actions['f'], ['number_column:avg_attendance', 'number_column:division']) check_productions_match(valid_actions['t'], [ 'string_column:league', 'string_column:playoffs', 'string_column:open_cup', 'string_column:regular_season' ]) check_productions_match(valid_actions['@start@'], ['d', 'n', 's']) # The question does not produce any strings. It produces just a number. check_productions_match(valid_actions['s'], ['[<r,<g,s>>, r, g]']) check_productions_match(valid_actions['d'], ['[<n,<n,<n,d>>>, n, n, n]']) check_productions_match(valid_actions['n'], [ '2013', '-1', '[<r,<f,n>>, r, f]', '[<r,<r,<f,n>>>, r, r, f]', '[<r,n>, r]' ]) check_productions_match(valid_actions['r'], [ 'all_rows', '[<r,<m,<d,r>>>, r, m, d]', '[<r,<g,r>>, r, g]', '[<r,<c,r>>, r, c]', '[<r,<f,<n,r>>>, r, f, n]', '[<r,<t,<s,r>>>, r, t, s]', '[<r,r>, r]' ]) def test_parsing_logical_form_with_string_not_in_question_fails(self): logical_form_with_usl_a_league = """(select (filter_in all_rows string_column:league usl_a_league) date_column:year)""" logical_form_with_2013 = """(select (filter_date_greater all_rows date_column:year (date 2013 -1 -1)) date_column:year)""" with self.assertRaises(ParsingError): self.world_with_2013.parse_logical_form( logical_form_with_usl_a_league) self.world_with_usl_a_league.parse_logical_form( logical_form_with_2013) def test_world_processes_logical_forms_correctly(self): logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)" expression = self.world_with_usl_a_league.parse_logical_form( logical_form) f = types.name_mapper.get_alias # Cells (and parts) get mapped to strings. # Column names are mapped in local name mapping. For the global names, we can get their # aliases from the name mapper. assert str( expression ) == f"{f('select')}({f('filter_in')}({f('all_rows')},C2,string:usl_a_league),C0)" def test_world_gets_correct_actions(self): logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)" expression = self.world_with_usl_a_league.parse_logical_form( logical_form) expected_sequence = [ '@start@ -> s', 's -> [<r,<g,s>>, r, m]', '<r,<g,s>> -> select', 'r -> [<r,<t,<s,r>>>, r, t, s]', '<r,<t,<s,r>>> -> filter_in', 'r -> all_rows', 't -> string_column:league', 's -> string:usl_a_league', 'm -> date_column:year' ] assert self.world_with_usl_a_league.get_action_sequence( expression) == expected_sequence def test_world_gets_logical_form_from_actions(self): logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)" expression = self.world_with_usl_a_league.parse_logical_form( logical_form) action_sequence = self.world_with_usl_a_league.get_action_sequence( expression) reconstructed_logical_form = self.world_with_usl_a_league.get_logical_form( action_sequence) assert logical_form == reconstructed_logical_form def test_world_processes_logical_forms_with_number_correctly(self): tokens = [ Token(x) for x in [ 'when', 'was', 'the', 'attendance', 'higher', 'than', '3000', '?' ] ] world = self._get_world_with_question_tokens(tokens) logical_form = """(select (filter_number_greater all_rows number_column:avg_attendance 3000) date_column:year)""" expression = world.parse_logical_form(logical_form) f = types.name_mapper.get_alias # Cells (and parts) get mapped to strings. # Column names are mapped in local name mapping. For the global names, we can get their # aliases from the name mapper. assert str( expression ) == f"{f('select')}({f('filter_number_greater')}({f('all_rows')},C6,num:3000),C0)" def test_world_processes_logical_forms_with_date_correctly(self): logical_form = """(select (filter_date_greater all_rows date_column:year (date 2013 -1 -1)) date_column:year)""" expression = self.world_with_2013.parse_logical_form(logical_form) f = types.name_mapper.get_alias # Cells (and parts) get mapped to strings. # Column names are mapped in local name mapping. For the global names, we can get their # aliases from the name mapper. assert str(expression) == \ f"{f('select')}({f('filter_date_greater')}({f('all_rows')},C0,{f('date')}(num:2013,num:~1,num:~1)),C0)" def test_get_agenda(self): tokens = [ Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'n -> 2000', '<r,r> -> last', 'm -> date_column:year' } tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'difference', 'in', 'attendance', 'between', 'years', '2001', 'and', '2005', '?' ] ] world = self._get_world_with_question_tokens(tokens) # "year" column does not match because "years" occurs in the question. assert set(world.get_agenda()) == { 'n -> 2001', 'n -> 2005', '<r,<r,<f,n>>> -> diff' } tokens = [ Token(x) for x in [ 'what', 'was', 'the', 'total', 'avg.', 'attendance', 'in', 'years', '2001', 'and', '2005', '?' ] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { 'n -> 2001', 'n -> 2005', '<r,<f,n>> -> sum', 'f -> number_column:avg_attendance' } tokens = [ Token(x) for x in ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { '<r,<c,r>> -> argmin', 'f -> number_column:avg_attendance' } tokens = [ Token(x) for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { '<r,<f,n>> -> min', 'f -> number_column:avg_attendance' } tokens = [ Token(x) for x in ['when', 'did', 'the', 'team', 'not', 'qualify', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'s -> string:qualify'} tokens = [ Token(x) for x in [ 'when', 'was', 'the', 'avg.', 'attendance', 'at', 'least', '7000', '?' ] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { '<r,<f,<n,r>>> -> filter_number_greater_equals', 'f -> number_column:avg_attendance', 'n -> 7000' } tokens = [ Token(x) for x in [ 'when', 'was', 'the', 'avg.', 'attendance', 'more', 'than', '7000', '?' ] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { '<r,<f,<n,r>>> -> filter_number_greater', 'f -> number_column:avg_attendance', 'n -> 7000' } tokens = [ Token(x) for x in [ 'when', 'was', 'the', 'avg.', 'attendance', 'at', 'most', '7000', '?' ] ] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == { '<r,<f,<n,r>>> -> filter_number_lesser_equals', 'f -> number_column:avg_attendance', 'n -> 7000' } tokens = [Token(x) for x in ['what', 'was', 'the', 'top', 'year', '?']] world = self._get_world_with_question_tokens(tokens) assert set( world.get_agenda()) == {'<r,r> -> first', 'm -> date_column:year'} tokens = [ Token(x) for x in ['what', 'was', 'the', 'year', 'in', 'the', 'bottom', 'row', '?'] ] world = self._get_world_with_question_tokens(tokens) assert set( world.get_agenda()) == {'<r,r> -> last', 'm -> date_column:year'}
def text_to_instance( self, # type: ignore question: str, table_lines: List[List[str]], target_values: List[str], offline_search_output: List[str] = None) -> Instance: """ Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV files pre-tagged using CoreNLP, which we use for training. Parameters ---------- question : ``str`` Input question table_lines : ``List[List[str]]`` The table content preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines`` for the expected format. target_values : ``List[str]`` offline_search_output : List[str], optional List of logical forms, produced by offline search. Not required during test. """ # pylint: disable=arguments-differ tokenized_question = self._tokenizer.tokenize(question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) # TODO(pradeep): We'll need a better way to input CoreNLP processed lines. table_context = TableQuestionContext.read_from_lines( table_lines, tokenized_question) target_values_field = MetadataField(target_values) world = WikiTablesVariableFreeWorld(table_context) world_field = MetadataField(world) # Note: Not passing any featre extractors when instantiating the field below. This will make # it use all the available extractors. table_field = KnowledgeGraphField( table_context.get_table_knowledge_graph(), tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_instance_specific_entity( rule_right_side) field = ProductionRuleField(production_rule, is_global_rule=is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'table': table_field, 'world': world_field, 'actions': action_field, 'target_values': target_values_field } # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if offline_search_output: action_sequence_fields: List[Field] = [] for logical_form in offline_search_output: try: expression = world.parse_logical_form(logical_form) except ParsingError as error: logger.debug( f'Parsing error: {error.message}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') continue except: logger.error(logical_form) raise action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.debug( f'Missing production rule: {error.args}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') logger.debug(f'Logical form was: {logical_form}') continue if len(action_sequence_fields ) >= self._max_offline_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
class TestWikiTablesVariableFreeWorld(AllenNlpTestCase): def setUp(self): super().setUp() question_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']] self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tagged' self.table_context = TableQuestionContext.read_from_file(self.table_file, question_tokens) self.world_with_2013 = WikiTablesVariableFreeWorld(self.table_context) usl_league_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', 'with', 'usl', 'a', 'league', '?']] self.world_with_usl_a_league = self._get_world_with_question_tokens(usl_league_tokens) def _get_world_with_question_tokens(self, tokens: List[Token]) -> WikiTablesVariableFreeWorld: table_context = TableQuestionContext.read_from_file(self.table_file, tokens) world = WikiTablesVariableFreeWorld(table_context) return world def test_get_valid_actions_returns_correct_set(self): # This test is long, but worth it. These are all of the valid actions in the grammar, and # we want to be sure they are what we expect. valid_actions = self.world_with_2013.get_valid_actions() assert set(valid_actions.keys()) == { "<r,<g,s>>", "<r,<f,<n,r>>>", "<r,<c,r>>", "<r,<g,r>>", "<r,<r,<f,n>>>", "<r,<t,<s,r>>>", "<n,<n,<n,d>>>", "<r,<m,<d,r>>>", "<r,<f,n>>", "<r,r>", "<r,n>", "d", "n", "s", "m", "t", "f", "r", "@start@", } check_productions_match(valid_actions['<r,<g,s>>'], ['mode', 'select']) check_productions_match(valid_actions['<r,<f,<n,r>>>'], ['filter_number_equals', 'filter_number_greater', 'filter_number_greater_equals', 'filter_number_lesser', 'filter_number_lesser_equals', 'filter_number_not_equals']) check_productions_match(valid_actions['<r,<c,r>>'], ['argmax', 'argmin']) check_productions_match(valid_actions['<r,<g,r>>'], ['same_as']) check_productions_match(valid_actions['<r,<r,<f,n>>>'], ['diff']) check_productions_match(valid_actions['<r,<t,<s,r>>>'], ['filter_in', 'filter_not_in']) check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date']) check_productions_match(valid_actions['<r,<m,<d,r>>>'], ['filter_date_equals', 'filter_date_greater', 'filter_date_greater_equals', 'filter_date_lesser', 'filter_date_lesser_equals', 'filter_date_not_equals']) check_productions_match(valid_actions['<r,<f,n>>'], ['average', 'max', 'min', 'sum']) check_productions_match(valid_actions['<r,r>'], ['first', 'last', 'next', 'previous']) check_productions_match(valid_actions['<r,n>'], ['count']) # These are the columns in table, and are instance specific. check_productions_match(valid_actions['m'], ['date_column:year']) check_productions_match(valid_actions['f'], ['number_column:avg_attendance', 'number_column:division']) check_productions_match(valid_actions['t'], ['string_column:league', 'string_column:playoffs', 'string_column:open_cup', 'string_column:regular_season']) check_productions_match(valid_actions['@start@'], ['d', 'n', 's']) # The question does not produce any strings. It produces just a number. check_productions_match(valid_actions['s'], ['[<r,<g,s>>, r, m]', '[<r,<g,s>>, r, f]', '[<r,<g,s>>, r, t]']) check_productions_match(valid_actions['d'], ['[<n,<n,<n,d>>>, n, n, n]']) check_productions_match(valid_actions['n'], ['2013', '-1', '[<r,<f,n>>, r, f]', '[<r,<r,<f,n>>>, r, r, f]', '[<r,n>, r]']) check_productions_match(valid_actions['r'], ['all_rows', '[<r,<m,<d,r>>>, r, m, d]', '[<r,<g,r>>, r, m]', '[<r,<g,r>>, r, f]', '[<r,<g,r>>, r, t]', '[<r,<c,r>>, r, m]', '[<r,<c,r>>, r, f]', '[<r,<f,<n,r>>>, r, f, n]', '[<r,<t,<s,r>>>, r, t, s]', '[<r,r>, r]']) def test_get_valid_actions_in_world_without_number_columns(self): question_tokens = [Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?']] table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-6.table' table_context = TableQuestionContext.read_from_file(table_file, question_tokens) # The table does not have a number column. assert "number" not in table_context.column_types.values() world = WikiTablesVariableFreeWorld(table_context) actions = world.get_valid_actions() assert set(actions.keys()) == { "<r,<g,s>>", "<r,<c,r>>", "<r,<g,r>>", "<r,<t,<s,r>>>", "<n,<n,<n,d>>>", "<r,<m,<d,r>>>", "<r,r>", "<r,n>", "d", "n", "s", "m", "t", "r", "@start@", } assert set([str(type_) for type_ in world.get_basic_types()]) == {'n', 'd', 's', 'r', 't', 'm', 'g', 'c'} check_productions_match(actions['s'], ['[<r,<g,s>>, r, m]', '[<r,<g,s>>, r, t]']) def test_get_valid_actions_in_world_without_date_columns(self): question_tokens = [Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?']] table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-4.table' table_context = TableQuestionContext.read_from_file(table_file, question_tokens) # The table does not have a date column. assert "date" not in table_context.column_types.values() world = WikiTablesVariableFreeWorld(table_context) actions = world.get_valid_actions() assert set(actions.keys()) == { "<r,<g,s>>", "<r,<f,<n,r>>>", "<r,<c,r>>", "<r,<g,r>>", "<r,<r,<f,n>>>", "<r,<t,<s,r>>>", "<n,<n,<n,d>>>", "<r,<f,n>>", "<r,r>", "<r,n>", "d", "n", "s", "t", "f", "r", "@start@", } assert set([str(type_) for type_ in world.get_basic_types()]) == {'n', 'd', 's', 'r', 't', 'f', 'g', 'c'} check_productions_match(actions['s'], ['[<r,<g,s>>, r, f]', '[<r,<g,s>>, r, t]']) def test_get_valid_actions_in_world_without_comparable_columns(self): question_tokens = [Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?']] table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-1.table' table_context = TableQuestionContext.read_from_file(table_file, question_tokens) # The table does not have date or number columns. assert "date" not in table_context.column_types.values() assert "number" not in table_context.column_types.values() world = WikiTablesVariableFreeWorld(table_context) actions = world.get_valid_actions() assert set(actions.keys()) == { "<r,<g,s>>", "<r,<g,r>>", "<r,<t,<s,r>>>", "<n,<n,<n,d>>>", "<r,r>", "<r,n>", "d", "n", "s", "t", "r", "@start@", } assert set([str(type_) for type_ in world.get_basic_types()]) == {'n', 'd', 's', 'r', 't', 'g'} def test_parsing_logical_form_with_string_not_in_question_fails(self): logical_form_with_usl_a_league = """(select (filter_in all_rows string_column:league usl_a_league) date_column:year)""" logical_form_with_2013 = """(select (filter_date_greater all_rows date_column:year (date 2013 -1 -1)) date_column:year)""" with self.assertRaises(ParsingError): self.world_with_2013.parse_logical_form(logical_form_with_usl_a_league) self.world_with_usl_a_league.parse_logical_form(logical_form_with_2013) @staticmethod def _get_alias(types_, name) -> str: if name in types_.generic_name_mapper.common_name_mapping: return types_.generic_name_mapper.get_alias(name) elif name in types_.string_column_name_mapper.common_name_mapping: return types_.string_column_name_mapper.get_alias(name) elif name in types_.number_column_name_mapper.common_name_mapping: return types_.number_column_name_mapper.get_alias(name) elif name in types_.date_column_name_mapper.common_name_mapping: return types_.date_column_name_mapper.get_alias(name) else: return types_.comparable_column_name_mapper.get_alias(name) def test_world_processes_logical_forms_correctly(self): logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)" expression = self.world_with_usl_a_league.parse_logical_form(logical_form) f = partial(self._get_alias, types) # Cells (and parts) get mapped to strings. # Column names are mapped in local name mapping. For the global names, we can get their # aliases from the name mapper. assert str(expression) == f"{f('select')}({f('filter_in')}({f('all_rows')},C2,string:usl_a_league),C0)" def test_world_gets_correct_actions(self): logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)" expression = self.world_with_usl_a_league.parse_logical_form(logical_form) expected_sequence = ['@start@ -> s', 's -> [<r,<g,s>>, r, m]', '<r,<g,s>> -> select', 'r -> [<r,<t,<s,r>>>, r, t, s]', '<r,<t,<s,r>>> -> filter_in', 'r -> all_rows', 't -> string_column:league', 's -> string:usl_a_league', 'm -> date_column:year'] assert self.world_with_usl_a_league.get_action_sequence(expression) == expected_sequence def test_world_gets_logical_form_from_actions(self): logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)" expression = self.world_with_usl_a_league.parse_logical_form(logical_form) action_sequence = self.world_with_usl_a_league.get_action_sequence(expression) reconstructed_logical_form = self.world_with_usl_a_league.get_logical_form(action_sequence) assert logical_form == reconstructed_logical_form def test_world_processes_logical_forms_with_number_correctly(self): tokens = [Token(x) for x in ['when', 'was', 'the', 'attendance', 'higher', 'than', '3000', '?']] world = self._get_world_with_question_tokens(tokens) logical_form = """(select (filter_number_greater all_rows number_column:avg_attendance 3000) date_column:year)""" expression = world.parse_logical_form(logical_form) f = partial(self._get_alias, types) # Cells (and parts) get mapped to strings. # Column names are mapped in local name mapping. For the global names, we can get their # aliases from the name mapper. assert str(expression) == f"{f('select')}({f('filter_number_greater')}({f('all_rows')},C6,num:3000),C0)" def test_world_processes_logical_forms_with_date_correctly(self): logical_form = """(select (filter_date_greater all_rows date_column:year (date 2013 -1 -1)) date_column:year)""" expression = self.world_with_2013.parse_logical_form(logical_form) f = partial(self._get_alias, types) # Cells (and parts) get mapped to strings. # Column names are mapped in local name mapping. For the global names, we can get their # aliases from the name mapper. assert str(expression) == \ f"{f('select')}({f('filter_date_greater')}({f('all_rows')},C0,{f('date')}(num:2013,num:~1,num:~1)),C0)" def test_get_agenda(self): tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'n -> 2000', '<r,r> -> last', 'm -> date_column:year'} tokens = [Token(x) for x in ['what', 'was', 'the', 'difference', 'in', 'attendance', 'between', 'years', '2001', 'and', '2005', '?']] world = self._get_world_with_question_tokens(tokens) # "year" column does not match because "years" occurs in the question. assert set(world.get_agenda()) == {'n -> 2001', 'n -> 2005', '<r,<r,<f,n>>> -> diff'} tokens = [Token(x) for x in ['what', 'was', 'the', 'total', 'avg.', 'attendance', 'in', 'years', '2001', 'and', '2005', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'n -> 2001', 'n -> 2005', '<r,<f,n>> -> sum', 'f -> number_column:avg_attendance'} tokens = [Token(x) for x in ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'<r,<c,r>> -> argmin', 'f -> number_column:avg_attendance'} tokens = [Token(x) for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'<r,<f,n>> -> min', 'f -> number_column:avg_attendance'} tokens = [Token(x) for x in ['when', 'did', 'the', 'team', 'not', 'qualify', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'s -> string:qualify'} tokens = [Token(x) for x in ['when', 'was', 'the', 'avg.', 'attendance', 'at', 'least', '7000', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'<r,<f,<n,r>>> -> filter_number_greater_equals', 'f -> number_column:avg_attendance', 'n -> 7000'} tokens = [Token(x) for x in ['when', 'was', 'the', 'avg.', 'attendance', 'more', 'than', '7000', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'<r,<f,<n,r>>> -> filter_number_greater', 'f -> number_column:avg_attendance', 'n -> 7000'} tokens = [Token(x) for x in ['when', 'was', 'the', 'avg.', 'attendance', 'at', 'most', '7000', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'<r,<f,<n,r>>> -> filter_number_lesser_equals', 'f -> number_column:avg_attendance', 'n -> 7000'} tokens = [Token(x) for x in ['what', 'was', 'the', 'top', 'year', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'<r,r> -> first', 'm -> date_column:year'} tokens = [Token(x) for x in ['what', 'was', 'the', 'year', 'in', 'the', 'bottom', 'row', '?']] world = self._get_world_with_question_tokens(tokens) assert set(world.get_agenda()) == {'<r,r> -> last', 'm -> date_column:year'}