예제 #1
0
    def test_uuid(self):
        serialised = serialise_params(params={"default": uuid.UUID(int=4)})
        assert (repr(serialised.params["default"]) ==
                'uuid.UUID("00000000-0000-0000-0000-000000000004")')

        serialised = serialise_params(params={"default": UUID4()})
        self.assertTrue(serialised.params["default"].__repr__() == "UUID4()")
예제 #2
0
    def alter_columns(self) -> AlterStatements:
        response = []
        extra_imports = []
        extra_definitions = []
        for table in self.schema:
            snapshot_table = self._get_snapshot_table(table.class_name)
            if snapshot_table:
                delta: TableDelta = table - snapshot_table
            else:
                continue

            for i in delta.alter_columns:
                new_params = serialise_params(i.params)
                extra_imports.extend(new_params.extra_imports)
                extra_definitions.extend(new_params.extra_definitions)

                old_params = serialise_params(i.old_params)
                extra_imports.extend(old_params.extra_imports)
                extra_definitions.extend(old_params.extra_definitions)

                response.append(
                    f"manager.alter_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{i.column_name}', params={new_params.params}, old_params={old_params.params})"  # noqa: E501
                )

        return AlterStatements(
            statements=response,
            extra_imports=extra_imports,
            extra_definitions=extra_definitions,
        )
예제 #3
0
    def test_lambda(self):
        """
        Make sure lambda functions are rejected.
        """
        with self.assertRaises(ValueError) as manager:
            serialise_params(params={"default": lambda x: x + 1})

        self.assertEqual(manager.exception.__str__(),
                         "Lambdas can't be serialised")
예제 #4
0
    def alter_columns(self) -> AlterStatements:
        response: t.List[str] = []
        extra_imports: t.List[Import] = []
        extra_definitions: t.List[Definition] = []
        for table in self.schema:
            snapshot_table = self._get_snapshot_table(table.class_name)
            if snapshot_table:
                delta: TableDelta = table - snapshot_table
            else:
                continue

            for alter_column in delta.alter_columns:
                new_params = serialise_params(alter_column.params)
                extra_imports.extend(new_params.extra_imports)
                extra_definitions.extend(new_params.extra_definitions)

                old_params = serialise_params(alter_column.old_params)
                extra_imports.extend(old_params.extra_imports)
                extra_definitions.extend(old_params.extra_definitions)

                column_class = (alter_column.column_class.__name__
                                if alter_column.column_class else "None")

                old_column_class = (alter_column.old_column_class.__name__ if
                                    alter_column.old_column_class else "None")

                if alter_column.column_class is not None:
                    extra_imports.append(
                        Import(
                            module=alter_column.column_class.__module__,
                            target=alter_column.column_class.__name__,
                            expect_conflict_with_global_name=getattr(
                                UniqueGlobalNames,
                                f"COLUMN_{alter_column.column_class.__name__.upper()}",  # noqa: E501
                            ),
                        ))

                if alter_column.old_column_class is not None:
                    extra_imports.append(
                        Import(
                            module=alter_column.old_column_class.__module__,
                            target=alter_column.old_column_class.__name__,
                            expect_conflict_with_global_name=getattr(
                                UniqueGlobalNames,
                                f"COLUMN_{alter_column.old_column_class.__name__.upper()}",  # noqa: E501
                            ),
                        ))

                response.append(
                    f"manager.alter_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{alter_column.column_name}', params={new_params.params}, old_params={old_params.params}, column_class={column_class}, old_column_class={old_column_class})"  # noqa: E501
                )

        return AlterStatements(
            statements=response,
            extra_imports=extra_imports,
            extra_definitions=extra_definitions,
        )
예제 #5
0
    def test_lazy_table_reference(self):
        # These are equivalent:
        references_list = [
            LazyTableReference(
                table_class_name="Manager", app_name="example_app"
            ),
            LazyTableReference(
                table_class_name="Manager",
                module_path="tests.example_app.tables",
            ),
        ]

        for references in references_list:
            serialised = serialise_params(params={"references": references})
            self.assertTrue(
                serialised.params["references"].__repr__() == "Manager"
            )

            self.assertTrue(len(serialised.extra_imports) == 1)
            self.assertEqual(
                serialised.extra_imports[0].__str__(),
                "from piccolo.table import Table",
            )

            self.assertTrue(len(serialised.extra_definitions) == 1)
            self.assertEqual(
                serialised.extra_definitions[0].__str__(),
                'class Manager(Table, tablename="manager"): pass',
            )
