Ejemplo n.º 1
0
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(start_types={Number, Date, List[str]})
        self.table_context = table_context
        self.table_data = [Row(row) for row in table_context.table_data]

        self.exe_restricted = True
        self.restrict_num = 1

        column_types = table_context.column_types.values()
        if "string" in column_types:
            self.add_predicate('filter_in', self.filter_in)
            self.add_predicate('filter_not_in', self.filter_not_in)
        if "date" in column_types:
            self.add_predicate('filter_date_greater', self.filter_date_greater)
            self.add_predicate('filter_date_greater_equals',
                               self.filter_date_greater_equals)
            self.add_predicate('filter_date_lesser', self.filter_date_lesser)
            self.add_predicate('filter_date_lesser_equals',
                               self.filter_date_lesser_equals)
            self.add_predicate('filter_date_equals', self.filter_date_equals)
            self.add_predicate('filter_date_not_equals',
                               self.filter_date_not_equals)
            # Adding -1 to mapping because we need it for dates where not all three fields are
            # specified. We want to do this only when the table has a date column. This is because
            # the knowledge graph is also constructed in such a way that -1 is an entity with date
            # columns as the neighbors only if any date columns exist in the table.
            self.add_constant('-1', -1, type_=Number)
        if "number" in column_types:
            self.add_predicate('filter_number_greater',
                               self.filter_number_greater)
            self.add_predicate('filter_number_greater_equals',
                               self.filter_number_greater_equals)
            self.add_predicate('filter_number_lesser',
                               self.filter_number_lesser)
            self.add_predicate('filter_number_lesser_equals',
                               self.filter_number_lesser_equals)
            self.add_predicate('filter_number_equals',
                               self.filter_number_equals)
            self.add_predicate('filter_number_not_equals',
                               self.filter_number_not_equals)
            self.add_predicate('max', self.max)
            self.add_predicate('min', self.min)
            self.add_predicate('average', self.average)
            self.add_predicate('sum', self.sum)
            self.add_predicate('diff', self.diff)
        if "date" in column_types or "number" in column_types:
            self.add_predicate('argmax', self.argmax)
            self.add_predicate('argmin', self.argmin)

        self.table_graph = table_context.get_table_knowledge_graph()

        # Adding entities and numbers seen in questions as constants.
        question_entities, question_numbers = table_context.get_entities_from_question(
        )
        self._question_entities = [
            entity for entity, _, _, _ in question_entities
        ]
        self._question_numbers = [number for number, _ in question_numbers]
        for entity in self._question_entities:
            self.add_constant(entity, entity)

        for number in self._question_numbers:
            self.add_constant(str(number), float(number), type_=Number)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names as constants.  Each column gets added once for every
        # type in the hierarchy going from its concrete class to the base Column.  String columns
        # get added as StringColumn and Column, and date and number columns get added as DateColumn
        # (or NumberColumn), ComparableColumn, and Column.
        for column_name, column_type in table_context.column_types.items():
            column_name = f"{column_type}_column:{column_name}"
            column: Column = None
            if column_type == 'string':
                column = StringColumn(column_name)
            elif column_type == 'date':
                column = DateColumn(column_name)
                self.add_constant(column_name, column, type_=ComparableColumn)
            elif column_type == 'number':
                column = NumberColumn(column_name)
                self.add_constant(column_name, column, type_=ComparableColumn)
            self.add_constant(column_name, column, type_=Column)
            self.add_constant(column_name, column)
            column_type_name = str(PredicateType.get_type(type(column)))
            self._column_productions_for_agenda[
                column_name] = f"{column_type_name} -> {column_name}"

        # Mapping from terminal strings to productions that produce them.  We use this in the
        # agenda-related methods, and some models that use this language look at this field to know
        # how many terminals to plan for.
        self.terminal_productions: Dict[str, str] = {}
        for name, types in self._function_types.items():
            self.terminal_productions[name] = "%s -> %s" % (types[0], name)
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(constant_type_prefixes={"string": types.STRING_TYPE,
                                                 "num": types.NUMBER_TYPE},
                         global_type_signatures=types.COMMON_TYPE_SIGNATURE,
                         global_name_mapping=types.COMMON_NAME_MAPPING)
        self.table_context = table_context
        # We add name mapping and signatures corresponding to specific column types to the local
        # name mapping based on the table content here.
        column_types = table_context.column_types.values()
        self._table_has_string_columns = False
        self._table_has_date_columns = False
        self._table_has_number_columns = False
        if "string" in column_types:
            for name, translated_name in types.STRING_COLUMN_NAME_MAPPING.items():
                signature = types.STRING_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_string_columns = True
        if "date" in column_types:
            for name, translated_name in types.DATE_COLUMN_NAME_MAPPING.items():
                signature = types.DATE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            # Adding -1 to mapping because we need it for dates where not all three fields are
            # specified. We want to do this only when the table has a date column. This is because
            # the knowledge graph is also constructed in such a way that -1 is an entity with date
            # columns as the neighbors only if any date columns exist in the table.
            self._map_name(f"num:-1", keep_mapping=True)
            self._table_has_date_columns = True
        if "number" in column_types:
            for name, translated_name in types.NUMBER_COLUMN_NAME_MAPPING.items():
                signature = types.NUMBER_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_number_columns = True
        if "date" in column_types or "number" in column_types:
            for name, translated_name in types.COMPARABLE_COLUMN_NAME_MAPPING.items():
                signature = types.COMPARABLE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)

        self.table_graph = table_context.get_table_knowledge_graph()

        self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data)

        # TODO (pradeep): Use a NameMapper for mapping entity names too.
        # For every new column name seen, we update this counter to map it to a new NLTK name.
        self._column_counter = 0

        # Adding entities and numbers seen in questions to the mapping.
        question_entities, question_numbers = table_context.get_entities_from_question()
        self._question_entities = [entity for entity, _ in question_entities]
        self._question_numbers = [number for number, _ in question_numbers]
        for entity in self._question_entities:
            # These entities all have prefix "string:"
            self._map_name(entity, keep_mapping=True)

        for number_in_question in self._question_numbers:
            self._map_name(f"num:{number_in_question}", keep_mapping=True)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names to the local name mapping.
        for column_name, column_type in table_context.column_types.items():
            self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True)

        self.terminal_productions: Dict[str, str] = {}
        name_mapping = [(name, mapping) for name, mapping in self.global_name_mapping.items()]
        name_mapping += [(name, mapping) for name, mapping in self.local_name_mapping.items()]
        signatures = self.global_type_signatures.copy()
        signatures.update(self.local_type_signatures)
        for predicate, mapped_name in name_mapping:
            if mapped_name in signatures:
                signature = signatures[mapped_name]
                self.terminal_productions[predicate] = f"{signature} -> {predicate}"

        # We don't need to recompute this ever; let's just compute it once and cache it.
        self._valid_actions: Dict[str, List[str]] = None
