def search(tables_directory: str, input_examples_file: str, output_file: str, max_path_length: int, max_num_logical_forms: int, use_agenda: bool) -> None: data = [wikitables_util.parse_example_line(example_line) for example_line in open(input_examples_file)] tokenizer = WordTokenizer() with open(output_file, "w") as output_file_pointer: for instance_data in data: utterance = instance_data["question"] question_id = instance_data["id"] if utterance.startswith('"') and utterance.endswith('"'): utterance = utterance[1:-1] # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged table_file = instance_data["table_filename"].replace("csv", "tagged") # pylint: disable=protected-access target_list = [TableQuestionContext._normalize_string(value) for value in instance_data["target_values"]] try: target_value_list = evaluator.to_value_list(target_list) except: print(target_list) target_value_list = evaluator.to_value_list(target_list) tokenized_question = tokenizer.tokenize(utterance) table_file = f"{tables_directory}/{table_file}" context = TableQuestionContext.read_from_file(table_file, tokenized_question) world = WikiTablesVariableFreeWorld(context) walker = ActionSpaceWalker(world, max_path_length=max_path_length) correct_logical_forms = [] print(f"{question_id} {utterance}", file=output_file_pointer) if use_agenda: agenda = world.get_agenda() print(f"Agenda: {agenda}", file=output_file_pointer) all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda, max_num_logical_forms=10000) else: all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000) for logical_form in all_logical_forms: try: denotation = world.execute(logical_form) except ExecutionError: print(f"Failed to execute: {logical_form}", file=sys.stderr) continue if isinstance(denotation, list): denotation_list = [str(denotation_item) for denotation_item in denotation] else: # For numbers and dates denotation_list = [str(denotation)] denotation_value_list = evaluator.to_value_list(denotation_list) if evaluator.check_denotation(target_value_list, denotation_value_list): correct_logical_forms.append(logical_form) if not correct_logical_forms: print("NO LOGICAL FORMS FOUND!", file=output_file_pointer) for logical_form in correct_logical_forms[:max_num_logical_forms]: print(logical_form, file=output_file_pointer) print(file=output_file_pointer)
def test_get_nonterminal_productions_in_world_without_comparable_columns( self): question_tokens = [ Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?'] ] table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-1.table' table_context = TableQuestionContext.read_from_file( table_file, question_tokens) # The table does not have date or number columns. assert "date" not in table_context.column_types.values() assert "number" not in table_context.column_types.values() world = WikiTablesLanguage(table_context) actions = world.get_nonterminal_productions() assert set(actions.keys()) == { "<List[Row],Column:List[str]>", "<List[Row],Column:List[Row]>", "<List[Row],StringColumn,str:List[Row]>", "<Number,Number,Number:Date>", "<List[Row]:List[Row]>", "<List[Row]:Number>", "Date", "Number", "List[str]", "Column", "StringColumn", "List[Row]", "@start@", }
def setUp(self): self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True)) self.utterance = self.tokenizer.tokenize("where is mersin?") self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")} table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "tables" / "341.tagged" self.graph = TableQuestionContext.read_from_file( table_file, self.utterance).get_table_knowledge_graph() self.vocab = Vocabulary() self.name_index = self.vocab.add_token_to_namespace("name", namespace='tokens') self.in_index = self.vocab.add_token_to_namespace("in", namespace='tokens') self.english_index = self.vocab.add_token_to_namespace( "english", namespace='tokens') self.location_index = self.vocab.add_token_to_namespace( "location", namespace='tokens') self.mersin_index = self.vocab.add_token_to_namespace( "mersin", namespace='tokens') self.oov_index = self.vocab.get_token_index('random OOV string', namespace='tokens') self.edirne_index = self.oov_index self.field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer) super(KnowledgeGraphFieldTest, self).setUp()
def test_knowledge_graph_has_correct_neighbors(self): question = "when was the attendance greater than 5000?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged' table_question_context = TableQuestionContext.read_from_file( test_file, question_tokens) knowledge_graph = table_question_context.get_table_knowledge_graph() neighbors = knowledge_graph.neighbors # '5000' is neighbors with number and date columns. '-1' is in entities because there is a # date column, which is its only neighbor. assert set(neighbors.keys()) == { 'date_column:year', 'number_column:division', 'string_column:league', 'string_column:regular_season', 'string_column:playoffs', 'string_column:open_cup', 'number_column:avg_attendance', '5000', '-1' } assert set(neighbors['date_column:year']) == {'5000', '-1'} assert neighbors['number_column:division'] == ['5000'] assert neighbors['string_column:league'] == [] assert neighbors['string_column:regular_season'] == [] assert neighbors['string_column:playoffs'] == [] assert neighbors['string_column:open_cup'] == [] assert neighbors['number_column:avg_attendance'] == ['5000'] assert set(neighbors['5000']) == { 'date_column:year', 'number_column:division', 'number_column:avg_attendance' } assert neighbors['-1'] == ['date_column:year']
def test_get_knowledge_graph(self): question = "other than m1 how many notations have 1 in them?" question_tokens = self.tokenizer.tokenize(question) test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-11.table" table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) knowledge_graph = table_question_context.get_table_knowledge_graph() entities = knowledge_graph.entities # -1 is not in entities because there are no date columns in the table. assert sorted(entities) == ['1', 'number_column:position', 'string:m1', 'string_column:mnemonic', 'string_column:notation', 'string_column:short_name', 'string_column:swara'] neighbors = knowledge_graph.neighbors # Each number extracted from the question will have all number and date columns as # neighbors. Each string entity extracted from the question will only have the corresponding # column as the neighbor. assert neighbors == {'1': ['number_column:position'], 'string_column:mnemonic': [], 'string_column:short_name': [], 'string_column:swara': [], 'number_column:position': ['1'], 'string:m1': ['string_column:notation'], 'string_column:notation': ['string:m1']} entity_text = knowledge_graph.entity_text assert entity_text == {'1': '1', 'string:m1': 'm1', 'string_column:notation': 'notation', 'string_column:mnemonic': 'mnemonic', 'string_column:short_name': 'short name', 'string_column:swara': 'swara', 'number_column:position': 'position'}
def test_date_extraction(self): question = "how many laps did matt kenset complete on february 26, 2006." question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-8.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) _, number_entities = table_question_context.get_entities_from_question() assert number_entities == [("2", 8), ("26", 9), ("2006", 11)]
def evaluate_logical_form(self, logical_form: str, target_value: List[str], target_canon: List[str]) -> bool: """ Taken from Chen's script """ target_value_strings = tsv_unescape_list(target_value) normalized_target_value_strings = [ TableQuestionContext.normalize_string(value) for value in target_value_strings] canon_value_strings = tsv_unescape_list(target_canon) target_value_list = to_value_list(normalized_target_value_strings, canon_value_strings) try: denotation = self.execute(logical_form) except ExecutionError: logger.warning(f'Failed to execute: {logical_form}') return False except Exception as ex: err_template = "Exception of type {0} occurred. Arguments:\n{1!r}" message = err_template.format(type(ex).__name__, ex.args) logger.warning(f'{message}') if isinstance(denotation, list): denotation_list = [str(denotation_item) for denotation_item in denotation] else: denotation_list = [str(denotation)] denotation_value_list = to_value_list(denotation_list) return check_denotation(target_value_list, denotation_value_list)
def setUp(self): super().setUp() # Adding a bunch of random tokens in here so we get them as constants in the language. question_tokens = [ Token(x) for x in [ "what", "was", "the", "last", "year", "2013", "?", "quarterfinals", "a_league", "2010", "8000", "did_not_qualify", "2001", "2", "23", "2005", "1", "2002", "usl_a_league", "usl_first_division", ] ] self.table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "sample_table.tagged" self.table_context = TableQuestionContext.read_from_file(self.table_file, question_tokens) self.language = WikiTablesLanguage(self.table_context)
def test_rank_number_extraction(self): question = "what was the first tamil-language film in 1943?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-1.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) _, numbers = table_question_context.get_entities_from_question() assert numbers == [("1", 3), ('1943', 9)]
def test_get_valid_actions_in_world_without_date_columns(self): question_tokens = [Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?']] table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-4.table' table_context = TableQuestionContext.read_from_file(table_file, question_tokens) # The table does not have a date column. assert "date" not in table_context.column_types.values() world = WikiTablesVariableFreeWorld(table_context) actions = world.get_valid_actions() assert set(actions.keys()) == { "<r,<g,s>>", "<r,<f,<n,r>>>", "<r,<c,r>>", "<r,<g,r>>", "<r,<r,<f,n>>>", "<r,<t,<s,r>>>", "<n,<n,<n,d>>>", "<r,<f,n>>", "<r,r>", "<r,n>", "d", "n", "s", "t", "f", "r", "@start@", } assert set([str(type_) for type_ in world.get_basic_types()]) == {'n', 'd', 's', 'r', 't', 'f', 'g', 'c'} check_productions_match(actions['s'], ['[<r,<g,s>>, r, f]', '[<r,<g,s>>, r, t]'])
def test_date_column_type_extraction_1(self): question = "how many were elected?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-5.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) predicted_types = table_question_context.column_types assert predicted_types["first_elected"] == "date"
def test_rank_number_extraction(self): question = "what was the first tamil-language film in 1943?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-1.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) _, numbers = table_question_context.get_entities_from_question() assert numbers == [("1", 3), ('1943', 9)]
def test_table_data(self): question = "what was the attendance when usl a league played?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) assert table_question_context.table_data == [{'date_column:year': Date(2001, -1, -1), 'number_column:year': 2001.0, 'string_column:year': '2001', 'number_column:division': 2.0, 'string_column:division': '2', 'string_column:league': 'usl_a_league', 'string_column:regular_season': '4th_western', 'number_column:regular_season': 4.0, 'string_column:playoffs': 'quarterfinals', 'string_column:open_cup': 'did_not_qualify', 'number_column:open_cup': None, 'number_column:avg_attendance': 7169.0, 'string_column:avg_attendance': '7_169'}, {'date_column:year': Date(2005, -1, -1), 'number_column:year': 2005.0, 'string_column:year': '2005', 'number_column:division': 2.0, 'string_column:division': '2', 'string_column:league': 'usl_first_division', 'string_column:regular_season': '5th', 'number_column:regular_season': 5.0, 'string_column:playoffs': 'quarterfinals', 'string_column:open_cup': '4th_round', 'number_column:open_cup': 4.0, 'number_column:avg_attendance': 6028.0, 'string_column:avg_attendance': '6_028'}]
def test_date_extraction(self): question = "how many laps did matt kenset complete on february 26, 2006." question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-8.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) _, number_entities = table_question_context.get_entities_from_question() assert number_entities == [("2", 8), ("26", 9), ("2006", 11)]
def test_date_column_type_extraction_1(self): question = "how many were elected?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-5.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) column_names = table_question_context.column_names assert "date_column:first_elected" in column_names
def evaluate_logical_form(self, logical_form: str, target_list: List[str]) -> bool: """ Takes a logical form, and the list of target values as strings from the original lisp string, and returns True iff the logical form executes to the target list. """ normalized_target_list = [ TableQuestionContext.normalize_string(value) for value in target_list ] target_value_list = evaluator.to_value_list(normalized_target_list) try: denotation = self.execute(logical_form) except ExecutionError: logger.warning(f'Failed to execute: {logical_form}') return False if isinstance(denotation, list): denotation_list = [ str(denotation_item) for denotation_item in denotation ] else: denotation_list = [str(denotation)] denotation_value_list = evaluator.to_value_list(denotation_list) return evaluator.check_denotation(target_value_list, denotation_value_list)
def test_number_and_entity_extraction(self): question = "other than m1 how many notations have 1 in them?" question_tokens = self.tokenizer.tokenize(question) test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-11.table" table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) string_entities, number_entities = table_question_context.get_entities_from_question() assert string_entities == [("string:m1", "string_column:notation")] assert number_entities == [("1", 2), ("1", 7)]
def test_number_extraction(self): question = """how many players on the 191617 illinois fighting illini men's basketball team had more than 100 points scored?""" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-7.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) _, number_entities = table_question_context.get_entities_from_question() assert number_entities == [("191617", 5), ("100", 16)]
def test_date_extraction_2(self): question = """how many different players scored for the san jose earthquakes during their 1979 home opener against the timbers?""" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-6.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) _, number_entities = table_question_context.get_entities_from_question() assert number_entities == [("1979", 12)]
def test_multiword_entity_extraction(self): question = "was the positioning better the year of the france venue or the year of the south korea venue?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-3.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) entities, _ = table_question_context.get_entities_from_question() assert entities == [("string:france", "string_column:venue"), ("string:south_korea", "string_column:venue")]
def test_date_extraction_2(self): question = """how many different players scored for the san jose earthquakes during their 1979 home opener against the timbers?""" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-6.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) _, number_entities = table_question_context.get_entities_from_question() assert number_entities == [("1979", 12)]
def test_multiword_entity_extraction(self): question = "was the positioning better the year of the france venue or the year of the south korea venue?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-3.table' table_question_context = TableQuestionContext.read_from_file( test_file, question_tokens) entities, _ = table_question_context.get_entities_from_question() assert entities == ["france", "south_korea"]
def test_number_extraction(self): question = """how many players on the 191617 illinois fighting illini men's basketball team had more than 100 points scored?""" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-7.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) _, number_entities = table_question_context.get_entities_from_question() assert number_entities == [("191617", 5), ("100", 16)]
def search( tables_directory: str, data: JsonDict, output_path: str, max_path_length: int, max_num_logical_forms: int, use_agenda: bool, output_separate_files: bool, conservative_agenda: bool, ) -> None: print(f"Starting search with {len(data)} instances", file=sys.stderr) language_logger = logging.getLogger("allennlp.semparse.domain_languages.wikitables_language") language_logger.setLevel(logging.ERROR) tokenizer = SpacyTokenizer() if output_separate_files and not os.path.exists(output_path): os.makedirs(output_path) if not output_separate_files: output_file_pointer = open(output_path, "w") for instance_data in data: utterance = instance_data["question"] question_id = instance_data["id"] if utterance.startswith('"') and utterance.endswith('"'): utterance = utterance[1:-1] # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged table_file = instance_data["table_filename"].replace("csv", "tagged") target_list = instance_data["target_values"] tokenized_question = tokenizer.tokenize(utterance) table_file = f"{tables_directory}/{table_file}" context = TableQuestionContext.read_from_file(table_file, tokenized_question) world = WikiTablesLanguage(context) walker = ActionSpaceWalker(world, max_path_length=max_path_length) correct_logical_forms = [] if use_agenda: agenda = world.get_agenda(conservative=conservative_agenda) allow_partial_match = not conservative_agenda all_logical_forms = walker.get_logical_forms_with_agenda( agenda=agenda, max_num_logical_forms=10000, allow_partial_match=allow_partial_match ) else: all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000) for logical_form in all_logical_forms: if world.evaluate_logical_form(logical_form, target_list): correct_logical_forms.append(logical_form) if output_separate_files and correct_logical_forms: with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer: for logical_form in correct_logical_forms: print(logical_form, file=output_file_pointer) elif not output_separate_files: print(f"{question_id} {utterance}", file=output_file_pointer) if use_agenda: print(f"Agenda: {agenda}", file=output_file_pointer) if not correct_logical_forms: print("NO LOGICAL FORMS FOUND!", file=output_file_pointer) for logical_form in correct_logical_forms[:max_num_logical_forms]: print(logical_form, file=output_file_pointer) print(file=output_file_pointer) if not output_separate_files: output_file_pointer.close()
def test_number_comparison_works(self): # TableQuestionContext normlaizes all strings according to some rules. We want to ensure # that the original numerical values of number cells is being correctly processed here. tokens = WordTokenizer().tokenize("when was the attendance the highest?") tagged_file = self.FIXTURES_ROOT / "data" / "corenlp_processed_tables" / "TEST-2.table" context = TableQuestionContext.read_from_file(tagged_file, tokens) executor = WikiTablesVariableFreeExecutor(context.table_data) result = executor.execute("(select (argmax all_rows number_column:attendance) date_column:date)") assert result == ["november_10"]
def test_number_and_entity_extraction(self): question = "other than m1 how many notations have 1 in them?" question_tokens = self.tokenizer.tokenize(question) test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-11.table" table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) string_entities, number_entities = table_question_context.get_entities_from_question() assert string_entities == [("string:m1", ["string_column:notation"]), ("string:1", ["string_column:position"])] assert number_entities == [("1", 2), ("1", 7)]
def test_null_extraction(self): question = "on what date did the eagles score the least points?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-2.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) entities, numbers = table_question_context.get_entities_from_question() # "Eagles" does not appear in the table. assert entities == [] assert numbers == []
def setUp(self): super().setUp() question_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']] self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tagged' self.table_context = TableQuestionContext.read_from_file(self.table_file, question_tokens) self.world_with_2013 = WikiTablesVariableFreeWorld(self.table_context) usl_league_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', 'with', 'usl', 'a', 'league', '?']] self.world_with_usl_a_league = self._get_world_with_question_tokens(usl_league_tokens)
def test_date_column_type_extraction_2(self): question = "how many were elected?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-9.table' table_question_context = TableQuestionContext.read_from_file( test_file, question_tokens) predicted_types = table_question_context.column_types assert predicted_types["date_of_appointment"] == "date" assert predicted_types["date_of_election"] == "date"
def test_null_extraction(self): question = "on what date did the eagles score the least points?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-2.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) entities, numbers = table_question_context.get_entities_from_question() # "Eagles" does not appear in the table. assert entities == [] assert numbers == []
def setUp(self): super().setUp() # Adding a bunch of random tokens in here so we get them as constants in the language. question_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?', 'quarterfinals', 'a_league', '2010', '8000', 'did_not_qualify', '2001', '2', '23', '2005', '1', '2002', 'usl_a_league', 'usl_first_division']] self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tagged' self.table_context = TableQuestionContext.read_from_file(self.table_file, question_tokens) self.language = WikiTablesLanguage(self.table_context)
def test_string_column_types_extraction(self): question = "how many were elected?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-10.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) predicted_types = table_question_context.column_types assert predicted_types["birthplace"] == "string" assert predicted_types["advocate"] == "string" assert predicted_types["notability"] == "string" assert predicted_types["name"] == "string"
def test_string_column_types_extraction(self): question = "how many were elected?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-10.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) column_names = table_question_context.column_names assert "string_column:birthplace" in column_names assert "string_column:advocate" in column_names assert "string_column:notability" in column_names assert "string_column:name" in column_names
def search(tables_directory: str, input_examples_file: str, output_path: str, max_path_length: int, max_num_logical_forms: int, use_agenda: bool, output_separate_files: bool) -> None: data = [ wikitables_util.parse_example_line(example_line) for example_line in open(input_examples_file) ] tokenizer = WordTokenizer() if output_separate_files and not os.path.exists(output_path): os.makedirs(output_path) if not output_separate_files: output_file_pointer = open(output_path, "w") for instance_data in data: utterance = instance_data["question"] question_id = instance_data["id"] if utterance.startswith('"') and utterance.endswith('"'): utterance = utterance[1:-1] # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged table_file = instance_data["table_filename"].replace("csv", "tagged") target_list = instance_data["target_values"] tokenized_question = tokenizer.tokenize(utterance) table_file = f"{tables_directory}/{table_file}" context = TableQuestionContext.read_from_file(table_file, tokenized_question) world = WikiTablesVariableFreeWorld(context) walker = ActionSpaceWalker(world, max_path_length=max_path_length) correct_logical_forms = [] if use_agenda: agenda = world.get_agenda() all_logical_forms = walker.get_logical_forms_with_agenda( agenda=agenda, max_num_logical_forms=10000, allow_partial_match=True) else: all_logical_forms = walker.get_all_logical_forms( max_num_logical_forms=10000) for logical_form in all_logical_forms: if world.evaluate_logical_form(logical_form, target_list): correct_logical_forms.append(logical_form) if output_separate_files and correct_logical_forms: with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer: for logical_form in correct_logical_forms: print(logical_form, file=output_file_pointer) elif not output_separate_files: print(f"{question_id} {utterance}", file=output_file_pointer) if use_agenda: print(f"Agenda: {agenda}", file=output_file_pointer) if not correct_logical_forms: print("NO LOGICAL FORMS FOUND!", file=output_file_pointer) for logical_form in correct_logical_forms[:max_num_logical_forms]: print(logical_form, file=output_file_pointer) print(file=output_file_pointer) if not output_separate_files: output_file_pointer.close()
def test_numerical_column_type_extraction(self): question = """how many players on the 191617 illinois fighting illini men's basketball team had more than 100 points scored?""" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-7.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) predicted_types = table_question_context.column_types assert predicted_types["games_played"] == "number" assert predicted_types["field_goals"] == "number" assert predicted_types["free_throws"] == "number" assert predicted_types["points"] == "number"
def test_numerical_column_type_extraction(self): question = """how many players on the 191617 illinois fighting illini men's basketball team had more than 100 points scored?""" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-7.table' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) column_names = table_question_context.column_names assert "number_column:games_played" in column_names assert "number_column:field_goals" in column_names assert "number_column:free_throws" in column_names assert "number_column:points" in column_names
def test_get_knowledge_graph(self): question = "other than m1 how many notations have 1 in them?" question_tokens = self.tokenizer.tokenize(question) test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-11.table" table_question_context = TableQuestionContext.read_from_file( test_file, question_tokens) knowledge_graph = table_question_context.get_table_knowledge_graph() entities = knowledge_graph.entities # -1 is not in entities because there are no date columns in the table. assert sorted(entities) == [ "1", "number_column:notation", "number_column:position", "string:1", "string:m1", "string_column:mnemonic", "string_column:notation", "string_column:position", "string_column:short_name", "string_column:swara", ] neighbors = knowledge_graph.neighbors # Each number extracted from the question will have all number and date columns as # neighbors. Each string entity extracted from the question will only have the corresponding # column as the neighbor. neighbors_with_sets = { key: set(value) for key, value in neighbors.items() } assert neighbors_with_sets == { "1": {"number_column:position", "number_column:notation"}, "string_column:mnemonic": set(), "string_column:short_name": set(), "string_column:swara": set(), "number_column:position": {"1"}, "number_column:notation": {"1"}, "string:m1": {"string_column:notation"}, "string:1": {"string_column:position"}, "string_column:notation": {"string:m1"}, "string_column:position": {"string:1"}, } entity_text = knowledge_graph.entity_text assert entity_text == { "1": "1", "string:m1": "m1", "string:1": "1", "string_column:notation": "notation", "number_column:notation": "notation", "string_column:mnemonic": "mnemonic", "string_column:short_name": "short name", "string_column:swara": "swara", "number_column:position": "position", "string_column:position": "position", }
def text_to_instance(self, # type: ignore logical_forms: List[str], table_lines: List[List[str]], question: str) -> Instance: # pylint: disable=arguments-differ tokenized_question = self._tokenizer.tokenize(question.lower()) tokenized_question.insert(0, Token(START_SYMBOL)) tokenized_question.append(Token(END_SYMBOL)) question_field = TextField(tokenized_question, self._question_token_indexers) table_context = TableQuestionContext.read_from_lines(table_lines, tokenized_question) world = WikiTablesLanguage(table_context) action_sequences_list: List[List[str]] = [] action_sequence_fields_list: List[TextField] = [] for logical_form in logical_forms: try: action_sequence = world.logical_form_to_action_sequence(logical_form) action_sequence = reader_utils.make_bottom_up_action_sequence(action_sequence, world.is_nonterminal) action_sequence_field = TextField([Token(rule) for rule in action_sequence], self._rule_indexers) action_sequences_list.append(action_sequence) action_sequence_fields_list.append(action_sequence_field) except ParsingError as error: logger.debug(f'Parsing error: {error.message}, skipping logical form') logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') except: logger.error(logical_form) raise if not action_sequences_list: return None all_production_rule_fields: List[List[Field]] = [] for action_sequence in action_sequences_list: all_production_rule_fields.append([]) for production_rule in action_sequence: _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_instance_specific_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule=is_global_rule) all_production_rule_fields[-1].append(field) action_field = ListField([ListField(production_rule_fields) for production_rule_fields in all_production_rule_fields]) fields = {'action_sequences': ListField(action_sequence_fields_list), 'target_tokens': question_field, 'world': MetadataField(world), 'actions': action_field} return Instance(fields)
def search(tables_directory: str, input_examples_file: str, output_path: str, max_path_length: int, max_num_logical_forms: int, use_agenda: bool, output_separate_files: bool) -> None: data = [wikitables_util.parse_example_line(example_line) for example_line in open(input_examples_file)] tokenizer = WordTokenizer() if output_separate_files and not os.path.exists(output_path): os.makedirs(output_path) if not output_separate_files: output_file_pointer = open(output_path, "w") for instance_data in data: utterance = instance_data["question"] question_id = instance_data["id"] if utterance.startswith('"') and utterance.endswith('"'): utterance = utterance[1:-1] # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged table_file = instance_data["table_filename"].replace("csv", "tagged") target_list = instance_data["target_values"] tokenized_question = tokenizer.tokenize(utterance) table_file = f"{tables_directory}/{table_file}" context = TableQuestionContext.read_from_file(table_file, tokenized_question) world = WikiTablesVariableFreeWorld(context) walker = ActionSpaceWalker(world, max_path_length=max_path_length) correct_logical_forms = [] if use_agenda: agenda = world.get_agenda() all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda, max_num_logical_forms=10000) else: all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000) for logical_form in all_logical_forms: if world.evaluate_logical_form(logical_form, target_list): correct_logical_forms.append(logical_form) if output_separate_files and correct_logical_forms: with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer: for logical_form in correct_logical_forms: print(logical_form, file=output_file_pointer) elif not output_separate_files: print(f"{question_id} {utterance}", file=output_file_pointer) if use_agenda: print(f"Agenda: {agenda}", file=output_file_pointer) if not correct_logical_forms: print("NO LOGICAL FORMS FOUND!", file=output_file_pointer) for logical_form in correct_logical_forms[:max_num_logical_forms]: print(logical_form, file=output_file_pointer) print(file=output_file_pointer) if not output_separate_files: output_file_pointer.close()
def test_table_data_from_untagged_file(self): question = "what was the attendance when usl a league played?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tsv' table_lines = [line.strip() for line in open(test_file).readlines()] table_question_context = TableQuestionContext.read_from_lines( table_lines, question_tokens) # The content in the table represented by the untagged file we are reading here is the same as the one we # had in the tagged file above, except that we have a "Score" column instead of "Avg. Attendance" column, # which is changed to test the num2 extraction logic. I've shown the values not being extracted here as # well and commented them out. assert table_question_context.table_data == [ { 'number_column:year': 2001.0, # The value extraction logic we have for untagged lines does # not extract this value as a date. #'date_column:year': Date(2001, -1, -1), 'string_column:year': '2001', 'number_column:division': 2.0, 'string_column:division': '2', 'string_column:league': 'usl_a_league', 'string_column:regular_season': '4th_western', # We only check for strings that are entirely numbers. So 4.0 # will not be extracted. #'number_column:regular_season': 4.0, 'string_column:playoffs': 'quarterfinals', 'string_column:open_cup': 'did_not_qualify', #'number_column:open_cup': None, 'number_column:score': 20.0, 'num2_column:score': 30.0, 'string_column:score': '20_30' }, { 'number_column:year': 2005.0, #'date_column:year': Date(2005, -1, -1), 'string_column:year': '2005', 'number_column:division': 2.0, 'string_column:division': '2', 'string_column:league': 'usl_first_division', 'string_column:regular_season': '5th', # Same here as in the "division" column for the first row. # 5.0 will not be extracted from "5th". #'number_column:regular_season': 5.0, 'string_column:playoffs': 'quarterfinals', 'string_column:open_cup': '4th_round', #'number_column:open_cup': 4.0, 'number_column:score': 50.0, 'num2_column:score': 40.0, 'string_column:score': '50_40' } ]
def test_knowledge_graph_has_correct_neighbors(self): question = "when was the attendance greater than 5000?" question_tokens = self.tokenizer.tokenize(question) test_file = f"{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged" table_question_context = TableQuestionContext.read_from_file( test_file, question_tokens) knowledge_graph = table_question_context.get_table_knowledge_graph() neighbors = knowledge_graph.neighbors # '5000' is neighbors with number and date columns. '-1' is in entities because there is a # date column, which is its only neighbor. assert set(neighbors.keys()) == { "date_column:year", "number_column:year", "string_column:year", "number_column:division", "string_column:division", "string_column:league", "string_column:regular_season", "number_column:regular_season", "string_column:playoffs", "string_column:open_cup", "number_column:open_cup", "number_column:avg_attendance", "string_column:avg_attendance", "5000", "-1", } assert set(neighbors["date_column:year"]) == {"5000", "-1"} assert neighbors["number_column:year"] == ["5000"] assert neighbors["string_column:year"] == [] assert neighbors["number_column:division"] == ["5000"] assert neighbors["string_column:division"] == [] assert neighbors["string_column:league"] == [] assert neighbors["string_column:regular_season"] == [] assert neighbors["number_column:regular_season"] == ["5000"] assert neighbors["string_column:playoffs"] == [] assert neighbors["string_column:open_cup"] == [] assert neighbors["number_column:open_cup"] == ["5000"] assert neighbors["number_column:avg_attendance"] == ["5000"] assert neighbors["string_column:avg_attendance"] == [] assert set(neighbors["5000"]) == { "date_column:year", "number_column:year", "number_column:division", "number_column:avg_attendance", "number_column:regular_season", "number_column:open_cup", } assert neighbors["-1"] == ["date_column:year"]
def __init__(self, table_context: TableQuestionContext) -> None: super().__init__(constant_type_prefixes={ "string": types.STRING_TYPE, "num": types.NUMBER_TYPE }, global_type_signatures=types.COMMON_TYPE_SIGNATURE, global_name_mapping=types.COMMON_NAME_MAPPING) # TODO (pradeep): Do we need constant type prefixes? self.table_context = table_context self._executor = WikiTablesVariableFreeExecutor( self.table_context.table_data) # For every new column name seen, we update this counter to map it to a new NLTK name. self._column_counter = 0 # Adding entities and numbers seen in questions to the mapping. self._question_entities, question_numbers = table_context.get_entities_from_question( ) self._question_numbers = [number for number, _ in question_numbers] for entity in self._question_entities: self._map_name(f"string:{entity}", keep_mapping=True) for number_in_question in self._question_numbers: self._map_name(f"num:{number_in_question}", keep_mapping=True) # Adding -1 to mapping because we need it for dates where not all three fields are # specified. self._map_name(f"num:-1", keep_mapping=True) # Keeps track of column name productions so that we can add them to the agenda. self._column_productions_for_agenda: Dict[str, str] = {} # Adding column names to the local name mapping. for column_name, column_type in table_context.column_types.items(): self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True) self.global_terminal_productions: Dict[str, str] = {} for predicate, mapped_name in self.global_name_mapping.items(): if mapped_name in self.global_type_signatures: signature = self.global_type_signatures[mapped_name] self.global_terminal_productions[ predicate] = f"{signature} -> {predicate}" # We don't need to recompute this ever; let's just compute it once and cache it. self._valid_actions: Dict[str, List[str]] = None
def test_table_data(self): question = "what was the attendance when usl a league played?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) assert table_question_context.table_data == [{'date_column:year': '2001', 'number_column:division': '2', 'string_column:league': 'usl_a_league', 'string_column:regular_season': '4th_western', 'string_column:playoffs': 'quarterfinals', 'string_column:open_cup': 'did_not_qualify', 'number_column:avg_attendance': '7_169'}, {'date_column:year': '2005', 'number_column:division': '2', 'string_column:league': 'usl_first_division', 'string_column:regular_season': '5th', 'string_column:playoffs': 'quarterfinals', 'string_column:open_cup': '4th_round', 'number_column:avg_attendance': '6_028'}]
def evaluate_logical_form(self, logical_form: str, target_list: List[str]) -> bool: """ Takes a logical form, and the list of target values as strings from the original lisp string, and returns True iff the logical form executes to the target list. """ normalized_target_list = [TableQuestionContext.normalize_string(value) for value in target_list] target_value_list = evaluator.to_value_list(normalized_target_list) try: denotation = self.execute(logical_form) except ExecutionError: logger.warning(f'Failed to execute: {logical_form}') return False if isinstance(denotation, list): denotation_list = [str(denotation_item) for denotation_item in denotation] else: denotation_list = [str(denotation)] denotation_value_list = evaluator.to_value_list(denotation_list) return evaluator.check_denotation(target_value_list, denotation_value_list)
def test_knowledge_graph_has_correct_neighbors(self): question = "when was the attendance greater than 5000?" question_tokens = self.tokenizer.tokenize(question) test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged' table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens) knowledge_graph = table_question_context.get_table_knowledge_graph() neighbors = knowledge_graph.neighbors # '5000' is neighbors with number and date columns. '-1' is in entities because there is a # date column, which is its only neighbor. assert set(neighbors.keys()) == {'date_column:year', 'number_column:division', 'string_column:league', 'string_column:regular_season', 'string_column:playoffs', 'string_column:open_cup', 'number_column:avg_attendance', '5000', '-1'} assert set(neighbors['date_column:year']) == {'5000', '-1'} assert neighbors['number_column:division'] == ['5000'] assert neighbors['string_column:league'] == [] assert neighbors['string_column:regular_season'] == [] assert neighbors['string_column:playoffs'] == [] assert neighbors['string_column:open_cup'] == [] assert neighbors['number_column:avg_attendance'] == ['5000'] assert set(neighbors['5000']) == {'date_column:year', 'number_column:division', 'number_column:avg_attendance'} assert neighbors['-1'] == ['date_column:year']
def _get_world_with_question_tokens(self, tokens: List[Token]) -> WikiTablesVariableFreeWorld: table_context = TableQuestionContext.read_from_file(self.table_file, tokens) world = WikiTablesVariableFreeWorld(table_context) return world
def __init__(self, table_context: TableQuestionContext) -> None: super().__init__(constant_type_prefixes={"string": types.STRING_TYPE, "num": types.NUMBER_TYPE}, global_type_signatures=types.COMMON_TYPE_SIGNATURE, global_name_mapping=types.COMMON_NAME_MAPPING) self.table_context = table_context # We add name mapping and signatures corresponding to specific column types to the local # name mapping based on the table content here. column_types = table_context.column_types.values() self._table_has_string_columns = False self._table_has_date_columns = False self._table_has_number_columns = False if "string" in column_types: for name, translated_name in types.STRING_COLUMN_NAME_MAPPING.items(): signature = types.STRING_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self._table_has_string_columns = True if "date" in column_types: for name, translated_name in types.DATE_COLUMN_NAME_MAPPING.items(): signature = types.DATE_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) # Adding -1 to mapping because we need it for dates where not all three fields are # specified. We want to do this only when the table has a date column. This is because # the knowledge graph is also constructed in such a way that -1 is an entity with date # columns as the neighbors only if any date columns exist in the table. self._map_name(f"num:-1", keep_mapping=True) self._table_has_date_columns = True if "number" in column_types: for name, translated_name in types.NUMBER_COLUMN_NAME_MAPPING.items(): signature = types.NUMBER_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self._table_has_number_columns = True if "date" in column_types or "number" in column_types: for name, translated_name in types.COMPARABLE_COLUMN_NAME_MAPPING.items(): signature = types.COMPARABLE_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self.table_graph = table_context.get_table_knowledge_graph() self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data) # TODO (pradeep): Use a NameMapper for mapping entity names too. # For every new column name seen, we update this counter to map it to a new NLTK name. self._column_counter = 0 # Adding entities and numbers seen in questions to the mapping. question_entities, question_numbers = table_context.get_entities_from_question() self._question_entities = [entity for entity, _ in question_entities] self._question_numbers = [number for number, _ in question_numbers] for entity in self._question_entities: # These entities all have prefix "string:" self._map_name(entity, keep_mapping=True) for number_in_question in self._question_numbers: self._map_name(f"num:{number_in_question}", keep_mapping=True) # Keeps track of column name productions so that we can add them to the agenda. self._column_productions_for_agenda: Dict[str, str] = {} # Adding column names to the local name mapping. for column_name, column_type in table_context.column_types.items(): self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True) self.terminal_productions: Dict[str, str] = {} name_mapping = [(name, mapping) for name, mapping in self.global_name_mapping.items()] name_mapping += [(name, mapping) for name, mapping in self.local_name_mapping.items()] signatures = self.global_type_signatures.copy() signatures.update(self.local_type_signatures) for predicate, mapped_name in name_mapping: if mapped_name in signatures: signature = signatures[mapped_name] self.terminal_productions[predicate] = f"{signature} -> {predicate}" # We don't need to recompute this ever; let's just compute it once and cache it. self._valid_actions: Dict[str, List[str]] = None