예제 #6
0
    def new_table_columns(self) -> AlterStatements:
        new_tables: t.List[DiffableTable] = list(
            set(self.schema) - set(self.schema_snapshot))

        response: t.List[str] = []
        extra_imports: t.List[Import] = []
        extra_definitions: t.List[str] = []
        for table in new_tables:
            if (table.class_name
                    in self.rename_tables_collection.new_class_names):
                continue

            for column in table.columns:
                # In case we cause subtle bugs:
                params = deepcopy(column._meta.params)
                _params = serialise_params(params)
                cleaned_params = _params.params
                extra_imports.extend(_params.extra_imports)
                extra_definitions.extend(_params.extra_definitions)

                extra_imports.append(
                    Import(
                        module=column.__class__.__module__,
                        target=column.__class__.__name__,
                    ))

                response.append(
                    f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column._meta.name}', column_class_name='{column.__class__.__name__}', column_class={column.__class__.__name__}, params={str(cleaned_params)})"  # noqa: E501
                )
        return AlterStatements(
            statements=response,
            extra_imports=extra_imports,
            extra_definitions=extra_definitions,
        )
예제 #7
0
    def add_columns(self) -> AlterStatements:
        response = []
        extra_imports = []
        extra_definitions = []
        for table in self.schema:
            snapshot_table = self._get_snapshot_table(table.class_name)
            if snapshot_table:
                delta: TableDelta = table - snapshot_table
            else:
                continue

            for column in delta.add_columns:
                if (
                    column.column_name
                    in self.rename_columns_collection.new_column_names
                ):
                    continue

                params = serialise_params(column.params)
                cleaned_params = params.params
                extra_imports.extend(params.extra_imports)
                extra_definitions.extend(params.extra_definitions)

                response.append(
                    f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', column_class_name='{column.column_class_name}', params={str(cleaned_params)})"  # noqa: E501
                )
        return AlterStatements(
            statements=response,
            extra_imports=extra_imports,
            extra_definitions=extra_definitions,
        )
예제 #8
0
    def test_lazy_table_reference(self):
        # These are equivalent:
        references_list = [
            LazyTableReference(table_class_name="Manager", app_name="music"),
            LazyTableReference(
                table_class_name="Manager",
                module_path="tests.example_apps.music.tables",
            ),
        ]

        for references in references_list:
            serialised = serialise_params(params={"references": references})
            self.assertTrue(
                serialised.params["references"].__repr__() == "Manager")

            self.assertTrue(len(serialised.extra_imports) == 1)
            self.assertEqual(
                serialised.extra_imports[0].__str__(),
                "from piccolo.table import Table",
            )

            self.assertTrue(len(serialised.extra_definitions) == 1)
            self.assertEqual(
                serialised.extra_definitions[0].__str__(),
                ('class Manager(Table, tablename="manager"): '
                 "id = Serial(null=False, primary_key=True, unique=False, "
                 "index=False, index_method=IndexMethod.btree, "
                 "choices=None, db_column_name='id', secret=False)"),
            )
예제 #9
0
 def test_time(self):
     serialised = serialise_params(params={"default": TimeNow()})
     self.assertEqual(serialised.params["default"].__repr__(), "TimeNow()")
     self.assertTrue(len(serialised.extra_imports) == 1)
     self.assertEqual(
         serialised.extra_imports[0].__str__(),
         "from piccolo.columns.defaults.time import TimeNow",
     )
예제 #10
0
    def test_builtins(self):
        """
        Make sure builtins can be serialised properly.
        """
        serialised = serialise_params(params={"default": list})
        self.assertTrue(serialised.params["default"].__repr__() == "list")

        self.assertTrue(len(serialised.extra_imports) == 0)
예제 #11
0
    def test_function(self):
        serialised = serialise_params(params={"default": example_function})
        self.assertTrue(
            serialised.params["default"].__repr__() == "example_function")

        self.assertTrue(len(serialised.extra_imports) == 1)
        self.assertEqual(
            serialised.extra_imports[0].__str__(),
            ("from tests.apps.migrations.auto.test_serialisation import "
             "example_function"),
        )

        self.assertTrue(len(serialised.extra_definitions) == 0)
예제 #12
0
    def test_builtin_enum_instance(self):
        """
        Make sure Enum instances defiend in Piccolo can be serialised properly
        - for example, with on_delete.
        """
        serialised = serialise_params(params={"on_delete": OnDelete.cascade})

        self.assertEqual(serialised.params["on_delete"].__repr__(),
                         "OnDelete.cascade")
        self.assertEqual(
            [i.__repr__() for i in serialised.extra_imports],
            ["from piccolo.columns.base import OnDelete"],
        )
        self.assertEqual(serialised.extra_definitions, [])