Ejemplo n.º 3
0
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(constant_type_prefixes={"string": types.STRING_TYPE,
                                                 "num": types.NUMBER_TYPE},
                         global_type_signatures=types.COMMON_TYPE_SIGNATURE,
                         global_name_mapping=types.COMMON_NAME_MAPPING)
        self.table_context = table_context
        # We add name mapping and signatures corresponding to specific column types to the local
        # name mapping based on the table content here.
        column_types = table_context.column_types.values()
        self._table_has_string_columns = False
        self._table_has_date_columns = False
        self._table_has_number_columns = False
        if "string" in column_types:
            for name, translated_name in types.STRING_COLUMN_NAME_MAPPING.items():
                signature = types.STRING_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_string_columns = True
        if "date" in column_types:
            for name, translated_name in types.DATE_COLUMN_NAME_MAPPING.items():
                signature = types.DATE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            # Adding -1 to mapping because we need it for dates where not all three fields are
            # specified. We want to do this only when the table has a date column. This is because
            # the knowledge graph is also constructed in such a way that -1 is an entity with date
            # columns as the neighbors only if any date columns exist in the table.
            self._map_name(f"num:-1", keep_mapping=True)
            self._table_has_date_columns = True
        if "number" in column_types:
            for name, translated_name in types.NUMBER_COLUMN_NAME_MAPPING.items():
                signature = types.NUMBER_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)
            self._table_has_number_columns = True
        if "date" in column_types or "number" in column_types:
            for name, translated_name in types.COMPARABLE_COLUMN_NAME_MAPPING.items():
                signature = types.COMPARABLE_COLUMN_TYPE_SIGNATURE[translated_name]
                self._add_name_mapping(name, translated_name, signature)

        self.table_graph = table_context.get_table_knowledge_graph()

        self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data)

        # TODO (pradeep): Use a NameMapper for mapping entity names too.
        # For every new column name seen, we update this counter to map it to a new NLTK name.
        self._column_counter = 0

        # Adding entities and numbers seen in questions to the mapping.
        question_entities, question_numbers = table_context.get_entities_from_question()
        self._question_entities = [entity for entity, _, _, _ in question_entities]
        self._question_numbers = [number for number, _ in question_numbers]

        self.ent2id = dict()
        for entity, start, end, _ in question_entities:
           self.ent2id[entity] = (start, end)
        self.num2id = dict()
        for num, _id in question_numbers:
            if num != -1:
                self.num2id[num] = _id

        for entity in self._question_entities:
            # These entities all have prefix "string:"
            self._map_name(entity, keep_mapping=True)

        for number_in_question in self._question_numbers:
            self._map_name(f"num:{number_in_question}", keep_mapping=True)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names to the local name mapping.
        for column_name, column_type in table_context.column_types.items():
            self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True)

        self.terminal_productions: Dict[str, str] = {}
        name_mapping = [(name, mapping) for name, mapping in self.global_name_mapping.items()]
        name_mapping += [(name, mapping) for name, mapping in self.local_name_mapping.items()]
        signatures = self.global_type_signatures.copy()
        signatures.update(self.local_type_signatures)
        for predicate, mapped_name in name_mapping:
            if mapped_name in signatures:
                signature = signatures[mapped_name]
                self.terminal_productions[predicate] = f"{signature} -> {predicate}"

        # We don't need to recompute this ever; let's just compute it once and cache it.
        self._valid_actions: Dict[str, List[str]] = None
