Пример #1
0
 def test_number_comparison_works(self):
     # TableQuestionContext normlaizes all strings according to some rules. We want to ensure
     # that the original numerical values of number cells is being correctly processed here.
     tokens = WordTokenizer().tokenize("when was the attendance the highest?")
     tagged_file = self.FIXTURES_ROOT / "data" / "corenlp_processed_tables" / "TEST-2.table"
     context = TableQuestionContext.read_from_file(tagged_file, tokens)
     executor = WikiTablesVariableFreeExecutor(context.table_data)
     result = executor.execute("(select (argmax all_rows number_column:attendance) date_column:date)")
     assert result == ["november_10"]
Пример #2
0
 def setUp(self):
     super().setUp()
     table_data = [{"date_column:date": "january_2001", "number_column:division": "2",
                    "string_column:league": "usl_a_league", "string_column:regular_season": "4th_western",
                    "string_column:playoffs": "quarterfinals", "string_column:open_cup": "did_not_qualify",
                    "number_column:avg_attendance": "7169"},
                   {"date_column:date": "march_2005", "number_column:division": "2",
                    "string_column:league": "usl_first_division", "string_column:regular_season": "5th",
                    "string_column:playoffs": "quarterfinals", "string_column:open_cup": "4th_round",
                    "number_column:avg_attendance": "6028"}]
     self.executor = WikiTablesVariableFreeExecutor(table_data)
Пример #3
0
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(constant_type_prefixes={
            "string": types.STRING_TYPE,
            "num": types.NUMBER_TYPE
        },
                         global_type_signatures=types.COMMON_TYPE_SIGNATURE,
                         global_name_mapping=types.COMMON_NAME_MAPPING)
        # TODO (pradeep): Do we need constant type prefixes?
        self.table_context = table_context

        self._executor = WikiTablesVariableFreeExecutor(
            self.table_context.table_data)

        # For every new column name seen, we update this counter to map it to a new NLTK name.
        self._column_counter = 0

        # Adding entities and numbers seen in questions to the mapping.
        self._question_entities, question_numbers = table_context.get_entities_from_question(
        )
        self._question_numbers = [number for number, _ in question_numbers]
        for entity in self._question_entities:
            self._map_name(f"string:{entity}", keep_mapping=True)

        for number_in_question in self._question_numbers:
            self._map_name(f"num:{number_in_question}", keep_mapping=True)

        # Adding -1 to mapping because we need it for dates where not all three fields are
        # specified.
        self._map_name(f"num:-1", keep_mapping=True)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names to the local name mapping.
        for column_name, column_type in table_context.column_types.items():
            self._map_name(f"{column_type}_column:{column_name}",
                           keep_mapping=True)

        self.global_terminal_productions: Dict[str, str] = {}
        for predicate, mapped_name in self.global_name_mapping.items():
            if mapped_name in self.global_type_signatures:
                signature = self.global_type_signatures[mapped_name]
                self.global_terminal_productions[
                    predicate] = f"{signature} -> {predicate}"

        # We don't need to recompute this ever; let's just compute it once and cache it.
        self._valid_actions: Dict[str, List[str]] = None
Пример #4
0
 def setUp(self):
     super().setUp()
     table_data = [{"fb:row.row.year": "fb:cell.2001", "fb:row.row.division": "fb:cell.2",
                    "fb:row.row.league": "fb:cell.usl_a_league", "fb:row.row.regular_season":
                    "fb:cell.4th_western", "fb:row.row.playoffs": "fb:cell.quarterfinals",
                    "fb:row.row.open_cup": "fb:cell.did_not_qualify",
                    "fb:row.row.avg_attendance": "fb:cell.7169"},
                   {"fb:row.row.year": "fb:cell.2005", "fb:row.row.division": "fb:cell.2",
                    "fb:row.row.league": "fb:cell.usl_first_division", "fb:row.row.regular_season":
                    "fb:cell.5th", "fb:row.row.playoffs": "fb:cell.quarterfinals",
                    "fb:row.row.open_cup": "fb:cell.4th_round",
                    "fb:row.row.avg_attendance": "fb:cell.6028"}]
     self.executor = WikiTablesVariableFreeExecutor(table_data)
     table_data_with_date = [{"fb:row.row.date": "fb:cell.january_2001", "fb:row.row.division": "fb:cell.2",
                              "fb:row.row.league": "fb:cell.usl_a_league", "fb:row.row.regular_season":
                              "fb:cell.4th_western", "fb:row.row.playoffs": "fb:cell.quarterfinals",
                              "fb:row.row.open_cup": "fb:cell.did_not_qualify",
                              "fb:row.row.avg_attendance": "fb:cell.7169"},
                             {"fb:row.row.date": "fb:cell.march_2005", "fb:row.row.division": "fb:cell.2",
                              "fb:row.row.league": "fb:cell.usl_first_division", "fb:row.row.regular_season":
                              "fb:cell.5th", "fb:row.row.playoffs": "fb:cell.quarterfinals",
                              "fb:row.row.open_cup": "fb:cell.4th_round",
                              "fb:row.row.avg_attendance": "fb:cell.6028"}]
     self.executor_with_date = WikiTablesVariableFreeExecutor(table_data_with_date)