예제 #13
0
    def test_custom_enum_instance(self):
        """
        Make sure custom Enum instances can be serialised properly. An example
        is when a user defines a choices Enum, and then sets the default to
        one of those choices.
        """
        class Choices(Enum):
            a = 1
            b = 2

        serialised = serialise_params(params={"default": Choices.a})

        self.assertEqual(serialised.params["default"], 1)
        self.assertEqual(serialised.extra_imports, [])
        self.assertEqual(serialised.extra_definitions, [])
예제 #14
0
    def test_column_instance(self):
        """
        Make sure Column instances can be serialised properly. An example
        use case is when a `base_column` argument is passed to an `Array`
        column.
        """
        serialised = serialise_params(params={"base_column": Varchar()})

        self.assertEqual(
            serialised.params["base_column"].__repr__(),
            "Varchar(length=255, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)",  # noqa: E501
        )

        self.assertEqual(
            {i.__repr__()
             for i in serialised.extra_imports},
            {
                "from piccolo.columns.column_types import Varchar",
                "from piccolo.columns.indexes import IndexMethod",
            },
        )
예제 #15
0
    def add_columns(self) -> AlterStatements:
        response: t.List[str] = []
        extra_imports: t.List[Import] = []
        extra_definitions: t.List[Definition] = []
        for table in self.schema:
            snapshot_table = self._get_snapshot_table(table.class_name)
            if snapshot_table:
                delta: TableDelta = table - snapshot_table
            else:
                continue

            for add_column in delta.add_columns:
                if (add_column.column_name
                        in self.rename_columns_collection.new_column_names):
                    continue

                params = serialise_params(add_column.params)
                cleaned_params = params.params
                extra_imports.extend(params.extra_imports)
                extra_definitions.extend(params.extra_definitions)

                column_class = add_column.column_class
                extra_imports.append(
                    Import(
                        module=column_class.__module__,
                        target=column_class.__name__,
                        expect_conflict_with_global_name=getattr(
                            UniqueGlobalNames,
                            f"COLUMN_{column_class.__name__.upper()}",
                        ),
                    ))

                response.append(
                    f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{add_column.column_name}', db_column_name='{add_column.db_column_name}', column_class_name='{add_column.column_class_name}', column_class={column_class.__name__}, params={str(cleaned_params)})"  # noqa: E501
                )
        return AlterStatements(
            statements=response,
            extra_imports=extra_imports,
            extra_definitions=extra_definitions,
        )
예제 #16
0
    def test_enum_type(self):
        """
        Make sure Enum types can be serialised properly.
        """
        class Choices(Enum):
            a = 1
            b = 2
            c = Choice(value=3, display_name="c1")

        serialised = serialise_params(params={"choices": Choices})

        self.assertEqual(
            serialised.params["choices"].__repr__(),
            "Enum('Choices', {'a': 1, 'b': 2, 'c': Choice(value=3, display_name='c1')})",  # noqa: E501
        )

        self.assertEqual(
            {i.__repr__()
             for i in serialised.extra_imports},
            {
                "from piccolo.columns.choices import Choice",
                "from enum import Enum",
            },
        )
예제 #17
0
 def test_uuid(self):
     serialised = serialise_params(params={"default": UUID4()})
     self.assertTrue(serialised.params["default"].__repr__() == "UUID4()")
