def setup_method(self):
        self.tokenizer = SpacyTokenizer(pos_tags=True)
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "tables" / "341.tagged"
        self.graph = TableQuestionContext.read_from_file(
            table_file, self.utterance).get_table_knowledge_graph()
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name",
                                                            namespace="tokens")
        self.in_index = self.vocab.add_token_to_namespace("in",
                                                          namespace="tokens")
        self.english_index = self.vocab.add_token_to_namespace(
            "english", namespace="tokens")
        self.location_index = self.vocab.add_token_to_namespace(
            "location", namespace="tokens")
        self.mersin_index = self.vocab.add_token_to_namespace(
            "mersin", namespace="tokens")

        self.oov_index = self.vocab.get_token_index("random OOV string",
                                                    namespace="tokens")
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance,
                                         self.token_indexers, self.tokenizer)

        super().setup_method()
Beispiel #2
0
 def setUp(self):
     super().setUp()
     # Adding a bunch of random tokens in here so we get them as constants in the language.
     question_tokens = [
         Token(x)
         for x in [
             "what",
             "was",
             "the",
             "last",
             "year",
             "2013",
             "?",
             "quarterfinals",
             "a_league",
             "2010",
             "8000",
             "did_not_qualify",
             "2001",
             "2",
             "23",
             "2005",
             "1",
             "2002",
             "usl_a_league",
             "usl_first_division",
         ]
     ]
     self.table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "sample_table.tagged"
     self.table_context = TableQuestionContext.read_from_file(self.table_file, question_tokens)
     self.language = WikiTablesLanguage(self.table_context)
 def test_rank_number_extraction(self):
     question = "what was the first tamil-language film in 1943?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-1.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     _, numbers = table_question_context.get_entities_from_question()
     assert numbers == [("1", 3), ("1943", 9)]
 def test_date_column_type_extraction_1(self):
     question = "how many were elected?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-5.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     column_names = table_question_context.column_names
     assert "date_column:first_elected" in column_names
 def test_date_extraction(self):
     question = "how many laps did matt kenset complete on february 26, 2006."
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-8.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     _, number_entities = table_question_context.get_entities_from_question(
     )
     assert number_entities == [("2", 8), ("26", 9), ("2006", 11)]
Beispiel #6
0
 def test_date_column_type_extraction_2(self):
     question = "how many were elected?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-9.table'
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     column_names = table_question_context.column_names
     assert "date_column:date_of_appointment" in column_names
     assert "date_column:date_of_election" in column_names
Beispiel #7
0
 def test_multiword_entity_extraction(self):
     question = "was the positioning better the year of the france venue or the year of the south korea venue?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-3.table'
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     entities, _ = table_question_context.get_entities_from_question()
     assert entities == [("string:france", ["string_column:venue"]),
                         ("string:south_korea", ["string_column:venue"])]
 def test_number_extraction(self):
     question = """how many players on the 191617 illinois fighting illini men's basketball team
                   had more than 100 points scored?"""
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-7.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     _, number_entities = table_question_context.get_entities_from_question(
     )
     assert number_entities == [("191617", 5), ("100", 16)]
 def test_null_extraction(self):
     question = "on what date did the eagles score the least points?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-2.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     entities, numbers = table_question_context.get_entities_from_question()
     # "Eagles" does not appear in the table.
     assert entities == []
     assert numbers == []
 def test_date_extraction_2(self):
     question = """how many different players scored for the san jose earthquakes during their
                   1979 home opener against the timbers?"""
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-6.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     _, number_entities = table_question_context.get_entities_from_question(
     )
     assert number_entities == [("1979", 12)]
