Example #1
0
    def test_create_table_class(self):
        """
        Make sure a basic `Table` can be created successfully.
        """
        _Table = create_table_class(class_name="MyTable")
        self.assertEqual(_Table._meta.tablename, "my_table")

        _Table = create_table_class(class_name="MyTable",
                                    class_kwargs={"tablename": "my_table_1"})
        self.assertEqual(_Table._meta.tablename, "my_table_1")

        column = Varchar()
        _Table = create_table_class(class_name="MyTable",
                                    class_members={"name": column})
        self.assertIn(column, _Table._meta.columns)
Example #2
0
    async def _run_add_columns(self, backwards=False):
        """
        Add columns, which belong to existing tables
        """
        if backwards:
            for add_column in self.add_columns.add_columns:
                if add_column.table_class_name in [
                        i.class_name for i in self.add_tables
                ]:
                    # Don't reverse the add column as the table is going to
                    # be deleted.
                    continue

                _Table: t.Type[Table] = create_table_class(
                    class_name=add_column.table_class_name,
                    class_kwargs={"tablename": add_column.tablename},
                )

                await _Table.alter().drop_column(add_column.column).run()
        else:
            for table_class_name in self.add_columns.table_class_names:
                if table_class_name in [i.class_name for i in self.add_tables]:
                    continue  # No need to add columns to new tables

                add_columns: t.List[
                    AddColumnClass] = self.add_columns.for_table_class_name(
                        table_class_name)

                # Define the table, with the columns, so the metaclass
                # sets up the columns correctly.
                _Table: t.Type[Table] = create_table_class(
                    class_name=add_columns[0].table_class_name,
                    class_kwargs={"tablename": add_columns[0].tablename},
                    class_members={
                        add_column.column._meta.name: add_column.column
                        for add_column in add_columns
                    },
                )

                for add_column in add_columns:
                    # We fetch the column from the Table, as the metaclass
                    # copies and sets it up properly.
                    column = _Table._meta.get_column_by_name(
                        add_column.column._meta.name)
                    await _Table.alter().add_column(name=column._meta.name,
                                                    column=column).run()
                    if add_column.column._meta.index:
                        await _Table.create_index([add_column.column]).run()
Example #3
0
    async def _run_add_tables(self, backwards=False):
        table_classes: t.List[t.Type[Table]] = []
        for add_table in self.add_tables:
            add_columns: t.List[
                AddColumnClass] = self.add_columns.for_table_class_name(
                    add_table.class_name)
            _Table: t.Type[Table] = create_table_class(
                class_name=add_table.class_name,
                class_kwargs={"tablename": add_table.tablename},
                class_members={
                    add_column.column._meta.name: add_column.column
                    for add_column in add_columns
                },
            )
            table_classes.append(_Table)

        # Sort by foreign key, so they're created in the right order.
        sorted_table_classes = sort_table_classes(table_classes)

        if backwards:
            for _Table in reversed(sorted_table_classes):
                await _Table.alter().drop_table(cascade=True).run()
        else:
            for _Table in sorted_table_classes:
                await _Table.create_table().run()
Example #4
0
    async def _run_drop_columns(self, backwards=False):
        if backwards:
            for drop_column in self.drop_columns.drop_columns:
                _Table = await self.get_table_from_snaphot(
                    table_class_name=drop_column.table_class_name,
                    app_name=self.app_name,
                    offset=-1,
                )
                column_to_restore = _Table._meta.get_column_by_name(
                    drop_column.column_name)
                await _Table.alter().add_column(
                    name=drop_column.column_name,
                    column=column_to_restore).run()
        else:
            for table_class_name in self.drop_columns.table_class_names:
                columns = self.drop_columns.for_table_class_name(
                    table_class_name)

                if not columns:
                    continue

                _Table: t.Type[Table] = create_table_class(
                    class_name=table_class_name,
                    class_kwargs={"tablename": columns[0].tablename},
                )

                for column in columns:
                    await _Table.alter().drop_column(column=column.column_name
                                                     ).run()
