def test_uuid(self): serialised = serialise_params(params={"default": uuid.UUID(int=4)}) assert (repr(serialised.params["default"]) == 'uuid.UUID("00000000-0000-0000-0000-000000000004")') serialised = serialise_params(params={"default": UUID4()}) self.assertTrue(serialised.params["default"].__repr__() == "UUID4()")
def alter_columns(self) -> AlterStatements: response = [] extra_imports = [] extra_definitions = [] for table in self.schema: snapshot_table = self._get_snapshot_table(table.class_name) if snapshot_table: delta: TableDelta = table - snapshot_table else: continue for i in delta.alter_columns: new_params = serialise_params(i.params) extra_imports.extend(new_params.extra_imports) extra_definitions.extend(new_params.extra_definitions) old_params = serialise_params(i.old_params) extra_imports.extend(old_params.extra_imports) extra_definitions.extend(old_params.extra_definitions) response.append( f"manager.alter_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{i.column_name}', params={new_params.params}, old_params={old_params.params})" # noqa: E501 ) return AlterStatements( statements=response, extra_imports=extra_imports, extra_definitions=extra_definitions, )
def test_lambda(self): """ Make sure lambda functions are rejected. """ with self.assertRaises(ValueError) as manager: serialise_params(params={"default": lambda x: x + 1}) self.assertEqual(manager.exception.__str__(), "Lambdas can't be serialised")
def alter_columns(self) -> AlterStatements: response: t.List[str] = [] extra_imports: t.List[Import] = [] extra_definitions: t.List[Definition] = [] for table in self.schema: snapshot_table = self._get_snapshot_table(table.class_name) if snapshot_table: delta: TableDelta = table - snapshot_table else: continue for alter_column in delta.alter_columns: new_params = serialise_params(alter_column.params) extra_imports.extend(new_params.extra_imports) extra_definitions.extend(new_params.extra_definitions) old_params = serialise_params(alter_column.old_params) extra_imports.extend(old_params.extra_imports) extra_definitions.extend(old_params.extra_definitions) column_class = (alter_column.column_class.__name__ if alter_column.column_class else "None") old_column_class = (alter_column.old_column_class.__name__ if alter_column.old_column_class else "None") if alter_column.column_class is not None: extra_imports.append( Import( module=alter_column.column_class.__module__, target=alter_column.column_class.__name__, expect_conflict_with_global_name=getattr( UniqueGlobalNames, f"COLUMN_{alter_column.column_class.__name__.upper()}", # noqa: E501 ), )) if alter_column.old_column_class is not None: extra_imports.append( Import( module=alter_column.old_column_class.__module__, target=alter_column.old_column_class.__name__, expect_conflict_with_global_name=getattr( UniqueGlobalNames, f"COLUMN_{alter_column.old_column_class.__name__.upper()}", # noqa: E501 ), )) response.append( f"manager.alter_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{alter_column.column_name}', params={new_params.params}, old_params={old_params.params}, column_class={column_class}, old_column_class={old_column_class})" # noqa: E501 ) return AlterStatements( statements=response, extra_imports=extra_imports, extra_definitions=extra_definitions, )
def test_lazy_table_reference(self): # These are equivalent: references_list = [ LazyTableReference( table_class_name="Manager", app_name="example_app" ), LazyTableReference( table_class_name="Manager", module_path="tests.example_app.tables", ), ] for references in references_list: serialised = serialise_params(params={"references": references}) self.assertTrue( serialised.params["references"].__repr__() == "Manager" ) self.assertTrue(len(serialised.extra_imports) == 1) self.assertEqual( serialised.extra_imports[0].__str__(), "from piccolo.table import Table", ) self.assertTrue(len(serialised.extra_definitions) == 1) self.assertEqual( serialised.extra_definitions[0].__str__(), 'class Manager(Table, tablename="manager"): pass', )
def new_table_columns(self) -> AlterStatements: new_tables: t.List[DiffableTable] = list( set(self.schema) - set(self.schema_snapshot)) response: t.List[str] = [] extra_imports: t.List[Import] = [] extra_definitions: t.List[str] = [] for table in new_tables: if (table.class_name in self.rename_tables_collection.new_class_names): continue for column in table.columns: # In case we cause subtle bugs: params = deepcopy(column._meta.params) _params = serialise_params(params) cleaned_params = _params.params extra_imports.extend(_params.extra_imports) extra_definitions.extend(_params.extra_definitions) extra_imports.append( Import( module=column.__class__.__module__, target=column.__class__.__name__, )) response.append( f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column._meta.name}', column_class_name='{column.__class__.__name__}', column_class={column.__class__.__name__}, params={str(cleaned_params)})" # noqa: E501 ) return AlterStatements( statements=response, extra_imports=extra_imports, extra_definitions=extra_definitions, )
def add_columns(self) -> AlterStatements: response = [] extra_imports = [] extra_definitions = [] for table in self.schema: snapshot_table = self._get_snapshot_table(table.class_name) if snapshot_table: delta: TableDelta = table - snapshot_table else: continue for column in delta.add_columns: if ( column.column_name in self.rename_columns_collection.new_column_names ): continue params = serialise_params(column.params) cleaned_params = params.params extra_imports.extend(params.extra_imports) extra_definitions.extend(params.extra_definitions) response.append( f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', column_class_name='{column.column_class_name}', params={str(cleaned_params)})" # noqa: E501 ) return AlterStatements( statements=response, extra_imports=extra_imports, extra_definitions=extra_definitions, )
def test_lazy_table_reference(self): # These are equivalent: references_list = [ LazyTableReference(table_class_name="Manager", app_name="music"), LazyTableReference( table_class_name="Manager", module_path="tests.example_apps.music.tables", ), ] for references in references_list: serialised = serialise_params(params={"references": references}) self.assertTrue( serialised.params["references"].__repr__() == "Manager") self.assertTrue(len(serialised.extra_imports) == 1) self.assertEqual( serialised.extra_imports[0].__str__(), "from piccolo.table import Table", ) self.assertTrue(len(serialised.extra_definitions) == 1) self.assertEqual( serialised.extra_definitions[0].__str__(), ('class Manager(Table, tablename="manager"): ' "id = Serial(null=False, primary_key=True, unique=False, " "index=False, index_method=IndexMethod.btree, " "choices=None, db_column_name='id', secret=False)"), )
def test_time(self): serialised = serialise_params(params={"default": TimeNow()}) self.assertEqual(serialised.params["default"].__repr__(), "TimeNow()") self.assertTrue(len(serialised.extra_imports) == 1) self.assertEqual( serialised.extra_imports[0].__str__(), "from piccolo.columns.defaults.time import TimeNow", )
def test_builtins(self): """ Make sure builtins can be serialised properly. """ serialised = serialise_params(params={"default": list}) self.assertTrue(serialised.params["default"].__repr__() == "list") self.assertTrue(len(serialised.extra_imports) == 0)
def test_function(self): serialised = serialise_params(params={"default": example_function}) self.assertTrue( serialised.params["default"].__repr__() == "example_function") self.assertTrue(len(serialised.extra_imports) == 1) self.assertEqual( serialised.extra_imports[0].__str__(), ("from tests.apps.migrations.auto.test_serialisation import " "example_function"), ) self.assertTrue(len(serialised.extra_definitions) == 0)
def test_builtin_enum_instance(self): """ Make sure Enum instances defiend in Piccolo can be serialised properly - for example, with on_delete. """ serialised = serialise_params(params={"on_delete": OnDelete.cascade}) self.assertEqual(serialised.params["on_delete"].__repr__(), "OnDelete.cascade") self.assertEqual( [i.__repr__() for i in serialised.extra_imports], ["from piccolo.columns.base import OnDelete"], ) self.assertEqual(serialised.extra_definitions, [])
def test_custom_enum_instance(self): """ Make sure custom Enum instances can be serialised properly. An example is when a user defines a choices Enum, and then sets the default to one of those choices. """ class Choices(Enum): a = 1 b = 2 serialised = serialise_params(params={"default": Choices.a}) self.assertEqual(serialised.params["default"], 1) self.assertEqual(serialised.extra_imports, []) self.assertEqual(serialised.extra_definitions, [])
def test_column_instance(self): """ Make sure Column instances can be serialised properly. An example use case is when a `base_column` argument is passed to an `Array` column. """ serialised = serialise_params(params={"base_column": Varchar()}) self.assertEqual( serialised.params["base_column"].__repr__(), "Varchar(length=255, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)", # noqa: E501 ) self.assertEqual( {i.__repr__() for i in serialised.extra_imports}, { "from piccolo.columns.column_types import Varchar", "from piccolo.columns.indexes import IndexMethod", }, )
def add_columns(self) -> AlterStatements: response: t.List[str] = [] extra_imports: t.List[Import] = [] extra_definitions: t.List[Definition] = [] for table in self.schema: snapshot_table = self._get_snapshot_table(table.class_name) if snapshot_table: delta: TableDelta = table - snapshot_table else: continue for add_column in delta.add_columns: if (add_column.column_name in self.rename_columns_collection.new_column_names): continue params = serialise_params(add_column.params) cleaned_params = params.params extra_imports.extend(params.extra_imports) extra_definitions.extend(params.extra_definitions) column_class = add_column.column_class extra_imports.append( Import( module=column_class.__module__, target=column_class.__name__, expect_conflict_with_global_name=getattr( UniqueGlobalNames, f"COLUMN_{column_class.__name__.upper()}", ), )) response.append( f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{add_column.column_name}', db_column_name='{add_column.db_column_name}', column_class_name='{add_column.column_class_name}', column_class={column_class.__name__}, params={str(cleaned_params)})" # noqa: E501 ) return AlterStatements( statements=response, extra_imports=extra_imports, extra_definitions=extra_definitions, )
def test_enum_type(self): """ Make sure Enum types can be serialised properly. """ class Choices(Enum): a = 1 b = 2 c = Choice(value=3, display_name="c1") serialised = serialise_params(params={"choices": Choices}) self.assertEqual( serialised.params["choices"].__repr__(), "Enum('Choices', {'a': 1, 'b': 2, 'c': Choice(value=3, display_name='c1')})", # noqa: E501 ) self.assertEqual( {i.__repr__() for i in serialised.extra_imports}, { "from piccolo.columns.choices import Choice", "from enum import Enum", }, )
def test_uuid(self): serialised = serialise_params(params={"default": UUID4()}) self.assertTrue(serialised.params["default"].__repr__() == "UUID4()")
async def create_table_class_from_db(table_class: t.Type[Table], tablename: str, schema_name: str) -> OutputSchema: indexes = await get_indexes(table_class=table_class, tablename=tablename, schema_name=schema_name) constraints = await get_constraints(table_class=table_class, tablename=tablename, schema_name=schema_name) triggers = await get_fk_triggers(table_class=table_class, tablename=tablename, schema_name=schema_name) table_schema = await get_table_schema(table_class=table_class, tablename=tablename, schema_name=schema_name) output_schema = OutputSchema() columns: t.Dict[str, Column] = {} for pg_row_meta in table_schema: data_type = pg_row_meta.data_type column_type = COLUMN_TYPE_MAP.get(data_type, None) column_name = pg_row_meta.column_name column_default = pg_row_meta.column_default if not column_type: output_schema.warnings.append( f"{tablename}.{column_name} ['{data_type}']") column_type = Column kwargs: t.Dict[str, t.Any] = { "null": pg_row_meta.is_nullable == "YES", "unique": constraints.is_unique(column_name=column_name), } index = indexes.get_column_index(column_name=column_name) if index is not None: kwargs["index"] = True kwargs["index_method"] = index.method if constraints.is_primary_key(column_name=column_name): kwargs["primary_key"] = True if column_type == Integer: column_type = Serial if constraints.is_foreign_key(column_name=column_name): fk_constraint_table = constraints.get_foreign_key_constraint_name( column_name=column_name) column_type = ForeignKey constraint_table = await get_foreign_key_reference( table_class=table_class, constraint_name=fk_constraint_table.name, constraint_schema=fk_constraint_table.schema, ) if constraint_table.name: referenced_table: t.Union[str, t.Optional[t.Type[Table]]] if constraint_table.name == tablename: referenced_output_schema = output_schema referenced_table = "self" else: referenced_output_schema = ( await create_table_class_from_db( table_class=table_class, tablename=constraint_table.name, schema_name=constraint_table.schema, )) referenced_table = ( referenced_output_schema.get_table_with_name( tablename=constraint_table.name)) kwargs["references"] = (referenced_table if referenced_table is not None else ForeignKeyPlaceholder) trigger = triggers.get_column_ref_trigger( column_name, constraint_table.name) if trigger: kwargs["on_update"] = ONUPDATE_MAP[trigger.on_update] kwargs["on_delete"] = ONDELETE_MAP[trigger.on_delete] output_schema = sum( # type: ignore [output_schema, referenced_output_schema] # type: ignore ) # type: ignore else: kwargs["references"] = ForeignKeyPlaceholder output_schema.imports.append( "from piccolo.columns.column_types import " + column_type.__name__) if column_type is Varchar: kwargs["length"] = pg_row_meta.character_maximum_length elif isinstance(column_type, Numeric): radix = pg_row_meta.numeric_precision_radix precision = int(str(pg_row_meta.numeric_precision), radix) scale = int(str(pg_row_meta.numeric_scale), radix) kwargs["digits"] = (precision, scale) if column_default: default_value = get_column_default(column_type, column_default) if default_value: kwargs["default"] = default_value column = column_type(**kwargs) serialised_params = serialise_params(column._meta.params) for extra_import in serialised_params.extra_imports: output_schema.imports.append(extra_import.__repr__()) columns[column_name] = column table = create_table_class( class_name=_snake_to_camel(tablename), class_kwargs={"tablename": get_table_name(tablename, schema_name)}, class_members=columns, ) output_schema.tables.append(table) return output_schema
def __sub__(self, value: DiffableTable) -> TableDelta: if not isinstance(value, DiffableTable): raise ValueError( "Can only diff with other DiffableTable instances" ) if value.class_name != self.class_name: raise ValueError( "The two tables don't appear to have the same name." ) add_columns = [ AddColumn( table_class_name=self.class_name, column_name=i._meta.name, column_class_name=i.__class__.__name__, params=i._meta.params, ) for i in (set(self.columns) - set(value.columns)) ] drop_columns = [ DropColumn( table_class_name=self.class_name, column_name=i._meta.name, tablename=value.tablename, ) for i in (set(value.columns) - set(self.columns)) ] alter_columns: t.List[AlterColumn] = [] for existing_column in value.columns: column = self.columns_map.get(existing_column._meta.name) if not column: # This is a new column - already captured above. continue delta = compare_dicts( serialise_params(column._meta.params).params, serialise_params(existing_column._meta.params).params, ) old_params = { key: existing_column._meta.params.get(key) for key, _ in delta.items() } if delta: alter_columns.append( AlterColumn( table_class_name=self.class_name, tablename=self.tablename, column_name=column._meta.name, params=deserialise_params(delta), old_params=old_params, ) ) return TableDelta( add_columns=add_columns, drop_columns=drop_columns, alter_columns=alter_columns, )
def test_date(self): serialised = serialise_params(params={"default": DateNow()}) self.assertEqual(serialised.params["default"].__repr__(), "DateNow()")
def test_timestamp(self): serialised = serialise_params(params={"default": TimestampNow()}) self.assertTrue( serialised.params["default"].__repr__() == "TimestampNow()")
def test_decimal(self): serialised = serialise_params( params={"default": decimal.Decimal("1.2")}) assert repr(serialised.params["default"]) == 'decimal.Decimal("1.2")'