Ejemplo n.º 4
0
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(
            start_types=self._get_start_types_in_context(table_context))
        self.table_context = table_context
        self.table_data = [Row(row) for row in table_context.table_data]

        column_types = table_context.column_types
        self._table_has_string_columns = False
        self._table_has_date_columns = False
        self._table_has_number_columns = False
        if "string" in column_types:
            self.add_predicate('filter_in', self.filter_in)
            self.add_predicate('filter_not_in', self.filter_not_in)
            self._table_has_string_columns = True
        if "date" in column_types:
            self.add_predicate('filter_date_greater', self.filter_date_greater)
            self.add_predicate('filter_date_greater_equals',
                               self.filter_date_greater_equals)
            self.add_predicate('filter_date_lesser', self.filter_date_lesser)
            self.add_predicate('filter_date_lesser_equals',
                               self.filter_date_lesser_equals)
            self.add_predicate('filter_date_equals', self.filter_date_equals)
            self.add_predicate('filter_date_not_equals',
                               self.filter_date_not_equals)
            self.add_predicate('max_date', self.max_date)
            self.add_predicate('min_date', self.min_date)
            # Adding -1 to mapping because we need it for dates where not all three fields are
            # specified. We want to do this only when the table has a date column. This is because
            # the knowledge graph is also constructed in such a way that -1 is an entity with date
            # columns as the neighbors only if any date columns exist in the table.
            self.add_constant('-1', -1, type_=Number)
            self._table_has_date_columns = True
        if "number" in column_types or "num2" in column_types:
            self.add_predicate('filter_number_greater',
                               self.filter_number_greater)
            self.add_predicate('filter_number_greater_equals',
                               self.filter_number_greater_equals)
            self.add_predicate('filter_number_lesser',
                               self.filter_number_lesser)
            self.add_predicate('filter_number_lesser_equals',
                               self.filter_number_lesser_equals)
            self.add_predicate('filter_number_equals',
                               self.filter_number_equals)
            self.add_predicate('filter_number_not_equals',
                               self.filter_number_not_equals)
            self.add_predicate('max_number', self.max_number)
            self.add_predicate('min_number', self.min_number)
            self.add_predicate('average', self.average)
            self.add_predicate('sum', self.sum)
            self.add_predicate('diff', self.diff)
            self._table_has_number_columns = True
        if "date" in column_types or "number" in column_types or "num2" in column_types:
            self.add_predicate('argmax', self.argmax)
            self.add_predicate('argmin', self.argmin)

        self.table_graph = table_context.get_table_knowledge_graph()

        # Adding entities and numbers seen in questions as constants.
        question_entities, question_numbers = table_context.get_entities_from_question(
        )
        self._question_entities = [entity for entity, _ in question_entities]
        self._question_numbers = [number for number, _ in question_numbers]
        for entity in self._question_entities:
            # Forcing the type of entities to be List[str] here to ensure that the language deals with the outputs
            # of select-like statements and constants similarly.
            self.add_constant(entity, entity, type_=List[str])

        for number in self._question_numbers:
            self.add_constant(str(number), float(number), type_=Number)

        # Keeps track of column name productions so that we can add them to the agenda.
        self._column_productions_for_agenda: Dict[str, str] = {}

        # Adding column names as constants.
        for column_name in table_context.column_names:
            column_type = column_name.split(":")[0].replace("_column", "")
            column: Column = None
            if column_type == 'string':
                column = StringColumn(column_name)
            elif column_type == 'date':
                column = DateColumn(column_name)
                self.add_constant(column_name, column, type_=ComparableColumn)
            elif column_type == 'number' or column_type == "num2":
                column = NumberColumn(column_name)
                self.add_constant(column_name, column, type_=ComparableColumn)
            self.add_constant(column_name, column, type_=Column)
            self.add_constant(column_name, column)
            column_type_name = str(PredicateType.get_type(type(column)))
            self._column_productions_for_agenda[
                column_name] = f"{column_type_name} -> {column_name}"

        # Mapping from terminal strings to productions that produce them.  We use this in the
        # agenda-related methods, and some models that use this language look at this field to know
        # how many terminals to plan for.
        self.terminal_productions: Dict[str, str] = {}
        for name, types in self._function_types.items():
            self.terminal_productions[name] = "%s -> %s" % (types[0], name)