Beispiel #11
0
 def test_number_and_entity_extraction(self):
     question = "other than m1 how many notations have 1 in them?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-11.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     string_entities, number_entities = table_question_context.get_entities_from_question(
     )
     assert string_entities == [("string:m1", ["string_column:notation"]),
                                ("string:1", ["string_column:position"])]
     assert number_entities == [("1", 2), ("1", 7)]
 def test_string_column_types_extraction(self):
     question = "how many were elected?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-10.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     column_names = table_question_context.column_names
     assert "string_column:birthplace" in column_names
     assert "string_column:advocate" in column_names
     assert "string_column:notability" in column_names
     assert "string_column:name" in column_names
 def test_get_knowledge_graph(self):
     question = "other than m1 how many notations have 1 in them?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-11.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     knowledge_graph = table_question_context.get_table_knowledge_graph()
     entities = knowledge_graph.entities
     # -1 is not in entities because there are no date columns in the table.
     assert sorted(entities) == [
         "1",
         "number_column:notation",
         "number_column:position",
         "string:1",
         "string:m1",
         "string_column:mnemonic",
         "string_column:notation",
         "string_column:position",
         "string_column:short_name",
         "string_column:swara",
     ]
     neighbors = knowledge_graph.neighbors
     # Each number extracted from the question will have all number and date columns as
     # neighbors. Each string entity extracted from the question will only have the corresponding
     # column as the neighbor.
     neighbors_with_sets = {
         key: set(value)
         for key, value in neighbors.items()
     }
     assert neighbors_with_sets == {
         "1": {"number_column:position", "number_column:notation"},
         "string_column:mnemonic": set(),
         "string_column:short_name": set(),
         "string_column:swara": set(),
         "number_column:position": {"1"},
         "number_column:notation": {"1"},
         "string:m1": {"string_column:notation"},
         "string:1": {"string_column:position"},
         "string_column:notation": {"string:m1"},
         "string_column:position": {"string:1"},
     }
     entity_text = knowledge_graph.entity_text
     assert entity_text == {
         "1": "1",
         "string:m1": "m1",
         "string:1": "1",
         "string_column:notation": "notation",
         "number_column:notation": "notation",
         "string_column:mnemonic": "mnemonic",
         "string_column:short_name": "short name",
         "string_column:swara": "swara",
         "number_column:position": "position",
         "string_column:position": "position",
     }
 def test_numerical_column_type_extraction(self):
     question = """how many players on the 191617 illinois fighting illini men's basketball team
                   had more than 100 points scored?"""
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-7.table"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     column_names = table_question_context.column_names
     assert "number_column:games_played" in column_names
     assert "number_column:field_goals" in column_names
     assert "number_column:free_throws" in column_names
     assert "number_column:points" in column_names
 def test_table_data_from_untagged_file(self):
     question = "what was the attendance when usl a league played?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/wikitables/sample_table.tsv"
     table_lines = [line.strip() for line in open(test_file).readlines()]
     table_question_context = TableQuestionContext.read_from_lines(
         table_lines, question_tokens)
     # The content in the table represented by the untagged file we are reading here is the same as the one we
     # had in the tagged file above, except that we have a "Score" column instead of "Avg. Attendance" column,
     # which is changed to test the num2 extraction logic. I've shown the values not being extracted here as
     # well and commented them out.
     assert table_question_context.table_data == [
         {
             "number_column:year": 2001.0,
             # The value extraction logic we have for untagged lines does
             # not extract this value as a date.
             # 'date_column:year': Date(2001, -1, -1),
             "string_column:year": "2001",
             "number_column:division": 2.0,
             "string_column:division": "2",
             "string_column:league": "usl_a_league",
             "string_column:regular_season": "4th_western",
             # We only check for strings that are entirely numbers. So 4.0
             # will not be extracted.
             # 'number_column:regular_season': 4.0,
             "string_column:playoffs": "quarterfinals",
             "string_column:open_cup": "did_not_qualify",
             # 'number_column:open_cup': None,
             "number_column:score": 20.0,
             "num2_column:score": 30.0,
             "string_column:score": "20_30",
         },
         {
             "number_column:year": 2005.0,
             # 'date_column:year': Date(2005, -1, -1),
             "string_column:year": "2005",
             "number_column:division": 2.0,
             "string_column:division": "2",
             "string_column:league": "usl_first_division",
             "string_column:regular_season": "5th",
             # Same here as in the "division" column for the first row.
             # 5.0 will not be extracted from "5th".
             # 'number_column:regular_season': 5.0,
             "string_column:playoffs": "quarterfinals",
             "string_column:open_cup": "4th_round",
             # 'number_column:open_cup': 4.0,
             "number_column:score": 50.0,
             "num2_column:score": 40.0,
             "string_column:score": "50_40",
         },
     ]
 def evaluate_denotation(self, denotation: Any, target_list: List[str]) -> bool:
     """
     Compares denotation with a target list and returns whether they are both the same according to the official
     evaluator.
     """
     normalized_target_list = [TableQuestionContext.normalize_string(value) for value in
                               target_list]
     target_value_list = evaluator.to_value_list(normalized_target_list)
     if isinstance(denotation, list):
         denotation_list = [str(denotation_item) for denotation_item in denotation]
     else:
         denotation_list = [str(denotation)]
     denotation_value_list = evaluator.to_value_list(denotation_list)
     return evaluator.check_denotation(target_value_list, denotation_value_list)
 def test_knowledge_graph_has_correct_neighbors(self):
     question = "when was the attendance greater than 5000?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     knowledge_graph = table_question_context.get_table_knowledge_graph()
     neighbors = knowledge_graph.neighbors
     # '5000' is neighbors with number and date columns. '-1' is in entities because there is a
     # date column, which is its only neighbor.
     assert set(neighbors.keys()) == {
         "date_column:year",
         "number_column:year",
         "string_column:year",
         "number_column:division",
         "string_column:division",
         "string_column:league",
         "string_column:regular_season",
         "number_column:regular_season",
         "string_column:playoffs",
         "string_column:open_cup",
         "number_column:open_cup",
         "number_column:avg_attendance",
         "string_column:avg_attendance",
         "5000",
         "-1",
     }
     assert set(neighbors["date_column:year"]) == {"5000", "-1"}
     assert neighbors["number_column:year"] == ["5000"]
     assert neighbors["string_column:year"] == []
     assert neighbors["number_column:division"] == ["5000"]
     assert neighbors["string_column:division"] == []
     assert neighbors["string_column:league"] == []
     assert neighbors["string_column:regular_season"] == []
     assert neighbors["number_column:regular_season"] == ["5000"]
     assert neighbors["string_column:playoffs"] == []
     assert neighbors["string_column:open_cup"] == []
     assert neighbors["number_column:open_cup"] == ["5000"]
     assert neighbors["number_column:avg_attendance"] == ["5000"]
     assert neighbors["string_column:avg_attendance"] == []
     assert set(neighbors["5000"]) == {
         "date_column:year",
         "number_column:year",
         "number_column:division",
         "number_column:avg_attendance",
         "number_column:regular_season",
         "number_column:open_cup",
     }
     assert neighbors["-1"] == ["date_column:year"]
