def test_create_table_class(self): """ Make sure a basic `Table` can be created successfully. """ _Table = create_table_class(class_name="MyTable") self.assertEqual(_Table._meta.tablename, "my_table") _Table = create_table_class(class_name="MyTable", class_kwargs={"tablename": "my_table_1"}) self.assertEqual(_Table._meta.tablename, "my_table_1") column = Varchar() _Table = create_table_class(class_name="MyTable", class_members={"name": column}) self.assertIn(column, _Table._meta.columns)
async def _run_add_columns(self, backwards=False): """ Add columns, which belong to existing tables """ if backwards: for add_column in self.add_columns.add_columns: if add_column.table_class_name in [ i.class_name for i in self.add_tables ]: # Don't reverse the add column as the table is going to # be deleted. continue _Table: t.Type[Table] = create_table_class( class_name=add_column.table_class_name, class_kwargs={"tablename": add_column.tablename}, ) await _Table.alter().drop_column(add_column.column).run() else: for table_class_name in self.add_columns.table_class_names: if table_class_name in [i.class_name for i in self.add_tables]: continue # No need to add columns to new tables add_columns: t.List[ AddColumnClass] = self.add_columns.for_table_class_name( table_class_name) # Define the table, with the columns, so the metaclass # sets up the columns correctly. _Table: t.Type[Table] = create_table_class( class_name=add_columns[0].table_class_name, class_kwargs={"tablename": add_columns[0].tablename}, class_members={ add_column.column._meta.name: add_column.column for add_column in add_columns }, ) for add_column in add_columns: # We fetch the column from the Table, as the metaclass # copies and sets it up properly. column = _Table._meta.get_column_by_name( add_column.column._meta.name) await _Table.alter().add_column(name=column._meta.name, column=column).run() if add_column.column._meta.index: await _Table.create_index([add_column.column]).run()
async def _run_add_tables(self, backwards=False): table_classes: t.List[t.Type[Table]] = [] for add_table in self.add_tables: add_columns: t.List[ AddColumnClass] = self.add_columns.for_table_class_name( add_table.class_name) _Table: t.Type[Table] = create_table_class( class_name=add_table.class_name, class_kwargs={"tablename": add_table.tablename}, class_members={ add_column.column._meta.name: add_column.column for add_column in add_columns }, ) table_classes.append(_Table) # Sort by foreign key, so they're created in the right order. sorted_table_classes = sort_table_classes(table_classes) if backwards: for _Table in reversed(sorted_table_classes): await _Table.alter().drop_table(cascade=True).run() else: for _Table in sorted_table_classes: await _Table.create_table().run()
async def _run_drop_columns(self, backwards=False): if backwards: for drop_column in self.drop_columns.drop_columns: _Table = await self.get_table_from_snaphot( table_class_name=drop_column.table_class_name, app_name=self.app_name, offset=-1, ) column_to_restore = _Table._meta.get_column_by_name( drop_column.column_name) await _Table.alter().add_column( name=drop_column.column_name, column=column_to_restore).run() else: for table_class_name in self.drop_columns.table_class_names: columns = self.drop_columns.for_table_class_name( table_class_name) if not columns: continue _Table: t.Type[Table] = create_table_class( class_name=table_class_name, class_kwargs={"tablename": columns[0].tablename}, ) for column in columns: await _Table.alter().drop_column(column=column.column_name ).run()
async def _run_rename_tables(self, backwards=False): for rename_table in self.rename_tables: class_name = (rename_table.new_class_name if backwards else rename_table.old_class_name) tablename = (rename_table.new_tablename if backwards else rename_table.old_tablename) new_tablename = (rename_table.old_tablename if backwards else rename_table.new_tablename) _Table: t.Type[Table] = create_table_class( class_name=class_name, class_kwargs={"tablename": tablename}) await _Table.alter().rename_table(new_name=new_tablename).run()
def to_table_class(self) -> t.Type[Table]: """ Converts the DiffableTable into a Table subclass. """ _Table: t.Type[Table] = create_table_class( class_name=self.class_name, class_kwargs={"tablename": self.tablename}, class_members={ column._meta.name: column for column in self.columns }, ) return _Table
def deserialise_legacy_params(name: str, value: str) -> t.Any: """ Earlier versions of Piccolo serialised parameters differently. This is here purely for backwards compatibility. """ if name == "references": components = value.split("|") if len(components) == 1: class_name = components[0] tablename = None elif len(components) == 2: class_name, tablename = components else: raise ValueError( "Unrecognised Table serialisation - should either be " "`SomeClassName` or `SomeClassName|some_table_name`.") _Table: t.Type[Table] = create_table_class( class_name=class_name, class_kwargs={"tablename": tablename} if tablename else {}, ) return _Table ########################################################################### if name == "on_delete": enum_name, item_name = value.split(".") if enum_name == "OnDelete": return getattr(OnDelete, item_name) ########################################################################### elif name == "on_update": enum_name, item_name = value.split(".") if enum_name == "OnUpdate": return getattr(OnUpdate, item_name) ########################################################################### if name == "default": if value in {"TimestampDefault.now", "DatetimeDefault.now"}: return TimestampNow() try: _value = datetime.datetime.fromisoformat(value) except ValueError: pass else: return _value return value
def test_protected_tablenames(self): """ Make sure that the logic around protected tablenames still works as expected. """ with self.assertRaises(ValueError): create_table_class(class_name="User") with self.assertRaises(ValueError): create_table_class(class_name="MyUser", class_kwargs={"tablename": "user"}) # This shouldn't raise an error: create_table_class(class_name="User", class_kwargs={"tablename": "my_user"})
async def _run_rename_columns(self, backwards=False): for table_class_name in self.rename_columns.table_class_names: columns = self.rename_columns.for_table_class_name( table_class_name) if not columns: continue _Table: t.Type[Table] = create_table_class( class_name=table_class_name, class_kwargs={"tablename": columns[0].tablename}, ) for rename_column in columns: column = (rename_column.new_db_column_name if backwards else rename_column.old_db_column_name) new_name = (rename_column.old_db_column_name if backwards else rename_column.new_db_column_name) await _Table.alter().rename_column( column=column, new_name=new_name, ).run()
def table(self, column: Column): return create_table_class(class_name="MyTable", class_members={"my_column": column})
def tearDown(self): create_table_class("MyTable").alter().drop_table( if_exists=True).run_sync() Migration.alter().drop_table(if_exists=True).run_sync()
async def _run_alter_columns(self, backwards=False): for table_class_name in self.alter_columns.table_class_names: alter_columns = self.alter_columns.for_table_class_name( table_class_name) if not alter_columns: continue _Table: t.Type[Table] = create_table_class( class_name=table_class_name, class_kwargs={"tablename": alter_columns[0].tablename}, ) for alter_column in alter_columns: params = (alter_column.old_params if backwards else alter_column.params) old_params = (alter_column.params if backwards else alter_column.old_params) ############################################################### # Change the column type if possible column_class = (alter_column.old_column_class if backwards else alter_column.column_class) old_column_class = (alter_column.column_class if backwards else alter_column.old_column_class) if (old_column_class is not None) and (column_class is not None): if old_column_class != column_class: old_column = old_column_class(**old_params) old_column._meta._table = _Table old_column._meta._name = alter_column.column_name old_column._meta.db_column_name = ( alter_column.db_column_name) new_column = column_class(**params) new_column._meta._table = _Table new_column._meta._name = alter_column.column_name new_column._meta.db_column_name = ( alter_column.db_column_name) using_expression: t.Optional[str] = None # Postgres won't automatically cast some types to # others. We may as well try, as it will definitely # fail otherwise. if new_column.value_type != old_column.value_type: if old_params.get("default", ...) is not None: # Unless the column's default value is also # something which can be cast to the new type, # it will also fail. Drop the default value for # now - the proper default is set later on. await _Table.alter().drop_default(old_column ).run() using_expression = "{}::{}".format( alter_column.db_column_name, new_column.column_type, ) # We can't migrate a SERIAL to a BIGSERIAL or vice # versa, as SERIAL isn't a true type, just an alias to # other commands. if issubclass(column_class, Serial) and issubclass( old_column_class, Serial): colored_warning( "Unable to migrate Serial to BigSerial and " "vice versa. This must be done manually.") else: await _Table.alter().set_column_type( old_column=old_column, new_column=new_column, using_expression=using_expression, ).run() ############################################################### null = params.get("null") if null is not None: await _Table.alter().set_null( column=alter_column.db_column_name, boolean=null).run() length = params.get("length") if length is not None: await _Table.alter().set_length( column=alter_column.db_column_name, length=length).run() unique = params.get("unique") if unique is not None: # When modifying unique contraints, we need to pass in # a column type, and not just the column name. column = Column() column._meta._table = _Table column._meta._name = alter_column.column_name column._meta.db_column_name = alter_column.db_column_name await _Table.alter().set_unique(column=column, boolean=unique).run() index = params.get("index") index_method = params.get("index_method") if index is None: if index_method is not None: # If the index value hasn't changed, but the # index_method value has, this indicates we need # to change the index type. column = Column() column._meta._table = _Table column._meta._name = alter_column.column_name column._meta.db_column_name = ( alter_column.db_column_name) await _Table.drop_index([column]).run() await _Table.create_index([column], method=index_method, if_not_exists=True).run() else: # If the index value has changed, then we are either # dropping, or creating an index. column = Column() column._meta._table = _Table column._meta._name = alter_column.column_name column._meta.db_column_name = alter_column.db_column_name if index is True: kwargs = ({ "method": index_method } if index_method else {}) await _Table.create_index([column], if_not_exists=True, **kwargs).run() else: await _Table.drop_index([column]).run() # None is a valid value, so retrieve ellipsis if not found. default = params.get("default", ...) if default is not ...: column = Column() column._meta._table = _Table column._meta._name = alter_column.column_name column._meta.db_column_name = alter_column.db_column_name if default is None: await _Table.alter().drop_default(column=column).run() else: column.default = default await _Table.alter().set_default( column=column, value=column.get_default_value()).run() # None is a valid value, so retrieve ellipsis if not found. digits = params.get("digits", ...) if digits is not ...: await _Table.alter().set_digits( column=alter_column.db_column_name, digits=digits, ).run()
async def create_table_class_from_db(table_class: t.Type[Table], tablename: str, schema_name: str) -> OutputSchema: indexes = await get_indexes(table_class=table_class, tablename=tablename, schema_name=schema_name) constraints = await get_constraints(table_class=table_class, tablename=tablename, schema_name=schema_name) triggers = await get_fk_triggers(table_class=table_class, tablename=tablename, schema_name=schema_name) table_schema = await get_table_schema(table_class=table_class, tablename=tablename, schema_name=schema_name) output_schema = OutputSchema() columns: t.Dict[str, Column] = {} for pg_row_meta in table_schema: data_type = pg_row_meta.data_type column_type = COLUMN_TYPE_MAP.get(data_type, None) column_name = pg_row_meta.column_name column_default = pg_row_meta.column_default if not column_type: output_schema.warnings.append( f"{tablename}.{column_name} ['{data_type}']") column_type = Column kwargs: t.Dict[str, t.Any] = { "null": pg_row_meta.is_nullable == "YES", "unique": constraints.is_unique(column_name=column_name), } index = indexes.get_column_index(column_name=column_name) if index is not None: kwargs["index"] = True kwargs["index_method"] = index.method if constraints.is_primary_key(column_name=column_name): kwargs["primary_key"] = True if column_type == Integer: column_type = Serial if constraints.is_foreign_key(column_name=column_name): fk_constraint_table = constraints.get_foreign_key_constraint_name( column_name=column_name) column_type = ForeignKey constraint_table = await get_foreign_key_reference( table_class=table_class, constraint_name=fk_constraint_table.name, constraint_schema=fk_constraint_table.schema, ) if constraint_table.name: referenced_table: t.Union[str, t.Optional[t.Type[Table]]] if constraint_table.name == tablename: referenced_output_schema = output_schema referenced_table = "self" else: referenced_output_schema = ( await create_table_class_from_db( table_class=table_class, tablename=constraint_table.name, schema_name=constraint_table.schema, )) referenced_table = ( referenced_output_schema.get_table_with_name( tablename=constraint_table.name)) kwargs["references"] = (referenced_table if referenced_table is not None else ForeignKeyPlaceholder) trigger = triggers.get_column_ref_trigger( column_name, constraint_table.name) if trigger: kwargs["on_update"] = ONUPDATE_MAP[trigger.on_update] kwargs["on_delete"] = ONDELETE_MAP[trigger.on_delete] output_schema = sum( # type: ignore [output_schema, referenced_output_schema] # type: ignore ) # type: ignore else: kwargs["references"] = ForeignKeyPlaceholder output_schema.imports.append( "from piccolo.columns.column_types import " + column_type.__name__) if column_type is Varchar: kwargs["length"] = pg_row_meta.character_maximum_length elif isinstance(column_type, Numeric): radix = pg_row_meta.numeric_precision_radix precision = int(str(pg_row_meta.numeric_precision), radix) scale = int(str(pg_row_meta.numeric_scale), radix) kwargs["digits"] = (precision, scale) if column_default: default_value = get_column_default(column_type, column_default) if default_value: kwargs["default"] = default_value column = column_type(**kwargs) serialised_params = serialise_params(column._meta.params) for extra_import in serialised_params.extra_imports: output_schema.imports.append(extra_import.__repr__()) columns[column_name] = column table = create_table_class( class_name=_snake_to_camel(tablename), class_kwargs={"tablename": get_table_name(tablename, schema_name)}, class_members=columns, ) output_schema.tables.append(table) return output_schema
def table_exists(self, tablename: str) -> bool: _Table: t.Type[Table] = create_table_class( class_name=tablename.upper(), class_kwargs={"tablename": tablename}) return _Table.table_exists().run_sync()
def run_sync(self, query): _Table = create_table_class(class_name="_Table") return _Table.raw(query).run_sync()