Пример #5
0
class TestWikiTablesVariableFreeExecutor(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        table_data = [{"date_column:date": "january_2001", "number_column:division": "2",
                       "string_column:league": "usl_a_league", "string_column:regular_season": "4th_western",
                       "string_column:playoffs": "quarterfinals", "string_column:open_cup": "did_not_qualify",
                       "number_column:avg_attendance": "7169"},
                      {"date_column:date": "march_2005", "number_column:division": "2",
                       "string_column:league": "usl_first_division", "string_column:regular_season": "5th",
                       "string_column:playoffs": "quarterfinals", "string_column:open_cup": "4th_round",
                       "number_column:avg_attendance": "6028"}]
        self.executor = WikiTablesVariableFreeExecutor(table_data)

    def test_execute_fails_with_unknown_function(self):
        logical_form = "(unknown_function all_rows string_column:league)"
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_select(self):
        logical_form = "(select all_rows string_column:league)"
        cell_list = self.executor.execute(logical_form)
        assert set(cell_list) == {'usl_a_league', 'usl_first_division'}

    def test_execute_works_with_argmax(self):
        logical_form = "(select (argmax all_rows number_column:avg_attendance) string_column:league)"
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ['usl_a_league']

    def test_execute_works_with_argmax_on_dates(self):
        logical_form = "(select (argmax all_rows date_column:date) string_column:league)"
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ['usl_first_division']

    def test_execute_works_with_argmin(self):
        logical_form = "(select (argmin all_rows number_column:avg_attendance) date_column:date)"
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ['march_2005']

    def test_execute_works_with_argmin_on_dates(self):
        logical_form = "(select (argmin all_rows date_column:date) string_column:league)"
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ['usl_a_league']

    def test_execute_works_with_filter_number_greater(self):
        # Selecting cell values from all rows that have attendance greater than the min value of
        # attendance.
        logical_form = """(select (filter_number_greater all_rows number_column:avg_attendance
                                   (min all_rows number_column:avg_attendance)) string_column:league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['usl_a_league']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_number_greater all_rows number_column:avg_attendance
                                   all_rows) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_greater(self):
        # Selecting cell values from all rows that have date greater than 2002.
        logical_form = """(select (filter_date_greater all_rows date_column:date
                                   (date 2002 -1 -1)) string_column:league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['usl_first_division']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_greater all_rows date_column:date
                                   2005) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_number_greater_equals(self):
        # Counting rows that have attendance greater than or equal to the min value of attendance.
        logical_form = """(count (filter_number_greater_equals all_rows number_column:avg_attendance
                                  (min all_rows number_column:avg_attendance)))"""
        count_result = self.executor.execute(logical_form)
        assert count_result == 2
        # Replacing the filter value with an invalid value.
        logical_form = """(count (filter_number_greater all_rows number_column:avg_attendance all_rows))"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_greater_equals(self):
        # Selecting cell values from all rows that have date greater than or equal to 2005 February
        # 1st.
        logical_form = """(select (filter_date_greater_equals all_rows date_column:date
                                   (date 2005 2 1)) string_column:league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['usl_first_division']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_greater_equals all_rows date_column:date
                                   2005) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_number_lesser(self):
        # Selecting cell values from all rows that have date lesser than 2005.
        logical_form = """(select (filter_number_lesser all_rows number_column:avg_attendance
                                    (max all_rows number_column:avg_attendance)) string_column:league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['usl_first_division']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_number_lesser all_rows date_column:date
                                   (date 2005 -1 -1)) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_lesser(self):
        # Selecting cell values from all rows that have date less that 2005 January
        logical_form = """(select (filter_date_lesser all_rows date_column:date
                                   (date 2005 1 -1)) string_column:league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ["usl_a_league"]
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_lesser all_rows date_column:date
                                   2005) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_number_lesser_equals(self):
        # Counting rows that have year lesser than or equal to 2005.
        logical_form = """(count (filter_number_lesser_equals all_rows number_column:avg_attendance 8000))"""
        count_result = self.executor.execute(logical_form)
        assert count_result == 2
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_number_lesser_equals all_rows date_column:date
                                   (date 2005 -1 -1)) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_lesser_equals(self):
        # Selecting cell values from all rows that have date less that or equal to 2001 February 23
        logical_form = """(select (filter_date_lesser_equals all_rows date_column:date
                                   (date 2001 2 23)) string_column:league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['usl_a_league']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_lesser_equals all_rows date_column:date
                                   2005) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_number_equals(self):
        # Counting rows that have year equal to 2010.
        logical_form = """(count (filter_number_equals all_rows number_column:avg_attendance 8000))"""
        count_result = self.executor.execute(logical_form)
        assert count_result == 0
        # Replacing the filter value with an invalid value.
        logical_form = """(count (filter_number_equals all_rows date_column:date (date 2010 -1 -1)))"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_equals(self):
        # Selecting cell values from all rows that have date not equal to 2001
        logical_form = """(select (filter_date_equals all_rows date_column:date
                                   (date 2001 -1 -1)) string_column:league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['usl_a_league']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_equals all_rows date_column:date
                                   2005) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_number_not_equals(self):
        # Counting rows that have year not equal to 2010.
        logical_form = """(count (filter_number_not_equals all_rows number_column:avg_attendance 8000))"""
        count_result = self.executor.execute(logical_form)
        assert count_result == 2
        # Replacing the filter value with an invalid value.
        logical_form = """(count (filter_number_not_equals all_rows date_column:date (date 2010 -1 -1)))"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_not_equals(self):
        # Selecting cell values from all rows that have date not equal to 2001
        logical_form = """(select (filter_date_not_equals all_rows date_column:date
                                   (date 2001 -1 -1)) string_column:league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['usl_first_division']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_not_equals all_rows date_column:date
                                   2005) string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_in(self):
        # Selecting "regular season" from rows that have "did not qualify" in "open cup" column.
        logical_form = """(select (filter_in all_rows string_column:open_cup string:did_not_qualify)
                                  string_column:regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["4th_western"]

    def test_execute_works_with_filter_not_in(self):
        # Selecting "regular season" from rows that do not have "did not qualify" in "open cup" column.
        logical_form = """(select (filter_not_in all_rows string_column:open_cup string:did_not_qualify)
                                   string_column:regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["5th"]

    def test_execute_works_with_first(self):
        # Selecting "regular season" from the first row.
        logical_form = """(select (first all_rows) string_column:regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["4th_western"]

    def test_execute_logs_warning_with_first_on_empty_list(self):
        # Selecting "regular season" from the first row where year is greater than 2010.
        with self.assertLogs("allennlp.semparse.executors.wikitables_variable_free_executor") as log:
            logical_form = """(select (first (filter_date_greater all_rows date_column:date
                                                (date 2010 -1 -1)))
                                      string_column:regular_season)"""
            self.executor.execute(logical_form)
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.executors.wikitables_variable_free_executor:"
                          "Trying to get first row from an empty list: "
                          "['filter_date_greater', 'all_rows', 'date_column:date', ['date', '2010', '-1', '-1']]"])

    def test_execute_works_with_last(self):
        # Selecting "regular season" from the last row where year is not equal to 2010.
        logical_form = """(select (last (filter_date_not_equals all_rows date_column:date
                                         (date 2010 -1 -1)))
                                  string_column:regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["5th"]

    def test_execute_logs_warning_with_last_on_empty_list(self):
        # Selecting "regular season" from the last row where year is greater than 2010.
        with self.assertLogs("allennlp.semparse.executors.wikitables_variable_free_executor") as log:
            logical_form = """(select (last (filter_date_greater all_rows date_column:date
                                                (date 2010 -1 -1)))
                                      string_column:regular_season)"""
            self.executor.execute(logical_form)
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.executors.wikitables_variable_free_executor:"
                          "Trying to get last row from an empty list: "
                          "['filter_date_greater', 'all_rows', 'date_column:date', ['date', '2010', '-1', '-1']]"])

    def test_execute_works_with_previous(self):
        # Selecting "regular season" from the row before last where year is not equal to 2010.
        logical_form = """(select (previous (last (filter_date_not_equals
                                                    all_rows date_column:date (date 2010 -1 -1))))
                                  string_column:regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["4th_western"]

    def test_execute_logs_warning_with_previous_on_empty_list(self):
        # Selecting "regular season" from the row before the one where year is greater than 2010.
        with self.assertLogs("allennlp.semparse.executors.wikitables_variable_free_executor") as log:
            logical_form = """(select (previous (filter_date_greater all_rows date_column:date (date 2010 -1 -1)))
                                      string_column:regular_season)"""
            self.executor.execute(logical_form)
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.executors.wikitables_variable_free_executor:"
                          "Trying to get the previous row from an empty list: "
                          "['filter_date_greater', 'all_rows', 'date_column:date', ['date', '2010', '-1', '-1']]"])

    def test_execute_works_with_next(self):
        # Selecting "regular season" from the row after first where year is not equal to 2010.
        logical_form = """(select (next (first (filter_date_not_equals
                                                all_rows date_column:date (date 2010 -1 -1))))
                                  string_column:regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["5th"]

    def test_execute_logs_warning_with_next_on_empty_list(self):
        # Selecting "regular season" from the row after the one where year is greater than 2010.
        with self.assertLogs("allennlp.semparse.executors.wikitables_variable_free_executor") as log:
            logical_form = """(select (next (filter_date_greater all_rows date_column:date (date 2010 -1 -1)))
                                      string_column:regular_season)"""
            self.executor.execute(logical_form)
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.executors.wikitables_variable_free_executor:"
                          "Trying to get the next row from an empty list: "
                          "['filter_date_greater', 'all_rows', 'date_column:date', ['date', '2010', '-1', '-1']]"])

    def test_execute_works_with_mode(self):
        # Most frequent division value.
        logical_form = """(mode all_rows number_column:division)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["2"]
        # If we used select instead, we should get a list of two values.
        logical_form = """(select all_rows number_column:division)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["2", "2"]
        # If we queried for the most frequent year instead, it should return two values since both
        # have the max frequency of 1.
        logical_form = """(mode all_rows date_column:date)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["january_2001", "march_2005"]

    def test_execute_works_with_same_as(self):
        # Select the "league" from all the rows that have the same value under "playoffs" as the
        # row that has the string "a league" under "league".
        logical_form = """(select (same_as (filter_in all_rows string_column:league string:a_league)
                                   string_column:playoffs)
                           string_column:league)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["usl_a_league", "usl_first_division"]

    def test_execute_works_with_sum(self):
        # Get total "avg attendance".
        logical_form = """(sum all_rows number_column:avg_attendance)"""
        sum_value = self.executor.execute(logical_form)
        assert sum_value == 13197
        # Total "avg attendance" where "playoffs" has "quarterfinals"
        logical_form = """(sum (filter_in all_rows string_column:playoffs string:quarterfinals)
                                number_column:avg_attendance)"""
        sum_value = self.executor.execute(logical_form)
        assert sum_value == 13197

    def test_execute_works_with_average(self):
        # Get average "avg attendance".
        logical_form = """(average all_rows number_column:avg_attendance)"""
        avg_value = self.executor.execute(logical_form)
        assert avg_value == 6598.5
        # Average "avg attendance" where "playoffs" has "quarterfinals"
        logical_form = """(average (filter_in all_rows string_column:playoffs string:quarterfinals)
                                number_column:avg_attendance)"""
        avg_value = self.executor.execute(logical_form)
        assert avg_value == 6598.5

    def test_execute_works_with_diff(self):
        # Difference in "avg attendance" between rows with "usl_a_league" and "usl_first_division"
        # in "league" columns.
        logical_form = """(diff (filter_in all_rows string_column:league string:usl_a_league)
                                (filter_in all_rows string_column:league string:usl_first_division)
                                number_column:avg_attendance)"""
        avg_value = self.executor.execute(logical_form)
        assert avg_value == 1141

    def test_execute_fails_with_diff_on_non_numerical_columns(self):
        logical_form = """(diff (filter_in all_rows string_column:league string:usl_a_league)
                                (filter_in all_rows string_column:league string:usl_first_division)
                                string_column:league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_fails_with_non_int_dates(self):
        logical_form = """(date 2015 1.5 1)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_date_comparison_works(self):
        assert Date(2013, 12, 31) > Date(2013, 12, 30)
        assert Date(2013, 12, 31) == Date(2013, 12, -1)
        assert Date(2013, -1, -1) >= Date(2013, 12, 31)
        # pylint: disable=singleton-comparison
        assert (Date(2013, 12, -1) > Date(2013, 12, 31)) == False
        assert (Date(2013, 12, 31) > 2013) == False
        assert (Date(2013, 12, 31) >= 2013) == False
        assert Date(2013, 12, 31) != 2013
        assert (Date(2018, 1, 1) >= Date(-1, 2, 1)) == False
        assert (Date(2018, 1, 1) < Date(-1, 2, 1)) == False
        # When year is unknown in both cases, we can compare months and days.
        assert Date(-1, 2, 1) < Date(-1, 2, 3)
        # If both year and month are not know in both cases, the comparison is undefined, and both
        # < and >= return False.
        assert (Date(-1, -1, 1) < Date(-1, -1, 3)) == False
        assert (Date(-1, -1, 1) >= Date(-1, -1, 3)) == False
        # Same when year is known, buth months are not.
        assert (Date(2018, -1, 1) < Date(2018, -1, 3)) == False
        # TODO (pradeep): Figure out whether this is expected behavior by looking at data.
        assert (Date(2018, -1, 1) >= Date(2018, -1, 3)) == False

    def test_number_comparison_works(self):
        # TableQuestionContext normlaizes all strings according to some rules. We want to ensure
        # that the original numerical values of number cells is being correctly processed here.
        tokens = WordTokenizer().tokenize("when was the attendance the highest?")
        tagged_file = self.FIXTURES_ROOT / "data" / "corenlp_processed_tables" / "TEST-2.table"
        context = TableQuestionContext.read_from_file(tagged_file, tokens)
        executor = WikiTablesVariableFreeExecutor(context.table_data)
        result = executor.execute("(select (argmax all_rows number_column:attendance) date_column:date)")
        assert result == ["november_10"]

    def test_evaluate_logical_form(self):
        logical_form = """(select (same_as (filter_in all_rows string_column:league string:a_league)
                                   string_column:playoffs)
                           string_column:league)"""
        assert self.executor.evaluate_logical_form(logical_form, ["USL A-League",
                                                                  "USL First Division"])

    def test_evaluate_logical_form_with_invalid_logical_form(self):
        logical_form = """(select (same_as (filter_in all_rows string_column:league INVALID_CONSTANT)
                                   string_column:playoffs)
                           string_column:league)"""
        assert not self.executor.evaluate_logical_form(logical_form, ["USL A-League",
                                                                      "USL First Division"])
Пример #6
0
class WikiTablesVariableFreeWorld(World):
    """
    World representation for the WikitableQuestions domain with the variable-free language used in
    the paper from Liang et al. (2018).

    Parameters
    ----------
    table_graph : ``TableQuestionKnowledgeGraph``
        Context associated with this world.
    """
    # When we're converting from logical forms to action sequences, this set tells us which
    # functions in the logical form are curried functions, and how many arguments the function
    # actually takes.  This is necessary because NLTK curries all multi-argument functions to a
    # series of one-argument function applications.  See `world._get_transitions` for more info.
    curried_functions = {
            types.SELECT_TYPE: 2,
            types.ROW_FILTER_WITH_GENERIC_COLUMN: 2,
            types.ROW_FILTER_WITH_COMPARABLE_COLUMN: 2,
            types.ROW_NUM_OP: 2,
            types.ROW_FILTER_WITH_COLUMN_AND_NUMBER: 3,
            types.ROW_FILTER_WITH_COLUMN_AND_DATE: 3,
            types.ROW_FILTER_WITH_COLUMN_AND_STRING: 3,
            types.NUM_DIFF_WITH_COLUMN: 3,
            }

    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(constant_type_prefixes={"string": types.STRING_TYPE,
                                                 "num": types.NUMBER_TYPE},
                         global_type_signatures=types.COMMON_TYPE_SIGNATURE,
                         global_name_mapping=types.COMMON_NAME_MAPPING)
        # TODO (pradeep): Do we need constant type prefixes?
        self.table_context = table_context

        self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data)

        # For every new column name seen, we update this counter to map it to a new NLTK name.
        self._column_counter = 0

        # Adding entities and numbers seen in questions to the mapping.
        self._question_entities, question_numbers = table_context.get_entities_from_question()
        self._question_numbers = [number for number, _ in question_numbers]
        for entity in self._question_entities:
            self._map_name(f"string:{entity}", keep_mapping=True)

        for number_in_question in self._question_numbers:
            self._map_name(f"num:{number_in_question}", keep_mapping=True)

        # Adding -1 to mapping because we need it for dates where not all three fields are
        # specified.
        self._map_name(f"num:-1", keep_mapping=True)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names to the local name mapping.
        for column_name, column_type in table_context.column_types.items():
            self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True)

        self.global_terminal_productions: Dict[str, str] = {}
        for predicate, mapped_name in self.global_name_mapping.items():
            if mapped_name in self.global_type_signatures:
                signature = self.global_type_signatures[mapped_name]
                self.global_terminal_productions[predicate] = f"{signature} -> {predicate}"

        # We don't need to recompute this ever; let's just compute it once and cache it.
        self._valid_actions: Dict[str, List[str]] = None

    @overrides
    def _get_curried_functions(self) -> Dict[Type, int]:
        return WikiTablesVariableFreeWorld.curried_functions

    @overrides
    def get_basic_types(self) -> Set[Type]:
        return types.BASIC_TYPES

    @overrides
    def get_valid_starting_types(self) -> Set[Type]:
        return types.STARTING_TYPES

    def _translate_name_and_add_mapping(self, name: str) -> str:
        if "_column:" in name:
            # Column name
            translated_name = "C%d" % self._column_counter
            self._column_counter += 1
            if name.startswith("number_column:"):
                column_type = types.NUMBER_COLUMN_TYPE
            elif name.startswith("string_column:"):
                column_type = types.STRING_COLUMN_TYPE
            else:
                column_type = types.DATE_COLUMN_TYPE
            self._add_name_mapping(name, translated_name, column_type)
            self._column_productions_for_agenda[name] = f"{column_type} -> {name}"
        elif name.startswith("string:"):
            # We do not need to translate these names.
            translated_name = name
            self._add_name_mapping(name, translated_name, types.STRING_TYPE)
        elif name.startswith("num:"):
            # NLTK throws an error if it sees a "." in constants, which will most likely happen
            # within numbers as a decimal point. We're changing those to underscores.
            translated_name = name.replace(".", "_")
            if re.match("num:-[0-9_]+", translated_name):
                # The string is a negative number. This makes NLTK interpret this as a negated
                # expression and force its type to be TRUTH_VALUE (t).
                translated_name = translated_name.replace("-", "~")
            original_name = name.replace("num:", "")
            self._add_name_mapping(original_name, translated_name, types.NUMBER_TYPE)
        return translated_name

    @overrides
    def _map_name(self, name: str, keep_mapping: bool = False) -> str:
        if name not in types.COMMON_NAME_MAPPING and name not in self.local_name_mapping:
            if not keep_mapping:
                raise ParsingError(f"Encountered un-mapped name: {name}")
            translated_name = self._translate_name_and_add_mapping(name)
        else:
            if name in types.COMMON_NAME_MAPPING:
                translated_name = types.COMMON_NAME_MAPPING[name]
            else:
                translated_name = self.local_name_mapping[name]
        return translated_name

    def get_agenda(self):
        agenda_items = []
        question_tokens = [token.text for token in self.table_context.question_tokens]
        question = " ".join(question_tokens)
        for token in question_tokens:
            if token in ["next", "after", "below"]:
                agenda_items.append("next")
            if token in ["previous", "before", "above"]:
                agenda_items.append("previous")
            if token == "total":
                agenda_items.append("sum")
            if token == "difference":
                agenda_items.append("diff")
            if token == "average":
                agenda_items.append("average")
            if token in ["least", "top", "smallest", "shortest", "lowest"]:
                # This condition is too brittle. But for most logical forms with "min", there are
                # semantically equivalent ones with "argmin". The exceptions are rare.
                if "what is the least" in question:
                    agenda_items.append("min")
                else:
                    agenda_items.append("argmin")
            if token in ["most", "largest", "highest", "longest", "greatest"]:
                # This condition is too brittle. But for most logical forms with "max", there are
                # semantically equivalent ones with "argmax". The exceptions are rare.
                if "what is the most" in question:
                    agenda_items.append("max")
                else:
                    agenda_items.append("argmax")
            if token == "first":
                agenda_items.append("first")
            if token == "last":
                agenda_items.append("last")

        if "how many" in question:
            if "sum" not in agenda_items and "average" not in agenda_items:
                # The question probably just requires counting the rows. But this is not very
                # accurate. The question could also be asking for a value that is in the table.
                agenda_items.append("count")
        agenda = []
        # Adding productions from the global set.
        for agenda_item in set(agenda_items):
            agenda.append(self.global_terminal_productions[agenda_item])

        # Adding column names that occur in question.
        question_with_underscores = "_".join(question_tokens)
        normalized_question = re.sub("[^a-z0-9_]", "", question_with_underscores)
        # We keep track of tokens that are in column names being added to the agenda. We will not
        # add string productions to the agenda if those tokens were already captured as column
        # names.
        # Note: If the same string occurs multiple times, this may cause string productions being
        # omitted from the agenda unnecessarily. That is fine, as we want to err on the side of
        # adding fewer rules to the agenda.
        tokens_in_column_names: Set[str] = set()
        for column_name_with_type, signature in self._column_productions_for_agenda.items():
            column_name = column_name_with_type.split(":")[1]
            # Underscores ensure that the match is of whole words.
            if f"_{column_name}_" in normalized_question:
                agenda.append(signature)
                for token in column_name.split("_"):
                    tokens_in_column_names.add(token)

        # Adding all productions that lead to entities and numbers extracted from the question.
        for entity in self._question_entities:
            if entity not in tokens_in_column_names:
                agenda.append(f"{types.STRING_TYPE} -> string:{entity}")

        for number in self._question_numbers:
            # The reason we check for the presence of the number in the question again is because
            # some of these numbers are extracted from number words like month names and ordinals
            # like "first". On looking at some agenda outputs, I found that they hurt more than help
            # in the agenda.
            if f"_{number}_" in normalized_question:
                agenda.append(f"{types.NUMBER_TYPE} -> {number}")

        return agenda

    def execute(self, logical_form: str) -> Union[List[str], int]:
        return self._executor.execute(logical_form)
Пример #7
0
class TestWikiTablesVariableFreeExecutor(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        table_data = [{"fb:row.row.year": "fb:cell.2001", "fb:row.row.division": "fb:cell.2",
                       "fb:row.row.league": "fb:cell.usl_a_league", "fb:row.row.regular_season":
                       "fb:cell.4th_western", "fb:row.row.playoffs": "fb:cell.quarterfinals",
                       "fb:row.row.open_cup": "fb:cell.did_not_qualify",
                       "fb:row.row.avg_attendance": "fb:cell.7169"},
                      {"fb:row.row.year": "fb:cell.2005", "fb:row.row.division": "fb:cell.2",
                       "fb:row.row.league": "fb:cell.usl_first_division", "fb:row.row.regular_season":
                       "fb:cell.5th", "fb:row.row.playoffs": "fb:cell.quarterfinals",
                       "fb:row.row.open_cup": "fb:cell.4th_round",
                       "fb:row.row.avg_attendance": "fb:cell.6028"}]
        self.executor = WikiTablesVariableFreeExecutor(table_data)
        table_data_with_date = [{"fb:row.row.date": "fb:cell.january_2001", "fb:row.row.division": "fb:cell.2",
                                 "fb:row.row.league": "fb:cell.usl_a_league", "fb:row.row.regular_season":
                                 "fb:cell.4th_western", "fb:row.row.playoffs": "fb:cell.quarterfinals",
                                 "fb:row.row.open_cup": "fb:cell.did_not_qualify",
                                 "fb:row.row.avg_attendance": "fb:cell.7169"},
                                {"fb:row.row.date": "fb:cell.march_2005", "fb:row.row.division": "fb:cell.2",
                                 "fb:row.row.league": "fb:cell.usl_first_division", "fb:row.row.regular_season":
                                 "fb:cell.5th", "fb:row.row.playoffs": "fb:cell.quarterfinals",
                                 "fb:row.row.open_cup": "fb:cell.4th_round",
                                 "fb:row.row.avg_attendance": "fb:cell.6028"}]
        self.executor_with_date = WikiTablesVariableFreeExecutor(table_data_with_date)

    def test_execute_fails_with_unknown_function(self):
        logical_form = "(unknown_function all_rows fb:row.row.league)"
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_fails_with_unknown_constant(self):
        logical_form = "12fdw"
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_select(self):
        logical_form = "(select all_rows fb:row.row.league)"
        cell_list = self.executor.execute(logical_form)
        assert set(cell_list) == {'fb:cell.usl_a_league', 'fb:cell.usl_first_division'}

    def test_execute_works_with_argmax(self):
        logical_form = "(select (argmax all_rows fb:row.row.avg_attendance) fb:row.row.league)"
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ['fb:cell.usl_a_league']

    def test_execute_works_with_argmax_on_dates(self):
        logical_form = "(select (argmax all_rows fb:row.row.date) fb:row.row.league)"
        cell_list = self.executor_with_date.execute(logical_form)
        assert cell_list == ['fb:cell.usl_first_division']

    def test_execute_works_with_argmin(self):
        logical_form = "(select (argmin all_rows fb:row.row.avg_attendance) fb:row.row.year)"
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ['fb:cell.2005']

    def test_execute_works_with_argmin_on_dates(self):
        logical_form = "(select (argmin all_rows fb:row.row.date) fb:row.row.league)"
        cell_list = self.executor_with_date.execute(logical_form)
        assert cell_list == ['fb:cell.usl_a_league']

    def test_execute_works_with_filter_number_greater(self):
        # Selecting cell values from all rows that have attendance greater than the min value of
        # attendance.
        logical_form = """(select (filter_number_greater all_rows fb:row.row.avg_attendance
                                   (min all_rows fb:row.row.avg_attendance)) fb:row.row.league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['fb:cell.usl_a_league']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_number_greater all_rows fb:row.row.avg_attendance
                                   all_rows) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_greater(self):
        # Selecting cell values from all rows that have date greater than 2002.
        logical_form = """(select (filter_date_greater all_rows fb:row.row.date
                                   (date 2002 -1 -1)) fb:row.row.league)"""
        cell_value_list = self.executor_with_date.execute(logical_form)
        assert cell_value_list == ['fb:cell.usl_first_division']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_greater all_rows fb:row.row.date
                                   2005) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor_with_date.execute(logical_form)

    def test_execute_works_with_filter_number_greater_equals(self):
        # Counting rows that have attendance greater than or equal to the min value of attendance.
        logical_form = """(count (filter_number_greater_equals all_rows fb:row.row.avg_attendance
                                  (min all_rows fb:row.row.avg_attendance)))"""
        count_result = self.executor.execute(logical_form)
        assert count_result == 2
        # Replacing the filter value with an invalid value.
        logical_form = """(count (filter_number_greater all_rows fb:row.row.avg_attendance all_rows))"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_greater_equals(self):
        # Selecting cell values from all rows that have date greater than or equal to 2005 February
        # 1st.
        logical_form = """(select (filter_date_greater_equals all_rows fb:row.row.date
                                   (date 2005 2 1)) fb:row.row.league)"""
        cell_value_list = self.executor_with_date.execute(logical_form)
        assert cell_value_list == ['fb:cell.usl_first_division']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_greater_equals all_rows fb:row.row.date
                                   2005) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor_with_date.execute(logical_form)

    def test_execute_works_with_filter_number_lesser(self):
        # Selecting cell values from all rows that have date lesser than 2005.
        logical_form = """(select (filter_number_lesser all_rows fb:row.row.year
                                    (max all_rows fb:row.row.year)) fb:row.row.league)"""
        cell_value_list = self.executor.execute(logical_form)
        assert cell_value_list == ['fb:cell.usl_a_league']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_number_lesser all_rows fb:row.row.year
                                   (date 2005 -1 -1)) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_lesser(self):
        # Selecting cell values from all rows that have date less that 2005 January
        logical_form = """(select (filter_date_lesser all_rows fb:row.row.date
                                   (date 2005 1 -1)) fb:row.row.league)"""
        cell_value_list = self.executor_with_date.execute(logical_form)
        assert cell_value_list == ["fb:cell.usl_a_league"]
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_lesser all_rows fb:row.row.date
                                   2005) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor_with_date.execute(logical_form)

    def test_execute_works_with_filter_number_lesser_equals(self):
        # Counting rows that have year lesser than or equal to 2005.
        logical_form = """(count (filter_number_lesser_equals all_rows fb:row.row.year 2005))"""
        count_result = self.executor.execute(logical_form)
        assert count_result == 2
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_number_lesser_equals all_rows fb:row.row.year
                                   (date 2005 -1 -1)) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_lesser_equals(self):
        # Selecting cell values from all rows that have date less that or equal to 2001 February 23
        logical_form = """(select (filter_date_lesser_equals all_rows fb:row.row.date
                                   (date 2001 2 23)) fb:row.row.league)"""
        cell_value_list = self.executor_with_date.execute(logical_form)
        assert cell_value_list == ['fb:cell.usl_a_league']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_lesser_equals all_rows fb:row.row.date
                                   2005) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor_with_date.execute(logical_form)

    def test_execute_works_with_filter_number_equals(self):
        # Counting rows that have year equal to 2010.
        logical_form = """(count (filter_number_equals all_rows fb:row.row.year 2010))"""
        count_result = self.executor.execute(logical_form)
        assert count_result == 0
        # Replacing the filter value with an invalid value.
        logical_form = """(count (filter_number_equals all_rows fb:row.row.year (date 2010 -1 -1))"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_equals(self):
        # Selecting cell values from all rows that have date not equal to 2001
        logical_form = """(select (filter_date_equals all_rows fb:row.row.date
                                   (date 2001 -1 -1)) fb:row.row.league)"""
        cell_value_list = self.executor_with_date.execute(logical_form)
        assert cell_value_list == ['fb:cell.usl_a_league']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_equals all_rows fb:row.row.date
                                   2005) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor_with_date.execute(logical_form)

    def test_execute_works_with_filter_number_not_equals(self):
        # Counting rows that have year not equal to 2010.
        logical_form = """(count (filter_number_not_equals all_rows fb:row.row.year 2010))"""
        count_result = self.executor.execute(logical_form)
        assert count_result == 2
        # Replacing the filter value with an invalid value.
        logical_form = """(count (filter_number_not_equals all_rows fb:row.row.year (date 2010 -1 -1))"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_works_with_filter_date_not_equals(self):
        # Selecting cell values from all rows that have date not equal to 2001
        logical_form = """(select (filter_date_not_equals all_rows fb:row.row.date
                                   (date 2001 -1 -1)) fb:row.row.league)"""
        cell_value_list = self.executor_with_date.execute(logical_form)
        assert cell_value_list == ['fb:cell.usl_first_division']
        # Replacing the filter value with an invalid value.
        logical_form = """(select (filter_date_not_equals all_rows fb:row.row.date
                                   2005) fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor_with_date.execute(logical_form)

    def test_execute_works_with_filter_in(self):
        # Selecting "regular season" from rows that have "did not qualify" in "open cup" column.
        logical_form = """(select (filter_in all_rows fb:row.row.open_cup did_not_qualify)
                                  fb:row.row.regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.4th_western"]

    def test_execute_works_with_filter_not_in(self):
        # Selecting "regular season" from rows that do not have "did not qualify" in "open cup" column.
        logical_form = """(select (filter_not_in all_rows fb:row.row.open_cup did_not_qualify)
                                   fb:row.row.regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.5th"]

    def test_execute_works_with_first(self):
        # Selecting "regular season" from the first row.
        logical_form = """(select (first all_rows) fb:row.row.regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.4th_western"]

    def test_execute_logs_warning_with_first_on_empty_list(self):
        # Selecting "regular season" from the first row where year is greater than 2010.
        with self.assertLogs("allennlp.semparse.executors.wikitables_variable_free_executor") as log:
            logical_form = """(select (first (filter_number_greater all_rows fb:row.row.year 2010))
                                      fb:row.row.regular_season)"""
            self.executor.execute(logical_form)
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.executors.wikitables_variable_free_executor:"
                          "Trying to get first row from an empty list: "
                          "['filter_number_greater', 'all_rows', 'fb:row.row.year', '2010']"])

    def test_execute_works_with_last(self):
        # Selecting "regular season" from the last row where year is not equal to 2010.
        logical_form = """(select (last (filter_number_not_equals all_rows fb:row.row.year 2010))
                                  fb:row.row.regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.5th"]

    def test_execute_logs_warning_with_last_on_empty_list(self):
        # Selecting "regular season" from the last row where year is greater than 2010.
        with self.assertLogs("allennlp.semparse.executors.wikitables_variable_free_executor") as log:
            logical_form = """(select (last (filter_number_greater all_rows fb:row.row.year 2010))
                                      fb:row.row.regular_season)"""
            self.executor.execute(logical_form)
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.executors.wikitables_variable_free_executor:"
                          "Trying to get last row from an empty list: "
                          "['filter_number_greater', 'all_rows', 'fb:row.row.year', '2010']"])

    def test_execute_works_with_previous(self):
        # Selecting "regular season" from the row before last where year is not equal to 2010.
        logical_form = """(select (previous (last (filter_number_not_equals
                                                    all_rows fb:row.row.year 2010)))
                                  fb:row.row.regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.4th_western"]

    def test_execute_logs_warning_with_previous_on_empty_list(self):
        # Selecting "regular season" from the row before the one where year is greater than 2010.
        with self.assertLogs("allennlp.semparse.executors.wikitables_variable_free_executor") as log:
            logical_form = """(select (previous (filter_number_greater all_rows fb:row.row.year 2010))
                                      fb:row.row.regular_season)"""
            self.executor.execute(logical_form)
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.executors.wikitables_variable_free_executor:"
                          "Trying to get the previous row from an empty list: "
                          "['filter_number_greater', 'all_rows', 'fb:row.row.year', '2010']"])

    def test_execute_works_with_next(self):
        # Selecting "regular season" from the row after first where year is not equal to 2010.
        logical_form = """(select (next (first (filter_number_not_equals
                                                    all_rows fb:row.row.year 2010)))
                                  fb:row.row.regular_season)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.5th"]

    def test_execute_logs_warning_with_next_on_empty_list(self):
        # Selecting "regular season" from the row after the one where year is greater than 2010.
        with self.assertLogs("allennlp.semparse.executors.wikitables_variable_free_executor") as log:
            logical_form = """(select (next (filter_number_greater all_rows fb:row.row.year 2010))
                                      fb:row.row.regular_season)"""
            self.executor.execute(logical_form)
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.executors.wikitables_variable_free_executor:"
                          "Trying to get the next row from an empty list: "
                          "['filter_number_greater', 'all_rows', 'fb:row.row.year', '2010']"])

    def test_execute_works_with_mode(self):
        # Most frequent division value.
        logical_form = """(mode all_rows fb:row.row.division)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.2"]
        # If we used select instead, we should get a list of two values.
        logical_form = """(select all_rows fb:row.row.division)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.2", "fb:cell.2"]
        # If we queried for the most frequent year instead, it should return two values since both
        # have the max frequency of 1.
        logical_form = """(mode all_rows fb:row.row.year)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.2001", "fb:cell.2005"]

    def test_execute_works_with_same_as(self):
        # Select the "league" from all the rows that have the same value under "playoffs" as the
        # row that has the string "a league" under "league".
        logical_form = """(select (same_as (filter_in all_rows fb:row.row.league a_league)
                                   fb:row.row.playoffs)
                           fb:row.row.league)"""
        cell_list = self.executor.execute(logical_form)
        assert cell_list == ["fb:cell.usl_a_league", "fb:cell.usl_first_division"]

    def test_execute_works_with_sum(self):
        # Get total "avg attendance".
        logical_form = """(sum all_rows fb:row.row.avg_attendance)"""
        sum_value = self.executor.execute(logical_form)
        assert sum_value == 13197
        # Total "avg attendance" where "playoffs" has "quarterfinals"
        logical_form = """(sum (filter_in all_rows fb:row.row.playoffs quarterfinals)
                                fb:row.row.avg_attendance)"""
        sum_value = self.executor.execute(logical_form)
        assert sum_value == 13197

    def test_execute_works_with_average(self):
        # Get average "avg attendance".
        logical_form = """(average all_rows fb:row.row.avg_attendance)"""
        avg_value = self.executor.execute(logical_form)
        assert avg_value == 6598.5
        # Average "avg attendance" where "playoffs" has "quarterfinals"
        logical_form = """(average (filter_in all_rows fb:row.row.playoffs quarterfinals)
                                fb:row.row.avg_attendance)"""
        avg_value = self.executor.execute(logical_form)
        assert avg_value == 6598.5

    def test_execute_works_with_diff(self):
        # Difference in "avg attendance" between rows with "usl_a_league" and "usl_first_division"
        # in "league" columns.
        logical_form = """(diff (filter_in all_rows fb:row.row.league usl_a_league)
                                (filter_in all_rows fb:row.row.league usl_first_division)
                                fb:row.row.avg_attendance)"""
        avg_value = self.executor.execute(logical_form)
        assert avg_value == 1141

    def test_execute_fails_with_diff_on_non_numerical_columns(self):
        logical_form = """(diff (filter_in all_rows fb:row.row.league usl_a_league)
                                (filter_in all_rows fb:row.row.league usl_first_division)
                                fb:row.row.league)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_execute_fails_with_non_int_dates(self):
        logical_form = """(date 2015 1.5 1)"""
        with self.assertRaises(ExecutionError):
            self.executor.execute(logical_form)

    def test_date_comparison_works(self):
        assert Date(2013, 12, 31) > Date(2013, 12, 30)
        assert Date(2013, 12, 31) == Date(2013, 12, -1)
        assert Date(2013, -1, -1) >= Date(2013, 12, 31)
        # pylint: disable=singleton-comparison
        assert (Date(2013, 12, -1) > Date(2013, 12, 31)) == False
        assert (Date(2013, 12, 31) > 2013) == False
        assert (Date(2013, 12, 31) >= 2013) == False
        assert Date(2013, 12, 31) != 2013
        assert (Date(2018, 1, 1) >= Date(-1, 2, 1)) == False
        assert (Date(2018, 1, 1) < Date(-1, 2, 1)) == False
        # When year is unknown in both cases, we can compare months and days.
        assert Date(-1, 2, 1) < Date(-1, 2, 3)
        # If both year and month are not know in both cases, the comparison is undefined, and both
        # < and >= return False.
        assert (Date(-1, -1, 1) < Date(-1, -1, 3)) == False
        assert (Date(-1, -1, 1) >= Date(-1, -1, 3)) == False
        # Same when year is known, buth months are not.
        assert (Date(2018, -1, 1) < Date(2018, -1, 3)) == False
        # TODO (pradeep): Figure out whether this is expected behavior by looking at data.
        assert (Date(2018, -1, 1) >= Date(2018, -1, 3)) == False
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(constant_type_prefixes={"string": types.STRING_TYPE,
                                                 "num": types.NUMBER_TYPE},
                         global_type_signatures=types.COMMON_TYPE_SIGNATURE,
                         global_name_mapping=types.COMMON_NAME_MAPPING)
        self.table_context = table_context
        # We add name mapping and signatures corresponding to specific column types to the local
        # name mapping based on the table content here.
        column_types = table_context.column_types.values()
        self._table_has_string_columns = False
        self._table_has_date_columns = False
        self._table_has_number_columns = False
        if "string" in column_types:
            for name, translated_name in types.STRING_COLUMN_NAME_MAPPING.items():
                signature = types.STRING_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_string_columns = True
        if "date" in column_types:
            for name, translated_name in types.DATE_COLUMN_NAME_MAPPING.items():
                signature = types.DATE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            # Adding -1 to mapping because we need it for dates where not all three fields are
            # specified. We want to do this only when the table has a date column. This is because
            # the knowledge graph is also constructed in such a way that -1 is an entity with date
            # columns as the neighbors only if any date columns exist in the table.
            self._map_name(f"num:-1", keep_mapping=True)
            self._table_has_date_columns = True
        if "number" in column_types:
            for name, translated_name in types.NUMBER_COLUMN_NAME_MAPPING.items():
                signature = types.NUMBER_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_number_columns = True
        if "date" in column_types or "number" in column_types:
            for name, translated_name in types.COMPARABLE_COLUMN_NAME_MAPPING.items():
                signature = types.COMPARABLE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)

        self.table_graph = table_context.get_table_knowledge_graph()

        self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data)

        # TODO (pradeep): Use a NameMapper for mapping entity names too.
        # For every new column name seen, we update this counter to map it to a new NLTK name.
        self._column_counter = 0

        # Adding entities and numbers seen in questions to the mapping.
        question_entities, question_numbers = table_context.get_entities_from_question()
        self._question_entities = [entity for entity, _ in question_entities]
        self._question_numbers = [number for number, _ in question_numbers]
        for entity in self._question_entities:
            # These entities all have prefix "string:"
            self._map_name(entity, keep_mapping=True)

        for number_in_question in self._question_numbers:
            self._map_name(f"num:{number_in_question}", keep_mapping=True)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names to the local name mapping.
        for column_name, column_type in table_context.column_types.items():
            self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True)

        self.terminal_productions: Dict[str, str] = {}
        name_mapping = [(name, mapping) for name, mapping in self.global_name_mapping.items()]
        name_mapping += [(name, mapping) for name, mapping in self.local_name_mapping.items()]
        signatures = self.global_type_signatures.copy()
        signatures.update(self.local_type_signatures)
        for predicate, mapped_name in name_mapping:
            if mapped_name in signatures:
                signature = signatures[mapped_name]
                self.terminal_productions[predicate] = f"{signature} -> {predicate}"

        # We don't need to recompute this ever; let's just compute it once and cache it.
        self._valid_actions: Dict[str, List[str]] = None