Beispiel #18
0
 def setUp(self):
     super().setUp()
     # Adding a bunch of random tokens in here so we get them as constants in the language.
     question_tokens = [
         Token(x) for x in [
             'what', 'was', 'the', 'last', 'year', '2013', '?',
             'quarterfinals', 'a_league', '2010', '8000', 'did_not_qualify',
             '2001', '2', '23', '2005', '1', '2002', 'usl_a_league',
             'usl_first_division'
         ]
     ]
     self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tagged'
     self.table_context = TableQuestionContext.read_from_file(
         self.table_file, question_tokens)
     self.language = WikiTablesLanguage(self.table_context)
Beispiel #19
0
 def test_knowledge_graph_has_correct_neighbors(self):
     question = "when was the attendance greater than 5000?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged'
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     knowledge_graph = table_question_context.get_table_knowledge_graph()
     neighbors = knowledge_graph.neighbors
     # '5000' is neighbors with number and date columns. '-1' is in entities because there is a
     # date column, which is its only neighbor.
     assert set(neighbors.keys()) == {
         'date_column:year', 'number_column:year', 'string_column:year',
         'number_column:division', 'string_column:division',
         'string_column:league', 'string_column:regular_season',
         'number_column:regular_season', 'string_column:playoffs',
         'string_column:open_cup', 'number_column:open_cup',
         'number_column:avg_attendance', 'string_column:avg_attendance',
         '5000', '-1'
     }
     assert set(neighbors['date_column:year']) == {'5000', '-1'}
     assert neighbors['number_column:year'] == ['5000']
     assert neighbors['string_column:year'] == []
     assert neighbors['number_column:division'] == ['5000']
     assert neighbors['string_column:division'] == []
     assert neighbors['string_column:league'] == []
     assert neighbors['string_column:regular_season'] == []
     assert neighbors['number_column:regular_season'] == ['5000']
     assert neighbors['string_column:playoffs'] == []
     assert neighbors['string_column:open_cup'] == []
     assert neighbors['number_column:open_cup'] == ['5000']
     assert neighbors['number_column:avg_attendance'] == ['5000']
     assert neighbors['string_column:avg_attendance'] == []
     assert set(neighbors['5000']) == {
         'date_column:year', 'number_column:year', 'number_column:division',
         'number_column:avg_attendance', 'number_column:regular_season',
         'number_column:open_cup'
     }
     assert neighbors['-1'] == ['date_column:year']
 def test_table_data(self):
     question = "what was the attendance when usl a league played?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged"
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     assert table_question_context.table_data == [
         {
             "date_column:year": Date(2001, -1, -1),
             "number_column:year": 2001.0,
             "string_column:year": "2001",
             "number_column:division": 2.0,
             "string_column:division": "2",
             "string_column:league": "usl_a_league",
             "string_column:regular_season": "4th_western",
             "number_column:regular_season": 4.0,
             "string_column:playoffs": "quarterfinals",
             "string_column:open_cup": "did_not_qualify",
             "number_column:open_cup": None,
             "number_column:avg_attendance": 7169.0,
             "string_column:avg_attendance": "7_169",
         },
         {
             "date_column:year": Date(2005, -1, -1),
             "number_column:year": 2005.0,
             "string_column:year": "2005",
             "number_column:division": 2.0,
             "string_column:division": "2",
             "string_column:league": "usl_first_division",
             "string_column:regular_season": "5th",
             "number_column:regular_season": 5.0,
             "string_column:playoffs": "quarterfinals",
             "string_column:open_cup": "4th_round",
             "number_column:open_cup": 4.0,
             "number_column:avg_attendance": 6028.0,
             "string_column:avg_attendance": "6_028",
         },
     ]
