def __init__(self, table_context: TableQuestionContext) -> None: super().__init__(start_types={Number, Date, List[str]}) self.table_context = table_context self.table_data = [Row(row) for row in table_context.table_data] self.exe_restricted = True self.restrict_num = 1 column_types = table_context.column_types.values() if "string" in column_types: self.add_predicate('filter_in', self.filter_in) self.add_predicate('filter_not_in', self.filter_not_in) if "date" in column_types: self.add_predicate('filter_date_greater', self.filter_date_greater) self.add_predicate('filter_date_greater_equals', self.filter_date_greater_equals) self.add_predicate('filter_date_lesser', self.filter_date_lesser) self.add_predicate('filter_date_lesser_equals', self.filter_date_lesser_equals) self.add_predicate('filter_date_equals', self.filter_date_equals) self.add_predicate('filter_date_not_equals', self.filter_date_not_equals) # Adding -1 to mapping because we need it for dates where not all three fields are # specified. We want to do this only when the table has a date column. This is because # the knowledge graph is also constructed in such a way that -1 is an entity with date # columns as the neighbors only if any date columns exist in the table. self.add_constant('-1', -1, type_=Number) if "number" in column_types: self.add_predicate('filter_number_greater', self.filter_number_greater) self.add_predicate('filter_number_greater_equals', self.filter_number_greater_equals) self.add_predicate('filter_number_lesser', self.filter_number_lesser) self.add_predicate('filter_number_lesser_equals', self.filter_number_lesser_equals) self.add_predicate('filter_number_equals', self.filter_number_equals) self.add_predicate('filter_number_not_equals', self.filter_number_not_equals) self.add_predicate('max', self.max) self.add_predicate('min', self.min) self.add_predicate('average', self.average) self.add_predicate('sum', self.sum) self.add_predicate('diff', self.diff) if "date" in column_types or "number" in column_types: self.add_predicate('argmax', self.argmax) self.add_predicate('argmin', self.argmin) self.table_graph = table_context.get_table_knowledge_graph() # Adding entities and numbers seen in questions as constants. question_entities, question_numbers = table_context.get_entities_from_question( ) self._question_entities = [ entity for entity, _, _, _ in question_entities ] self._question_numbers = [number for number, _ in question_numbers] for entity in self._question_entities: self.add_constant(entity, entity) for number in self._question_numbers: self.add_constant(str(number), float(number), type_=Number) # Keeps track of column name productions so that we can add them to the agenda. self._column_productions_for_agenda: Dict[str, str] = {} # Adding column names as constants. Each column gets added once for every # type in the hierarchy going from its concrete class to the base Column. String columns # get added as StringColumn and Column, and date and number columns get added as DateColumn # (or NumberColumn), ComparableColumn, and Column. for column_name, column_type in table_context.column_types.items(): column_name = f"{column_type}_column:{column_name}" column: Column = None if column_type == 'string': column = StringColumn(column_name) elif column_type == 'date': column = DateColumn(column_name) self.add_constant(column_name, column, type_=ComparableColumn) elif column_type == 'number': column = NumberColumn(column_name) self.add_constant(column_name, column, type_=ComparableColumn) self.add_constant(column_name, column, type_=Column) self.add_constant(column_name, column) column_type_name = str(PredicateType.get_type(type(column))) self._column_productions_for_agenda[ column_name] = f"{column_type_name} -> {column_name}" # Mapping from terminal strings to productions that produce them. We use this in the # agenda-related methods, and some models that use this language look at this field to know # how many terminals to plan for. self.terminal_productions: Dict[str, str] = {} for name, types in self._function_types.items(): self.terminal_productions[name] = "%s -> %s" % (types[0], name)
def __init__(self, table_context: TableQuestionContext) -> None: super().__init__(constant_type_prefixes={"string": types.STRING_TYPE, "num": types.NUMBER_TYPE}, global_type_signatures=types.COMMON_TYPE_SIGNATURE, global_name_mapping=types.COMMON_NAME_MAPPING) self.table_context = table_context # We add name mapping and signatures corresponding to specific column types to the local # name mapping based on the table content here. column_types = table_context.column_types.values() self._table_has_string_columns = False self._table_has_date_columns = False self._table_has_number_columns = False if "string" in column_types: for name, translated_name in types.STRING_COLUMN_NAME_MAPPING.items(): signature = types.STRING_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self._table_has_string_columns = True if "date" in column_types: for name, translated_name in types.DATE_COLUMN_NAME_MAPPING.items(): signature = types.DATE_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) # Adding -1 to mapping because we need it for dates where not all three fields are # specified. We want to do this only when the table has a date column. This is because # the knowledge graph is also constructed in such a way that -1 is an entity with date # columns as the neighbors only if any date columns exist in the table. self._map_name(f"num:-1", keep_mapping=True) self._table_has_date_columns = True if "number" in column_types: for name, translated_name in types.NUMBER_COLUMN_NAME_MAPPING.items(): signature = types.NUMBER_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self._table_has_number_columns = True if "date" in column_types or "number" in column_types: for name, translated_name in types.COMPARABLE_COLUMN_NAME_MAPPING.items(): signature = types.COMPARABLE_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self.table_graph = table_context.get_table_knowledge_graph() self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data) # TODO (pradeep): Use a NameMapper for mapping entity names too. # For every new column name seen, we update this counter to map it to a new NLTK name. self._column_counter = 0 # Adding entities and numbers seen in questions to the mapping. question_entities, question_numbers = table_context.get_entities_from_question() self._question_entities = [entity for entity, _ in question_entities] self._question_numbers = [number for number, _ in question_numbers] for entity in self._question_entities: # These entities all have prefix "string:" self._map_name(entity, keep_mapping=True) for number_in_question in self._question_numbers: self._map_name(f"num:{number_in_question}", keep_mapping=True) # Keeps track of column name productions so that we can add them to the agenda. self._column_productions_for_agenda: Dict[str, str] = {} # Adding column names to the local name mapping. for column_name, column_type in table_context.column_types.items(): self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True) self.terminal_productions: Dict[str, str] = {} name_mapping = [(name, mapping) for name, mapping in self.global_name_mapping.items()] name_mapping += [(name, mapping) for name, mapping in self.local_name_mapping.items()] signatures = self.global_type_signatures.copy() signatures.update(self.local_type_signatures) for predicate, mapped_name in name_mapping: if mapped_name in signatures: signature = signatures[mapped_name] self.terminal_productions[predicate] = f"{signature} -> {predicate}" # We don't need to recompute this ever; let's just compute it once and cache it. self._valid_actions: Dict[str, List[str]] = None
def __init__(self, table_context: TableQuestionContext) -> None: super().__init__(constant_type_prefixes={"string": types.STRING_TYPE, "num": types.NUMBER_TYPE}, global_type_signatures=types.COMMON_TYPE_SIGNATURE, global_name_mapping=types.COMMON_NAME_MAPPING) self.table_context = table_context # We add name mapping and signatures corresponding to specific column types to the local # name mapping based on the table content here. column_types = table_context.column_types.values() self._table_has_string_columns = False self._table_has_date_columns = False self._table_has_number_columns = False if "string" in column_types: for name, translated_name in types.STRING_COLUMN_NAME_MAPPING.items(): signature = types.STRING_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self._table_has_string_columns = True if "date" in column_types: for name, translated_name in types.DATE_COLUMN_NAME_MAPPING.items(): signature = types.DATE_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) # Adding -1 to mapping because we need it for dates where not all three fields are # specified. We want to do this only when the table has a date column. This is because # the knowledge graph is also constructed in such a way that -1 is an entity with date # columns as the neighbors only if any date columns exist in the table. self._map_name(f"num:-1", keep_mapping=True) self._table_has_date_columns = True if "number" in column_types: for name, translated_name in types.NUMBER_COLUMN_NAME_MAPPING.items(): signature = types.NUMBER_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self._table_has_number_columns = True if "date" in column_types or "number" in column_types: for name, translated_name in types.COMPARABLE_COLUMN_NAME_MAPPING.items(): signature = types.COMPARABLE_COLUMN_TYPE_SIGNATURE[translated_name] self._add_name_mapping(name, translated_name, signature) self.table_graph = table_context.get_table_knowledge_graph() self._executor = WikiTablesVariableFreeExecutor(self.table_context.table_data) # TODO (pradeep): Use a NameMapper for mapping entity names too. # For every new column name seen, we update this counter to map it to a new NLTK name. self._column_counter = 0 # Adding entities and numbers seen in questions to the mapping. question_entities, question_numbers = table_context.get_entities_from_question() self._question_entities = [entity for entity, _, _, _ in question_entities] self._question_numbers = [number for number, _ in question_numbers] self.ent2id = dict() for entity, start, end, _ in question_entities: self.ent2id[entity] = (start, end) self.num2id = dict() for num, _id in question_numbers: if num != -1: self.num2id[num] = _id for entity in self._question_entities: # These entities all have prefix "string:" self._map_name(entity, keep_mapping=True) for number_in_question in self._question_numbers: self._map_name(f"num:{number_in_question}", keep_mapping=True) # Keeps track of column name productions so that we can add them to the agenda. self._column_productions_for_agenda: Dict[str, str] = {} # Adding column names to the local name mapping. for column_name, column_type in table_context.column_types.items(): self._map_name(f"{column_type}_column:{column_name}", keep_mapping=True) self.terminal_productions: Dict[str, str] = {} name_mapping = [(name, mapping) for name, mapping in self.global_name_mapping.items()] name_mapping += [(name, mapping) for name, mapping in self.local_name_mapping.items()] signatures = self.global_type_signatures.copy() signatures.update(self.local_type_signatures) for predicate, mapped_name in name_mapping: if mapped_name in signatures: signature = signatures[mapped_name] self.terminal_productions[predicate] = f"{signature} -> {predicate}" # We don't need to recompute this ever; let's just compute it once and cache it. self._valid_actions: Dict[str, List[str]] = None
def __init__(self, table_context: TableQuestionContext) -> None: super().__init__( start_types=self._get_start_types_in_context(table_context)) self.table_context = table_context self.table_data = [Row(row) for row in table_context.table_data] column_types = table_context.column_types self._table_has_string_columns = False self._table_has_date_columns = False self._table_has_number_columns = False if "string" in column_types: self.add_predicate('filter_in', self.filter_in) self.add_predicate('filter_not_in', self.filter_not_in) self._table_has_string_columns = True if "date" in column_types: self.add_predicate('filter_date_greater', self.filter_date_greater) self.add_predicate('filter_date_greater_equals', self.filter_date_greater_equals) self.add_predicate('filter_date_lesser', self.filter_date_lesser) self.add_predicate('filter_date_lesser_equals', self.filter_date_lesser_equals) self.add_predicate('filter_date_equals', self.filter_date_equals) self.add_predicate('filter_date_not_equals', self.filter_date_not_equals) self.add_predicate('max_date', self.max_date) self.add_predicate('min_date', self.min_date) # Adding -1 to mapping because we need it for dates where not all three fields are # specified. We want to do this only when the table has a date column. This is because # the knowledge graph is also constructed in such a way that -1 is an entity with date # columns as the neighbors only if any date columns exist in the table. self.add_constant('-1', -1, type_=Number) self._table_has_date_columns = True if "number" in column_types or "num2" in column_types: self.add_predicate('filter_number_greater', self.filter_number_greater) self.add_predicate('filter_number_greater_equals', self.filter_number_greater_equals) self.add_predicate('filter_number_lesser', self.filter_number_lesser) self.add_predicate('filter_number_lesser_equals', self.filter_number_lesser_equals) self.add_predicate('filter_number_equals', self.filter_number_equals) self.add_predicate('filter_number_not_equals', self.filter_number_not_equals) self.add_predicate('max_number', self.max_number) self.add_predicate('min_number', self.min_number) self.add_predicate('average', self.average) self.add_predicate('sum', self.sum) self.add_predicate('diff', self.diff) self._table_has_number_columns = True if "date" in column_types or "number" in column_types or "num2" in column_types: self.add_predicate('argmax', self.argmax) self.add_predicate('argmin', self.argmin) self.table_graph = table_context.get_table_knowledge_graph() # Adding entities and numbers seen in questions as constants. question_entities, question_numbers = table_context.get_entities_from_question( ) self._question_entities = [entity for entity, _ in question_entities] self._question_numbers = [number for number, _ in question_numbers] for entity in self._question_entities: # Forcing the type of entities to be List[str] here to ensure that the language deals with the outputs # of select-like statements and constants similarly. self.add_constant(entity, entity, type_=List[str]) for number in self._question_numbers: self.add_constant(str(number), float(number), type_=Number) # Keeps track of column name productions so that we can add them to the agenda. self._column_productions_for_agenda: Dict[str, str] = {} # Adding column names as constants. for column_name in table_context.column_names: column_type = column_name.split(":")[0].replace("_column", "") column: Column = None if column_type == 'string': column = StringColumn(column_name) elif column_type == 'date': column = DateColumn(column_name) self.add_constant(column_name, column, type_=ComparableColumn) elif column_type == 'number' or column_type == "num2": column = NumberColumn(column_name) self.add_constant(column_name, column, type_=ComparableColumn) self.add_constant(column_name, column, type_=Column) self.add_constant(column_name, column) column_type_name = str(PredicateType.get_type(type(column))) self._column_productions_for_agenda[ column_name] = f"{column_type_name} -> {column_name}" # Mapping from terminal strings to productions that produce them. We use this in the # agenda-related methods, and some models that use this language look at this field to know # how many terminals to plan for. self.terminal_productions: Dict[str, str] = {} for name, types in self._function_types.items(): self.terminal_productions[name] = "%s -> %s" % (types[0], name)
def __init__(self, table_context: TableQuestionContext) -> None: super().__init__(start_types={Number, Date, List[str]}) self.table_context = table_context self.table_data = [Row(row) for row in table_context.table_data] # if the last colum is total, remove it _name = f"string_column:{table_context.column_index_to_name[0]}" if _name in table_context.table_data[-1] and "total" in table_context.table_data[-1][_name]: self.table_data.pop() column_types = table_context.column_types if "string" in column_types: self.add_predicate('filter_in', self.filter_in) self.add_predicate('filter_not_in', self.filter_not_in) if "date" in column_types: self.add_predicate('filter_date_greater', self.filter_date_greater) self.add_predicate('filter_date_greater_equals', self.filter_date_greater_equals) self.add_predicate('filter_date_lesser', self.filter_date_lesser) self.add_predicate('filter_date_lesser_equals', self.filter_date_lesser_equals) self.add_predicate('filter_date_equals', self.filter_date_equals) self.add_predicate('filter_date_not_equals', self.filter_date_not_equals) if "number" in column_types or "num2" in column_types: self.add_predicate('filter_number_greater', self.filter_number_greater) self.add_predicate('filter_number_greater_equals', self.filter_number_greater_equals) self.add_predicate('filter_number_lesser', self.filter_number_lesser) self.add_predicate('filter_number_lesser_equals', self.filter_number_lesser_equals) self.add_predicate('filter_number_equals', self.filter_number_equals) self.add_predicate('filter_number_not_equals', self.filter_number_not_equals) self.add_predicate('max', self.max) self.add_predicate('min', self.min) self.add_predicate('average', self.average) self.add_predicate('sum', self.sum) self.add_predicate('diff', self.diff) if "date" in column_types or "number" in column_types or "num2" in column_types: self.add_predicate('argmax', self.argmax) self.add_predicate('argmin', self.argmin) # Adding entities and numbers seen in questions as constants. for entity in table_context._entity2id: self.add_constant(entity, entity) for number in table_context._num2id: self.add_constant(str(number), float(number), type_=Number) for date_str in table_context._date2id: date_obj = Date.make_date(date_str) self.add_constant(date_str, date_obj, type_=Date) self.table_graph = table_context.get_table_knowledge_graph() # Adding column names as constants. Each column gets added once for every # type in the hierarchy going from its concrete class to the base Column. String columns # get added as StringColumn and Column, and date and number columns get added as DateColumn # (or NumberColumn), ComparableColumn, and Column. for column_name, column_types in table_context.column2types.items(): for column_type in column_types: typed_column_name = f"{column_type}_column:{column_name}" column: Column = None if column_type == 'string': column = StringColumn(typed_column_name) elif column_type == 'date': column = DateColumn(typed_column_name) self.add_constant(typed_column_name, column, type_=ComparableColumn) elif column_type in ['number', 'num2']: column = NumberColumn(typed_column_name) self.add_constant(typed_column_name, column, type_=ComparableColumn) self.add_constant(typed_column_name, column, type_=Column) self.add_constant(typed_column_name, column) column_type_name = str(PredicateType.get_type(type(column)))