class WikiTablesVariableFreeWorld(World):
    """
    World representation for the WikitableQuestions domain with the variable-free language used in
    the paper from Liang et al. (2018).

    Parameters
    ----------
    table_graph : ``TableQuestionKnowledgeGraph``
        Context associated with this world.
    """
    # When we're converting from logical forms to action sequences, this set tells us which
    # functions in the logical form are curried functions, and how many arguments the function
    # actually takes.  This is necessary because NLTK curries all multi-argument functions to a
    # series of one-argument function applications.  See `world._get_transitions` for more info.
    curried_functions = {
            types.SELECT_TYPE: 2,
            types.ROW_FILTER_WITH_GENERIC_COLUMN: 2,
            types.ROW_FILTER_WITH_COMPARABLE_COLUMN: 2,
            types.ROW_NUM_OP: 2,
            types.ROW_FILTER_WITH_COLUMN_AND_NUMBER: 3,
            types.ROW_FILTER_WITH_COLUMN_AND_DATE: 3,
            types.ROW_FILTER_WITH_COLUMN_AND_STRING: 3,
            types.NUM_DIFF_WITH_COLUMN: 3,
            }

    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(constant_type_prefixes={"string": types.STRING_TYPE,
                                                 "num": types.NUMBER_TYPE},
                         global_type_signatures=types.COMMON_TYPE_SIGNATURE,
                         global_name_mapping=types.COMMON_NAME_MAPPING)
        self.table_context = table_context
        # We add name mapping and signatures corresponding to specific column types to the local
        # name mapping based on the table content here.
        column_types = table_context.column_types.values()
        self._table_has_string_columns = False
        self._table_has_date_columns = False
        self._table_has_number_columns = False
        if "string" in column_types:
            for name, translated_name in types.STRING_COLUMN_NAME_MAPPING.items():
                signature = types.STRING_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_string_columns = True
        if "date" in column_types:
            for name, translated_name in types.DATE_COLUMN_NAME_MAPPING.items():
                signature = types.DATE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            # Adding -1 to mapping because we need it for dates where not all three fields are
            # specified. We want to do this only when the table has a date column. This is because
            # the knowledge graph is also constructed in such a way that -1 is an entity with date
            # columns as the neighbors only if any date columns exist in the table.
            self._map_name(f"num:-1", keep_mapping=True)
            self._table_has_date_columns = True
        if "number" in column_types:
            for name, translated_name in types.NUMBER_COLUMN_NAME_MAPPING.items():
                signature = types.NUMBER_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_number_columns = True
        if "date" in column_types or "number" in column_types:
            for name, translated_name in types.COMPARABLE_COLUMN_NAME_MAPPING.items():
                signature = types.COMPARABLE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)

        self.table_graph = table_context.get_table_knowledge_graph()

        self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data)

        # TODO (pradeep): Use a NameMapper for mapping entity names too.
        # For every new column name seen, we update this counter to map it to a new NLTK name.
        self._column_counter = 0

        # Adding entities and numbers seen in questions to the mapping.
        question_entities, question_numbers = table_context.get_entities_from_question()
        self._question_entities = [entity for entity, _ in question_entities]
        self._question_numbers = [number for number, _ in question_numbers]
        for entity in self._question_entities:
            # These entities all have prefix "string:"
            self._map_name(entity, keep_mapping=True)

        for number_in_question in self._question_numbers:
            self._map_name(f"num:{number_in_question}", keep_mapping=True)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names to the local name mapping.
        for column_name, column_type in table_context.column_types.items():
            self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True)

        self.terminal_productions: Dict[str, str] = {}
        name_mapping = [(name, mapping) for name, mapping in self.global_name_mapping.items()]
        name_mapping += [(name, mapping) for name, mapping in self.local_name_mapping.items()]
        signatures = self.global_type_signatures.copy()
        signatures.update(self.local_type_signatures)
        for predicate, mapped_name in name_mapping:
            if mapped_name in signatures:
                signature = signatures[mapped_name]
                self.terminal_productions[predicate] = f"{signature} -> {predicate}"

        # We don't need to recompute this ever; let's just compute it once and cache it.
        self._valid_actions: Dict[str, List[str]] = None

    @staticmethod
    def is_instance_specific_entity(entity_name: str) -> bool:
        """
        Instance specific entities are column names, strings and numbers. Returns True if the entity
        is one of those.
        """
        entity_is_number = False
        try:
            float(entity_name)
            entity_is_number = True
        except ValueError:
            pass
        # Column names start with "*_column:", strings start with "string:"
        return "_column:" in entity_name or entity_name.startswith("string:") or entity_is_number

    @overrides
    def _get_curried_functions(self) -> Dict[Type, int]:
        return WikiTablesVariableFreeWorld.curried_functions

    @overrides
    def get_basic_types(self) -> Set[Type]:
        basic_types = set(types.BASIC_TYPES)
        if self._table_has_string_columns:
            basic_types.add(types.STRING_COLUMN_TYPE)
        if self._table_has_date_columns:
            basic_types.add(types.DATE_COLUMN_TYPE)
            basic_types.add(types.COMPARABLE_COLUMN_TYPE)
        if self._table_has_number_columns:
            basic_types.add(types.NUMBER_COLUMN_TYPE)
            basic_types.add(types.COMPARABLE_COLUMN_TYPE)
        return basic_types

    @overrides
    def get_valid_starting_types(self) -> Set[Type]:
        return types.STARTING_TYPES

    def _translate_name_and_add_mapping(self, name: str) -> str:
        if "_column:" in name:
            # Column name
            translated_name = "C%d" % self._column_counter
            self._column_counter += 1
            if name.startswith("number_column:"):
                column_type = types.NUMBER_COLUMN_TYPE
            elif name.startswith("string_column:"):
                column_type = types.STRING_COLUMN_TYPE
            else:
                column_type = types.DATE_COLUMN_TYPE
            self._add_name_mapping(name, translated_name, column_type)
            self._column_productions_for_agenda[name] = f"{column_type} -> {name}"
        elif name.startswith("string:"):
            # We do not need to translate these names.
            translated_name = name
            self._add_name_mapping(name, translated_name, types.STRING_TYPE)
        elif name.startswith("num:"):
            # NLTK throws an error if it sees a "." in constants, which will most likely happen
            # within numbers as a decimal point. We're changing those to underscores.
            translated_name = name.replace(".", "_")
            if re.match("num:-[0-9_]+", translated_name):
                # The string is a negative number. This makes NLTK interpret this as a negated
                # expression and force its type to be TRUTH_VALUE (t).
                translated_name = translated_name.replace("-", "~")
            original_name = name.replace("num:", "")
            self._add_name_mapping(original_name, translated_name, types.NUMBER_TYPE)
        return translated_name

    @overrides
    def _map_name(self, name: str, keep_mapping: bool = False) -> str:
        if name not in types.COMMON_NAME_MAPPING and name not in self.local_name_mapping:
            if not keep_mapping:
                raise ParsingError(f"Encountered un-mapped name: {name}")
            translated_name = self._translate_name_and_add_mapping(name)
        else:
            if name in types.COMMON_NAME_MAPPING:
                translated_name = types.COMMON_NAME_MAPPING[name]
            else:
                translated_name = self.local_name_mapping[name]
        return translated_name

    def get_agenda(self):
        agenda_items = []
        question_tokens = [token.text for token in self.table_context.question_tokens]
        question = " ".join(question_tokens)
        if "at least" in question:
            agenda_items.append("filter_number_greater_equals")
        if "at most" in question:
            agenda_items.append("filter_number_lesser_equals")

        comparison_triggers = ["greater", "larger", "more"]
        if any("no %s than" %word in question for word in comparison_triggers):
            agenda_items.append("filter_number_lesser_equals")
        elif any("%s than" %word in question for word in comparison_triggers):
            agenda_items.append("filter_number_greater")
        for token in question_tokens:
            if token in ["next", "after", "below"]:
                agenda_items.append("next")
            if token in ["previous", "before", "above"]:
                agenda_items.append("previous")
            if token == "total":
                agenda_items.append("sum")
            if token == "difference":
                agenda_items.append("diff")
            if token == "average":
                agenda_items.append("average")
            if token in ["least", "smallest", "shortest", "lowest"] and "at least" not in question:
                # This condition is too brittle. But for most logical forms with "min", there are
                # semantically equivalent ones with "argmin". The exceptions are rare.
                if "what is the least" in question:
                    agenda_items.append("min")
                else:
                    agenda_items.append("argmin")
            if token in ["most", "largest", "highest", "longest", "greatest"] and "at most" not in question:
                # This condition is too brittle. But for most logical forms with "max", there are
                # semantically equivalent ones with "argmax". The exceptions are rare.
                if "what is the most" in question:
                    agenda_items.append("max")
                else:
                    agenda_items.append("argmax")
            if token in ["first", "top"]:
                agenda_items.append("first")
            if token in ["last", "bottom"]:
                agenda_items.append("last")

        if "how many" in question:
            if "sum" not in agenda_items and "average" not in agenda_items:
                # The question probably just requires counting the rows. But this is not very
                # accurate. The question could also be asking for a value that is in the table.
                agenda_items.append("count")
        agenda = []
        # Adding productions from the global set.
        for agenda_item in set(agenda_items):
            # Some agenda items may not be present in the terminal productions because some of these
            # terminals are table-content specific. For example, if the question triggered "sum",
            # and the table does not have number columns, we should not add "<r,<f,n>> -> sum" to
            # the agenda.
            if agenda_item in self.terminal_productions:
                agenda.append(self.terminal_productions[agenda_item])

        # Adding column names that occur in question.
        question_with_underscores = "_".join(question_tokens)
        normalized_question = re.sub("[^a-z0-9_]", "", question_with_underscores)
        # We keep track of tokens that are in column names being added to the agenda. We will not
        # add string productions to the agenda if those tokens were already captured as column
        # names.
        # Note: If the same string occurs multiple times, this may cause string productions being
        # omitted from the agenda unnecessarily. That is fine, as we want to err on the side of
        # adding fewer rules to the agenda.
        tokens_in_column_names: Set[str] = set()
        for column_name_with_type, signature in self._column_productions_for_agenda.items():
            column_name = column_name_with_type.split(":")[1]
            # Underscores ensure that the match is of whole words.
            if f"_{column_name}_" in normalized_question:
                agenda.append(signature)
                for token in column_name.split("_"):
                    tokens_in_column_names.add(token)

        # Adding all productions that lead to entities and numbers extracted from the question.
        for entity in self._question_entities:
            if entity.replace("string:", "") not in tokens_in_column_names:
                agenda.append(f"{types.STRING_TYPE} -> {entity}")

        for number in self._question_numbers:
            # The reason we check for the presence of the number in the question again is because
            # some of these numbers are extracted from number words like month names and ordinals
            # like "first". On looking at some agenda outputs, I found that they hurt more than help
            # in the agenda.
            if f"_{number}_" in normalized_question:
                agenda.append(f"{types.NUMBER_TYPE} -> {number}")
        return agenda


    def execute(self, logical_form: str) -> Union[List[str], int]:
        return self._executor.execute(logical_form)

    def evaluate_logical_form(self, logical_form: str, target_list: List[str]) -> bool:
        """
        Takes a logical forms and a list of target values as strings from the original lisp
        representation of instances, and returns True iff the logical form executes to those values.
        """
        return self._executor.evaluate_logical_form(logical_form, target_list)