def search(
    tables_directory: str,
    data: JsonDict,
    output_path: str,
    max_path_length: int,
    max_num_logical_forms: int,
    use_agenda: bool,
    output_separate_files: bool,
    conservative_agenda: bool,
) -> None:
    print(f"Starting search with {len(data)} instances", file=sys.stderr)
    language_logger = logging.getLogger(
        "allennlp.semparse.domain_languages.wikitables_language")
    language_logger.setLevel(logging.ERROR)
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file,
                                                      tokenized_question)
        world = WikiTablesLanguage(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda(conservative=conservative_agenda)
            allow_partial_match = not conservative_agenda
            all_logical_forms = walker.get_logical_forms_with_agenda(
                agenda=agenda,
                max_num_logical_forms=10000,
                allow_partial_match=allow_partial_match)
        else:
            all_logical_forms = walker.get_all_logical_forms(
                max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz",
                           "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
    def text_to_instance(
        self,  # type: ignore
        question: str,
        table_lines: List[List[str]],
        target_values: List[str] = None,
        offline_search_output: List[str] = None,
    ) -> Instance:
        """
        Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that
        method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset,
        or simply in a tsv format where each line corresponds to a row and the cells are tab-separated.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``, optional
            Target values for the denotations the logical forms should execute to. Not required for testing.
        offline_search_output : ``List[str]``, optional
            List of logical forms, produced by offline search. Not required during test.
        """
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        world = WikiTablesLanguage(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens,
        )
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_productions():
            _, rule_right_side = production_rule.split(" -> ")
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            "question": question_field,
            "metadata": MetadataField(metadata),
            "table": table_field,
            "world": world_field,
            "actions": action_field,
        }

        if target_values is not None:
            target_values_field = MetadataField(target_values)
            fields["target_values"] = target_values_field

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    action_sequence = world.logical_form_to_action_sequence(
                        logical_form)
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except ParsingError as error:
                    logger.debug(
                        f"Parsing error: {error.message}, skipping logical form"
                    )
                    logger.debug(f"Question was: {question}")
                    logger.debug(f"Logical form was: {logical_form}")
                    logger.debug(f"Table info was: {table_lines}")
                    continue
                except KeyError as error:
                    logger.debug(
                        f"Missing production rule: {error.args}, skipping logical form"
                    )
                    logger.debug(f"Question was: {question}")
                    logger.debug(f"Table info was: {table_lines}")
                    logger.debug(f"Logical form was: {logical_form}")
                    continue
                except:  # noqa
                    logger.error(logical_form)
                    raise
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields["target_action_sequences"] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda(conservative=True):
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields["agenda"] = ListField(agenda_index_fields)
        return Instance(fields)
Beispiel #23
0
 def _get_world_with_question_tokens_and_table_file(
     self, tokens: List[Token], table_file: str
 ) -> WikiTablesLanguage:
     table_context = TableQuestionContext.read_from_file(table_file, tokens)
     world = WikiTablesLanguage(table_context)
     return world
Beispiel #24
0
 def test_table_data(self):
     question = "what was the attendance when usl a league played?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged'
     table_question_context = TableQuestionContext.read_from_file(
         test_file, question_tokens)
     assert table_question_context.table_data == [{
         'date_column:year':
         Date(2001, -1, -1),
         'number_column:year':
         2001.0,
         'string_column:year':
         '2001',
         'number_column:division':
         2.0,
         'string_column:division':
         '2',
         'string_column:league':
         'usl_a_league',
         'string_column:regular_season':
         '4th_western',
         'number_column:regular_season':
         4.0,
         'string_column:playoffs':
         'quarterfinals',
         'string_column:open_cup':
         'did_not_qualify',
         'number_column:open_cup':
         None,
         'number_column:avg_attendance':
         7169.0,
         'string_column:avg_attendance':
         '7_169'
     }, {
         'date_column:year':
         Date(2005, -1, -1),
         'number_column:year':
         2005.0,
         'string_column:year':
         '2005',
         'number_column:division':
         2.0,
         'string_column:division':
         '2',
         'string_column:league':
         'usl_first_division',
         'string_column:regular_season':
         '5th',
         'number_column:regular_season':
         5.0,
         'string_column:playoffs':
         'quarterfinals',
         'string_column:open_cup':
         '4th_round',
         'number_column:open_cup':
         4.0,
         'number_column:avg_attendance':
         6028.0,
         'string_column:avg_attendance':
         '6_028'
     }]
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(
            start_types=self._get_start_types_in_context(table_context))
        self.table_context = table_context
        self.table_data = [Row(row) for row in table_context.table_data]

        column_types = table_context.column_types
        self._table_has_string_columns = False
        self._table_has_date_columns = False
        self._table_has_number_columns = False
        if "string" in column_types:
            self.add_predicate("filter_in", self.filter_in)
            self.add_predicate("filter_not_in", self.filter_not_in)
            self._table_has_string_columns = True
        if "date" in column_types:
            self.add_predicate("filter_date_greater", self.filter_date_greater)
            self.add_predicate("filter_date_greater_equals",
                               self.filter_date_greater_equals)
            self.add_predicate("filter_date_lesser", self.filter_date_lesser)
            self.add_predicate("filter_date_lesser_equals",
                               self.filter_date_lesser_equals)
            self.add_predicate("filter_date_equals", self.filter_date_equals)
            self.add_predicate("filter_date_not_equals",
                               self.filter_date_not_equals)
            self.add_predicate("max_date", self.max_date)
            self.add_predicate("min_date", self.min_date)
            # Adding -1 to mapping because we need it for dates where not all three fields are
            # specified. We want to do this only when the table has a date column. This is because
            # the knowledge graph is also constructed in such a way that -1 is an entity with date
            # columns as the neighbors only if any date columns exist in the table.
            self.add_constant("-1", -1, type_=Number)
            self._table_has_date_columns = True
        if "number" in column_types or "num2" in column_types:
            self.add_predicate("filter_number_greater",
                               self.filter_number_greater)
            self.add_predicate("filter_number_greater_equals",
                               self.filter_number_greater_equals)
            self.add_predicate("filter_number_lesser",
                               self.filter_number_lesser)
            self.add_predicate("filter_number_lesser_equals",
                               self.filter_number_lesser_equals)
            self.add_predicate("filter_number_equals",
                               self.filter_number_equals)
            self.add_predicate("filter_number_not_equals",
                               self.filter_number_not_equals)
            self.add_predicate("max_number", self.max_number)
            self.add_predicate("min_number", self.min_number)
            self.add_predicate("average", self.average)
            self.add_predicate("sum", self.sum)
            self.add_predicate("diff", self.diff)
            self._table_has_number_columns = True
        if "date" in column_types or "number" in column_types or "num2" in column_types:
            self.add_predicate("argmax", self.argmax)
            self.add_predicate("argmin", self.argmin)

        self.table_graph = table_context.get_table_knowledge_graph()

        # Adding entities and numbers seen in questions as constants.
        question_entities, question_numbers = table_context.get_entities_from_question(
        )
        self._question_entities = [entity for entity, _ in question_entities]
        self._question_numbers = [number for number, _ in question_numbers]
        for entity in self._question_entities:
            # Forcing the type of entities to be List[str] here to ensure that the language deals with the outputs
            # of select-like statements and constants similarly.
            self.add_constant(entity, entity, type_=List[str])

        for number in self._question_numbers:
            self.add_constant(str(number), float(number), type_=Number)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names as constants.
        for column_name in table_context.column_names:
            column_type = column_name.split(":")[0].replace("_column", "")
            column: Column = None
            if column_type == "string":
                column = StringColumn(column_name)
            elif column_type == "date":
                column = DateColumn(column_name)
                self.add_constant(column_name, column, type_=ComparableColumn)
            elif column_type in {"number", "num2"}:
                column = NumberColumn(column_name)
                self.add_constant(column_name, column, type_=ComparableColumn)
            self.add_constant(column_name, column, type_=Column)
            self.add_constant(column_name, column)
            column_type_name = str(PredicateType.get_type(type(column)))
            self._column_productions_for_agenda[
                column_name] = f"{column_type_name} -> {column_name}"

        # Mapping from terminal strings to productions that produce them.  We use this in the
        # agenda-related methods, and some models that use this language look at this field to know
        # how many terminals to plan for.
        self.terminal_productions: Dict[str, str] = {}
        for name, types in self._function_types.items():
            self.terminal_productions[name] = "%s -> %s" % (types[0], name)