def shuffle_schema(schema: SQLSchema) -> SQLSchema: """ Shuffles the order of column_names, tokenized_column_names, original_column_names on one side, and table_names, tokenized_table_names, original_table_names on the other side. """ column_ids = list(schema.column_names.keys()) assert ( column_ids == list(schema.tokenized_column_names.keys()) == list(schema.original_column_names.keys()) ) table_ids = list(schema.table_names.keys()) assert ( table_ids == list(schema.tokenized_table_names.keys()) == list(schema.original_table_names.keys()) ) # Shuffle order of columns, and order of tables. random.shuffle(column_ids) random.shuffle(table_ids) shuffled_schema = replace( schema, column_names=OrderedFrozenDict( [(id, schema.column_names[id]) for id in column_ids] ), tokenized_column_names=OrderedFrozenDict( [(id, schema.tokenized_column_names[id]) for id in column_ids] ), original_column_names=OrderedFrozenDict( [(id, schema.original_column_names[id]) for id in column_ids] ), table_names=OrderedFrozenDict( [(id, schema.table_names[id]) for id in table_ids] ), tokenized_table_names=OrderedFrozenDict( [(id, schema.tokenized_table_names[id]) for id in table_ids] ), original_table_names=OrderedFrozenDict( [(id, schema.original_table_names[id]) for id in table_ids] ), table_to_columns=frozendict( { table_id: tuple(random.sample(column_ids, len(column_ids))) for table_id, column_ids in schema.table_to_columns.items() } ), ) return shuffled_schema
def freeze_order(xxs: Iterable[Iterable[T]]) -> FrozenDict[T, int]: return frozendict({rel: i for i, rel in enumerate(itertools.chain(*xxs))})
def preprocess_schema_uncached( schema: SpiderSchema, db_path: Optional[str], tokenize: Callable[[Optional[str], List[str], str], List[str]], ) -> SQLSchema: column_names = [] tokenized_column_names = [] original_column_names = [] table_names = [] tokenized_table_names = [] original_table_names = [] column_to_table = {} table_to_columns = {} foreign_keys = {} foreign_keys_tables = defaultdict(set) for i, column in enumerate(schema.columns): column_name = tokenize(column.type, column.name, column.unsplit_name) column_names.append(column.unsplit_name) tokenized_column_names.append(column_name) original_column_names.append(column.orig_name) table_id = None if column.table is None else column.table.id column_to_table[str(i)] = table_id if table_id is not None: columns = table_to_columns.setdefault(str(table_id), []) columns.append(i) if column.foreign_key_for is not None: foreign_keys[str(column.id)] = column.foreign_key_for.id foreign_keys_tables[str(column.table.id)].add( column.foreign_key_for.table.id) for i, table in enumerate(schema.tables): table_names.append(table.unsplit_name) tokenized_table_names.append( tokenize(None, table.name, table.unsplit_name)) original_table_names.append(table.orig_name) foreign_keys_tables = serialization.to_dict_with_sorted_values( foreign_keys_tables) primary_keys = [ column.id for table in schema.tables for column in table.primary_keys ] return SQLSchema( column_names=OrderedFrozenDict([ (ColumnId(str(id)), name) for id, name in enumerate(column_names) ]), tokenized_column_names=OrderedFrozenDict([ (ColumnId(str(id)), tuple(tokenized_name)) for id, tokenized_name in enumerate(tokenized_column_names) ]), original_column_names=OrderedFrozenDict([ (ColumnId(str(id)), original_name) for id, original_name in enumerate(original_column_names) ]), table_names=OrderedFrozenDict([ (ColumnId(str(id)), name) for id, name in enumerate(table_names) ]), tokenized_table_names=OrderedFrozenDict([ (ColumnId(str(id)), tuple(tokenized_name)) for id, tokenized_name in enumerate(tokenized_table_names) ]), original_table_names=OrderedFrozenDict([ (ColumnId(str(id)), original_name) for id, original_name in enumerate(original_table_names) ]), column_to_table=frozendict({ ColumnId(column): TableId(str(table)) if table is not None else None for column, table in column_to_table.items() }), table_to_columns=frozendict({ TableId(table): tuple(ColumnId(str(column)) for column in columns) for table, columns in table_to_columns.items() }), foreign_keys=frozendict({ ColumnId(this): ColumnId(str(other)) for this, other in foreign_keys.items() }), foreign_keys_tables=frozendict({ TableId(this): tuple(TableId(str(other)) for other in others) for this, others in foreign_keys_tables.items() }), primary_keys=tuple(ColumnId(str(column)) for column in primary_keys), db_id=schema.db_id, db_path=db_path, )