예제 #18
0
async def create_table_class_from_db(table_class: t.Type[Table],
                                     tablename: str,
                                     schema_name: str) -> OutputSchema:
    indexes = await get_indexes(table_class=table_class,
                                tablename=tablename,
                                schema_name=schema_name)
    constraints = await get_constraints(table_class=table_class,
                                        tablename=tablename,
                                        schema_name=schema_name)
    triggers = await get_fk_triggers(table_class=table_class,
                                     tablename=tablename,
                                     schema_name=schema_name)
    table_schema = await get_table_schema(table_class=table_class,
                                          tablename=tablename,
                                          schema_name=schema_name)
    output_schema = OutputSchema()
    columns: t.Dict[str, Column] = {}

    for pg_row_meta in table_schema:
        data_type = pg_row_meta.data_type
        column_type = COLUMN_TYPE_MAP.get(data_type, None)
        column_name = pg_row_meta.column_name
        column_default = pg_row_meta.column_default
        if not column_type:
            output_schema.warnings.append(
                f"{tablename}.{column_name} ['{data_type}']")
            column_type = Column

        kwargs: t.Dict[str, t.Any] = {
            "null": pg_row_meta.is_nullable == "YES",
            "unique": constraints.is_unique(column_name=column_name),
        }

        index = indexes.get_column_index(column_name=column_name)
        if index is not None:
            kwargs["index"] = True
            kwargs["index_method"] = index.method

        if constraints.is_primary_key(column_name=column_name):
            kwargs["primary_key"] = True
            if column_type == Integer:
                column_type = Serial

        if constraints.is_foreign_key(column_name=column_name):
            fk_constraint_table = constraints.get_foreign_key_constraint_name(
                column_name=column_name)
            column_type = ForeignKey
            constraint_table = await get_foreign_key_reference(
                table_class=table_class,
                constraint_name=fk_constraint_table.name,
                constraint_schema=fk_constraint_table.schema,
            )
            if constraint_table.name:
                referenced_table: t.Union[str, t.Optional[t.Type[Table]]]

                if constraint_table.name == tablename:
                    referenced_output_schema = output_schema
                    referenced_table = "self"
                else:
                    referenced_output_schema = (
                        await create_table_class_from_db(
                            table_class=table_class,
                            tablename=constraint_table.name,
                            schema_name=constraint_table.schema,
                        ))
                    referenced_table = (
                        referenced_output_schema.get_table_with_name(
                            tablename=constraint_table.name))
                kwargs["references"] = (referenced_table if referenced_table
                                        is not None else ForeignKeyPlaceholder)

                trigger = triggers.get_column_ref_trigger(
                    column_name, constraint_table.name)
                if trigger:
                    kwargs["on_update"] = ONUPDATE_MAP[trigger.on_update]
                    kwargs["on_delete"] = ONDELETE_MAP[trigger.on_delete]

                output_schema = sum(  # type: ignore
                    [output_schema, referenced_output_schema]  # type: ignore
                )  # type: ignore
            else:
                kwargs["references"] = ForeignKeyPlaceholder

        output_schema.imports.append(
            "from piccolo.columns.column_types import " + column_type.__name__)

        if column_type is Varchar:
            kwargs["length"] = pg_row_meta.character_maximum_length
        elif isinstance(column_type, Numeric):
            radix = pg_row_meta.numeric_precision_radix
            precision = int(str(pg_row_meta.numeric_precision), radix)
            scale = int(str(pg_row_meta.numeric_scale), radix)
            kwargs["digits"] = (precision, scale)

        if column_default:
            default_value = get_column_default(column_type, column_default)
            if default_value:
                kwargs["default"] = default_value

        column = column_type(**kwargs)

        serialised_params = serialise_params(column._meta.params)
        for extra_import in serialised_params.extra_imports:
            output_schema.imports.append(extra_import.__repr__())

        columns[column_name] = column

    table = create_table_class(
        class_name=_snake_to_camel(tablename),
        class_kwargs={"tablename": get_table_name(tablename, schema_name)},
        class_members=columns,
    )
    output_schema.tables.append(table)
    return output_schema
예제 #19
0
    def __sub__(self, value: DiffableTable) -> TableDelta:
        if not isinstance(value, DiffableTable):
            raise ValueError(
                "Can only diff with other DiffableTable instances"
            )

        if value.class_name != self.class_name:
            raise ValueError(
                "The two tables don't appear to have the same name."
            )

        add_columns = [
            AddColumn(
                table_class_name=self.class_name,
                column_name=i._meta.name,
                column_class_name=i.__class__.__name__,
                params=i._meta.params,
            )
            for i in (set(self.columns) - set(value.columns))
        ]

        drop_columns = [
            DropColumn(
                table_class_name=self.class_name,
                column_name=i._meta.name,
                tablename=value.tablename,
            )
            for i in (set(value.columns) - set(self.columns))
        ]

        alter_columns: t.List[AlterColumn] = []

        for existing_column in value.columns:
            column = self.columns_map.get(existing_column._meta.name)
            if not column:
                # This is a new column - already captured above.
                continue

            delta = compare_dicts(
                serialise_params(column._meta.params).params,
                serialise_params(existing_column._meta.params).params,
            )

            old_params = {
                key: existing_column._meta.params.get(key)
                for key, _ in delta.items()
            }

            if delta:
                alter_columns.append(
                    AlterColumn(
                        table_class_name=self.class_name,
                        tablename=self.tablename,
                        column_name=column._meta.name,
                        params=deserialise_params(delta),
                        old_params=old_params,
                    )
                )

        return TableDelta(
            add_columns=add_columns,
            drop_columns=drop_columns,
            alter_columns=alter_columns,
        )
예제 #20
0
 def test_date(self):
     serialised = serialise_params(params={"default": DateNow()})
     self.assertEqual(serialised.params["default"].__repr__(), "DateNow()")
예제 #21
0
 def test_timestamp(self):
     serialised = serialise_params(params={"default": TimestampNow()})
     self.assertTrue(
         serialised.params["default"].__repr__() == "TimestampNow()")
예제 #22
0
 def test_decimal(self):
     serialised = serialise_params(
         params={"default": decimal.Decimal("1.2")})
     assert repr(serialised.params["default"]) == 'decimal.Decimal("1.2")'