Example #5
0
    async def _run_rename_tables(self, backwards=False):
        for rename_table in self.rename_tables:
            class_name = (rename_table.new_class_name
                          if backwards else rename_table.old_class_name)
            tablename = (rename_table.new_tablename
                         if backwards else rename_table.old_tablename)
            new_tablename = (rename_table.old_tablename
                             if backwards else rename_table.new_tablename)

            _Table: t.Type[Table] = create_table_class(
                class_name=class_name, class_kwargs={"tablename": tablename})

            await _Table.alter().rename_table(new_name=new_tablename).run()
Example #6
0
    def to_table_class(self) -> t.Type[Table]:
        """
        Converts the DiffableTable into a Table subclass.
        """
        _Table: t.Type[Table] = create_table_class(
            class_name=self.class_name,
            class_kwargs={"tablename": self.tablename},
            class_members={
                column._meta.name: column
                for column in self.columns
            },
        )

        return _Table
Example #7
0
def deserialise_legacy_params(name: str, value: str) -> t.Any:
    """
    Earlier versions of Piccolo serialised parameters differently. This is
    here purely for backwards compatibility.
    """
    if name == "references":
        components = value.split("|")
        if len(components) == 1:
            class_name = components[0]
            tablename = None
        elif len(components) == 2:
            class_name, tablename = components
        else:
            raise ValueError(
                "Unrecognised Table serialisation - should either be "
                "`SomeClassName` or `SomeClassName|some_table_name`.")

        _Table: t.Type[Table] = create_table_class(
            class_name=class_name,
            class_kwargs={"tablename": tablename} if tablename else {},
        )
        return _Table

    ###########################################################################

    if name == "on_delete":
        enum_name, item_name = value.split(".")
        if enum_name == "OnDelete":
            return getattr(OnDelete, item_name)

    ###########################################################################

    elif name == "on_update":
        enum_name, item_name = value.split(".")
        if enum_name == "OnUpdate":
            return getattr(OnUpdate, item_name)

    ###########################################################################

    if name == "default":
        if value in {"TimestampDefault.now", "DatetimeDefault.now"}:
            return TimestampNow()
        try:
            _value = datetime.datetime.fromisoformat(value)
        except ValueError:
            pass
        else:
            return _value

    return value
Example #8
0
    def test_protected_tablenames(self):
        """
        Make sure that the logic around protected tablenames still works as
        expected.
        """
        with self.assertRaises(ValueError):
            create_table_class(class_name="User")

        with self.assertRaises(ValueError):
            create_table_class(class_name="MyUser",
                               class_kwargs={"tablename": "user"})

        # This shouldn't raise an error:
        create_table_class(class_name="User",
                           class_kwargs={"tablename": "my_user"})
Example #9
0
    async def _run_rename_columns(self, backwards=False):
        for table_class_name in self.rename_columns.table_class_names:
            columns = self.rename_columns.for_table_class_name(
                table_class_name)

            if not columns:
                continue

            _Table: t.Type[Table] = create_table_class(
                class_name=table_class_name,
                class_kwargs={"tablename": columns[0].tablename},
            )

            for rename_column in columns:
                column = (rename_column.new_db_column_name
                          if backwards else rename_column.old_db_column_name)
                new_name = (rename_column.old_db_column_name
                            if backwards else rename_column.new_db_column_name)

                await _Table.alter().rename_column(
                    column=column,
                    new_name=new_name,
                ).run()
Example #10
0
 def table(self, column: Column):
     return create_table_class(class_name="MyTable",
                               class_members={"my_column": column})
Example #11
0
 def tearDown(self):
     create_table_class("MyTable").alter().drop_table(
         if_exists=True).run_sync()
     Migration.alter().drop_table(if_exists=True).run_sync()