Пример #10
0
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(constant_type_prefixes={"string": types.STRING_TYPE,
                                                 "num": types.NUMBER_TYPE},
                         global_type_signatures=types.COMMON_TYPE_SIGNATURE,
                         global_name_mapping=types.COMMON_NAME_MAPPING)
        self.table_context = table_context
        # We add name mapping and signatures corresponding to specific column types to the local
        # name mapping based on the table content here.
        column_types = table_context.column_types.values()
        self._table_has_string_columns = False
        self._table_has_date_columns = False
        self._table_has_number_columns = False
        if "string" in column_types:
            for name, translated_name in types.STRING_COLUMN_NAME_MAPPING.items():
                signature = types.STRING_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_string_columns = True
        if "date" in column_types:
            for name, translated_name in types.DATE_COLUMN_NAME_MAPPING.items():
                signature = types.DATE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            # Adding -1 to mapping because we need it for dates where not all three fields are
            # specified. We want to do this only when the table has a date column. This is because
            # the knowledge graph is also constructed in such a way that -1 is an entity with date
            # columns as the neighbors only if any date columns exist in the table.
            self._map_name(f"num:-1", keep_mapping=True)
            self._table_has_date_columns = True
        if "number" in column_types:
            for name, translated_name in types.NUMBER_COLUMN_NAME_MAPPING.items():
                signature = types.NUMBER_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_number_columns = True
        if "date" in column_types or "number" in column_types:
            for name, translated_name in types.COMPARABLE_COLUMN_NAME_MAPPING.items():
                signature = types.COMPARABLE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)

        self.table_graph = table_context.get_table_knowledge_graph()

        self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data)

        # TODO (pradeep): Use a NameMapper for mapping entity names too.
        # For every new column name seen, we update this counter to map it to a new NLTK name.
        self._column_counter = 0

        # Adding entities and numbers seen in questions to the mapping.
        question_entities, question_numbers = table_context.get_entities_from_question()
        self._question_entities = [entity for entity, _, _, _ in question_entities]
        self._question_numbers = [number for number, _ in question_numbers]

        self.ent2id = dict()
        for entity, start, end, _ in question_entities:
           self.ent2id[entity] = (start, end)
        self.num2id = dict()
        for num, _id in question_numbers:
            if num != -1:
                self.num2id[num] = _id

        for entity in self._question_entities:
            # These entities all have prefix "string:"
            self._map_name(entity, keep_mapping=True)

        for number_in_question in self._question_numbers:
            self._map_name(f"num:{number_in_question}", keep_mapping=True)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names to the local name mapping.
        for column_name, column_type in table_context.column_types.items():
            self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True)

        self.terminal_productions: Dict[str, str] = {}
        name_mapping = [(name, mapping) for name, mapping in self.global_name_mapping.items()]
        name_mapping += [(name, mapping) for name, mapping in self.local_name_mapping.items()]
        signatures = self.global_type_signatures.copy()
        signatures.update(self.local_type_signatures)
        for predicate, mapped_name in name_mapping:
            if mapped_name in signatures:
                signature = signatures[mapped_name]
                self.terminal_productions[predicate] = f"{signature} -> {predicate}"

        # We don't need to recompute this ever; let's just compute it once and cache it.
        self._valid_actions: Dict[str, List[str]] = None