Ejemplo n.º 5
0
    def __init__(self, table_context: TableQuestionContext) -> None:
        super().__init__(start_types={Number, Date, List[str]})
        self.table_context = table_context
        self.table_data = [Row(row) for row in table_context.table_data]

        # if the last colum is total, remove it
        _name = f"string_column:{table_context.column_index_to_name[0]}"
        if _name in table_context.table_data[-1] and  "total" in table_context.table_data[-1][_name]:
            self.table_data.pop()

        column_types = table_context.column_types
        if "string" in column_types:
            self.add_predicate('filter_in', self.filter_in)
            self.add_predicate('filter_not_in', self.filter_not_in)
        if "date" in column_types:
            self.add_predicate('filter_date_greater', self.filter_date_greater)
            self.add_predicate('filter_date_greater_equals', self.filter_date_greater_equals)
            self.add_predicate('filter_date_lesser', self.filter_date_lesser)
            self.add_predicate('filter_date_lesser_equals', self.filter_date_lesser_equals)
            self.add_predicate('filter_date_equals', self.filter_date_equals)
            self.add_predicate('filter_date_not_equals', self.filter_date_not_equals)
        if "number" in column_types or "num2" in column_types:
            self.add_predicate('filter_number_greater', self.filter_number_greater)
            self.add_predicate('filter_number_greater_equals', self.filter_number_greater_equals)
            self.add_predicate('filter_number_lesser', self.filter_number_lesser)
            self.add_predicate('filter_number_lesser_equals', self.filter_number_lesser_equals)
            self.add_predicate('filter_number_equals', self.filter_number_equals)
            self.add_predicate('filter_number_not_equals', self.filter_number_not_equals)
            self.add_predicate('max', self.max)
            self.add_predicate('min', self.min)
            self.add_predicate('average', self.average)
            self.add_predicate('sum', self.sum)
            self.add_predicate('diff', self.diff)
        if "date" in column_types or "number" in column_types or "num2" in column_types:
            self.add_predicate('argmax', self.argmax)
            self.add_predicate('argmin', self.argmin)

        # Adding entities and numbers seen in questions as constants.
        for entity in table_context._entity2id:
            self.add_constant(entity, entity)
        for number in table_context._num2id:
            self.add_constant(str(number), float(number), type_=Number)
        for date_str in table_context._date2id:
            date_obj = Date.make_date(date_str)
            self.add_constant(date_str, date_obj, type_=Date)
        
        self.table_graph = table_context.get_table_knowledge_graph()

        # Adding column names as constants.  Each column gets added once for every
        # type in the hierarchy going from its concrete class to the base Column.  String columns
        # get added as StringColumn and Column, and date and number columns get added as DateColumn
        # (or NumberColumn), ComparableColumn, and Column.
        for column_name, column_types in table_context.column2types.items():
            for column_type in column_types:
                typed_column_name = f"{column_type}_column:{column_name}"
                column: Column = None
                if column_type == 'string':
                    column = StringColumn(typed_column_name)
                elif column_type == 'date':
                    column = DateColumn(typed_column_name)
                    self.add_constant(typed_column_name, column, type_=ComparableColumn)
                elif column_type in ['number', 'num2']:
                    column = NumberColumn(typed_column_name)
                    self.add_constant(typed_column_name, column, type_=ComparableColumn)
                self.add_constant(typed_column_name, column, type_=Column)
                self.add_constant(typed_column_name, column)
                column_type_name = str(PredicateType.get_type(type(column)))