Example #12
0
    async def _run_alter_columns(self, backwards=False):
        for table_class_name in self.alter_columns.table_class_names:
            alter_columns = self.alter_columns.for_table_class_name(
                table_class_name)

            if not alter_columns:
                continue

            _Table: t.Type[Table] = create_table_class(
                class_name=table_class_name,
                class_kwargs={"tablename": alter_columns[0].tablename},
            )

            for alter_column in alter_columns:

                params = (alter_column.old_params
                          if backwards else alter_column.params)

                old_params = (alter_column.params
                              if backwards else alter_column.old_params)

                ###############################################################

                # Change the column type if possible
                column_class = (alter_column.old_column_class
                                if backwards else alter_column.column_class)
                old_column_class = (alter_column.column_class if backwards else
                                    alter_column.old_column_class)

                if (old_column_class is not None) and (column_class
                                                       is not None):
                    if old_column_class != column_class:
                        old_column = old_column_class(**old_params)
                        old_column._meta._table = _Table
                        old_column._meta._name = alter_column.column_name
                        old_column._meta.db_column_name = (
                            alter_column.db_column_name)

                        new_column = column_class(**params)
                        new_column._meta._table = _Table
                        new_column._meta._name = alter_column.column_name
                        new_column._meta.db_column_name = (
                            alter_column.db_column_name)

                        using_expression: t.Optional[str] = None

                        # Postgres won't automatically cast some types to
                        # others. We may as well try, as it will definitely
                        # fail otherwise.
                        if new_column.value_type != old_column.value_type:
                            if old_params.get("default", ...) is not None:
                                # Unless the column's default value is also
                                # something which can be cast to the new type,
                                # it will also fail. Drop the default value for
                                # now - the proper default is set later on.
                                await _Table.alter().drop_default(old_column
                                                                  ).run()

                            using_expression = "{}::{}".format(
                                alter_column.db_column_name,
                                new_column.column_type,
                            )

                        # We can't migrate a SERIAL to a BIGSERIAL or vice
                        # versa, as SERIAL isn't a true type, just an alias to
                        # other commands.
                        if issubclass(column_class, Serial) and issubclass(
                                old_column_class, Serial):
                            colored_warning(
                                "Unable to migrate Serial to BigSerial and "
                                "vice versa. This must be done manually.")
                        else:
                            await _Table.alter().set_column_type(
                                old_column=old_column,
                                new_column=new_column,
                                using_expression=using_expression,
                            ).run()

                ###############################################################

                null = params.get("null")
                if null is not None:
                    await _Table.alter().set_null(
                        column=alter_column.db_column_name,
                        boolean=null).run()

                length = params.get("length")
                if length is not None:
                    await _Table.alter().set_length(
                        column=alter_column.db_column_name,
                        length=length).run()

                unique = params.get("unique")
                if unique is not None:
                    # When modifying unique contraints, we need to pass in
                    # a column type, and not just the column name.
                    column = Column()
                    column._meta._table = _Table
                    column._meta._name = alter_column.column_name
                    column._meta.db_column_name = alter_column.db_column_name
                    await _Table.alter().set_unique(column=column,
                                                    boolean=unique).run()

                index = params.get("index")
                index_method = params.get("index_method")
                if index is None:
                    if index_method is not None:
                        # If the index value hasn't changed, but the
                        # index_method value has, this indicates we need
                        # to change the index type.
                        column = Column()
                        column._meta._table = _Table
                        column._meta._name = alter_column.column_name
                        column._meta.db_column_name = (
                            alter_column.db_column_name)
                        await _Table.drop_index([column]).run()
                        await _Table.create_index([column],
                                                  method=index_method,
                                                  if_not_exists=True).run()
                else:
                    # If the index value has changed, then we are either
                    # dropping, or creating an index.
                    column = Column()
                    column._meta._table = _Table
                    column._meta._name = alter_column.column_name
                    column._meta.db_column_name = alter_column.db_column_name

                    if index is True:
                        kwargs = ({
                            "method": index_method
                        } if index_method else {})
                        await _Table.create_index([column],
                                                  if_not_exists=True,
                                                  **kwargs).run()
                    else:
                        await _Table.drop_index([column]).run()

                # None is a valid value, so retrieve ellipsis if not found.
                default = params.get("default", ...)
                if default is not ...:
                    column = Column()
                    column._meta._table = _Table
                    column._meta._name = alter_column.column_name
                    column._meta.db_column_name = alter_column.db_column_name

                    if default is None:
                        await _Table.alter().drop_default(column=column).run()
                    else:
                        column.default = default
                        await _Table.alter().set_default(
                            column=column,
                            value=column.get_default_value()).run()

                # None is a valid value, so retrieve ellipsis if not found.
                digits = params.get("digits", ...)
                if digits is not ...:
                    await _Table.alter().set_digits(
                        column=alter_column.db_column_name,
                        digits=digits,
                    ).run()
Example #13
0
async def create_table_class_from_db(table_class: t.Type[Table],
                                     tablename: str,
                                     schema_name: str) -> OutputSchema:
    indexes = await get_indexes(table_class=table_class,
                                tablename=tablename,
                                schema_name=schema_name)
    constraints = await get_constraints(table_class=table_class,
                                        tablename=tablename,
                                        schema_name=schema_name)
    triggers = await get_fk_triggers(table_class=table_class,
                                     tablename=tablename,
                                     schema_name=schema_name)
    table_schema = await get_table_schema(table_class=table_class,
                                          tablename=tablename,
                                          schema_name=schema_name)
    output_schema = OutputSchema()
    columns: t.Dict[str, Column] = {}

    for pg_row_meta in table_schema:
        data_type = pg_row_meta.data_type
        column_type = COLUMN_TYPE_MAP.get(data_type, None)
        column_name = pg_row_meta.column_name
        column_default = pg_row_meta.column_default
        if not column_type:
            output_schema.warnings.append(
                f"{tablename}.{column_name} ['{data_type}']")
            column_type = Column

        kwargs: t.Dict[str, t.Any] = {
            "null": pg_row_meta.is_nullable == "YES",
            "unique": constraints.is_unique(column_name=column_name),
        }

        index = indexes.get_column_index(column_name=column_name)
        if index is not None:
            kwargs["index"] = True
            kwargs["index_method"] = index.method

        if constraints.is_primary_key(column_name=column_name):
            kwargs["primary_key"] = True
            if column_type == Integer:
                column_type = Serial

        if constraints.is_foreign_key(column_name=column_name):
            fk_constraint_table = constraints.get_foreign_key_constraint_name(
                column_name=column_name)
            column_type = ForeignKey
            constraint_table = await get_foreign_key_reference(
                table_class=table_class,
                constraint_name=fk_constraint_table.name,
                constraint_schema=fk_constraint_table.schema,
            )
            if constraint_table.name:
                referenced_table: t.Union[str, t.Optional[t.Type[Table]]]

                if constraint_table.name == tablename:
                    referenced_output_schema = output_schema
                    referenced_table = "self"
                else:
                    referenced_output_schema = (
                        await create_table_class_from_db(
                            table_class=table_class,
                            tablename=constraint_table.name,
                            schema_name=constraint_table.schema,
                        ))
                    referenced_table = (
                        referenced_output_schema.get_table_with_name(
                            tablename=constraint_table.name))
                kwargs["references"] = (referenced_table if referenced_table
                                        is not None else ForeignKeyPlaceholder)

                trigger = triggers.get_column_ref_trigger(
                    column_name, constraint_table.name)
                if trigger:
                    kwargs["on_update"] = ONUPDATE_MAP[trigger.on_update]
                    kwargs["on_delete"] = ONDELETE_MAP[trigger.on_delete]

                output_schema = sum(  # type: ignore
                    [output_schema, referenced_output_schema]  # type: ignore
                )  # type: ignore
            else:
                kwargs["references"] = ForeignKeyPlaceholder

        output_schema.imports.append(
            "from piccolo.columns.column_types import " + column_type.__name__)

        if column_type is Varchar:
            kwargs["length"] = pg_row_meta.character_maximum_length
        elif isinstance(column_type, Numeric):
            radix = pg_row_meta.numeric_precision_radix
            precision = int(str(pg_row_meta.numeric_precision), radix)
            scale = int(str(pg_row_meta.numeric_scale), radix)
            kwargs["digits"] = (precision, scale)

        if column_default:
            default_value = get_column_default(column_type, column_default)
            if default_value:
                kwargs["default"] = default_value

        column = column_type(**kwargs)

        serialised_params = serialise_params(column._meta.params)
        for extra_import in serialised_params.extra_imports:
            output_schema.imports.append(extra_import.__repr__())

        columns[column_name] = column

    table = create_table_class(
        class_name=_snake_to_camel(tablename),
        class_kwargs={"tablename": get_table_name(tablename, schema_name)},
        class_members=columns,
    )
    output_schema.tables.append(table)
    return output_schema
Example #14
0
 def table_exists(self, tablename: str) -> bool:
     _Table: t.Type[Table] = create_table_class(
         class_name=tablename.upper(),
         class_kwargs={"tablename": tablename})
     return _Table.table_exists().run_sync()
Example #15
0
 def run_sync(self, query):
     _Table = create_table_class(class_name="_Table")
     return _Table.raw(query).run_sync()