Esempio n. 1
0
 def _get_config_from_config_file(cls, config_file: str) -> dict:
     _, extension = os.path.splitext(config_file)
     if extension in ('.yml', '.yaml'):
         import yaml
         with open(config_file, 'r') as f:
             config = yaml.safe_load(f)
     elif extension == '.json':
         with open(config_file, 'r') as f:
             config = json.load(f)
     else:
         raise ConfigurationError(
             'Unknown config extension {}, only .yml and .json are supported'.format(
                 extension
             )
         )
     return config
Esempio n. 2
0
 def _discover_models(cls, models_path, app_label) -> List[Type[Model]]:
     try:
         module = importlib.import_module(models_path)
     except ImportError:
         raise ConfigurationError(
             'Module "{}" not found'.format(models_path))
     discovered_models = []
     for attr_name in dir(module):
         attr = getattr(module, attr_name)
         if isclass(attr) and issubclass(attr,
                                         Model) and not attr._meta.abstract:
             if attr._meta.app and attr._meta.app != app_label:
                 continue
             attr._meta.app = app_label
             discovered_models.append(attr)
     return discovered_models
Esempio n. 3
0
    def _get_config_from_config_file(self, config_file: str) -> dict:
        _, extension = os.path.splitext(config_file)
        if extension in (".yml", ".yaml"):
            import yaml  # pylint: disable=C0415

            with open(config_file, "r") as f:
                config = yaml.safe_load(f)

        elif extension == ".json":
            with open(config_file, "r") as f:
                config = json.load(f)
        else:
            raise ConfigurationError(
                f"Unknown config extension {extension}, only .yml and .json are supported"
            )
        return config
Esempio n. 4
0
    async def generate_schemas(cls, safe: bool = True) -> None:
        """
        Generate schemas according to models provided to ``.init()`` method.
        Will fail if schemas already exists, so it's not recommended to be used as part
        of application workflow

        Parameters
        ----------
        safe:
            When set to true, creates the table only when it does not already exist.
        """
        if not cls._inited:
            raise ConfigurationError(
                "You have to call .init() first before generating schemas")
        for connection in cls._connections.values():
            await generate_schema_for_client(connection, safe)
Esempio n. 5
0
        def split_reference(reference: str) -> Tuple[str, str]:
            """
            Test, if reference follow the official naming conventions. Throws a
            ConfigurationError with a hopefully helpful message. If successfull,
            returns the app and the model name.

            :raises ConfigurationError: If no model reference is invalid.
            """
            items = reference.split(".")
            if len(items) != 2:  # pragma: nocoverage
                raise ConfigurationError(
                    ("'%s' is not a valid model reference Bad Reference."
                     " Should be something like <appname>.<modelname>.") %
                    reference)

            return (items[0], items[1])
Esempio n. 6
0
 def __init__(self,
              connection_name: str,
              pool=None,
              connection=None) -> None:
     if pool and connection:
         raise ConfigurationError('You must pass either connection or pool')
     self._connection = connection
     self.log = logging.getLogger('db_client')
     self._pool = pool
     self.single_connection = True
     self._single_connection_class = type(
         'SingleConnectionWrapper',
         (SingleConnectionWrapper, self.__class__), {})
     self._transaction_class = self.__class__
     self._old_context_value = None
     self.connection_name = connection_name
     self.transaction = None
Esempio n. 7
0
    def _resolve_field_for_model(self, model: "Type[Model]", table: Table, field: str) -> dict:
        joins = []
        fields = field.split("__")

        for iter_field in fields[:-1]:
            if iter_field not in model._meta.fetch_fields:
                raise ConfigurationError(f"{field} not resolvable")

            related_field = cast(RelationalField, model._meta.fields_map[iter_field])
            joins.append((table, iter_field, related_field))

            model = related_field.related_model
            related_table: Table = related_field.related_model._meta.basetable
            if isinstance(related_field, ForeignKeyFieldInstance):
                # Only FK's can be to same table, so we only auto-alias FK join tables
                related_table = related_table.as_(f"{table.get_table_name()}__{iter_field}")
            table = related_table

        last_field = fields[-1]
        if last_field in model._meta.fetch_fields:
            related_field = cast(RelationalField, model._meta.fields_map[last_field])
            related_field_meta = related_field.related_model._meta
            joins.append((table, last_field, related_field))
            related_table = related_field_meta.basetable

            if isinstance(related_field, BackwardFKRelation):
                if table == related_table:
                    related_table = related_table.as_(f"{table.get_table_name()}__{last_field}")

            field = related_table[related_field_meta.db_pk_column]
        else:
            field_object = model._meta.fields_map[last_field]
            if field_object.source_field:
                field = table[field_object.source_field]
            else:
                field = table[last_field]
            if self.populate_field_object:
                self.field_object = model._meta.fields_map.get(last_field, None)
                if self.field_object:  # pragma: nobranch
                    func = self.field_object.get_for_dialect(
                        model._meta.db.capabilities.dialect, "function_cast"
                    )
                    if func:
                        field = func(self.field_object, field)

        return {"joins": joins, "field": field}
Esempio n. 8
0
    def _resolve_field_for_model(self, model: "Type[Model]", table: Table,
                                 field: str, *default_values: Any) -> dict:
        field_split = field.split("__")
        if not field_split[1:]:
            function_joins = []
            if field_split[0] in model._meta.fetch_fields:
                related_field = cast(RelationalField,
                                     model._meta.fields_map[field_split[0]])
                related_field_meta = related_field.model_class._meta
                join = (table, field_split[0], related_field)
                function_joins.append(join)
                field = related_field_meta.basetable[
                    related_field_meta.db_pk_field]
            else:
                field = table[field_split[0]]

                if self.populate_field_object:
                    self.field_object = model._meta.fields_map.get(
                        field_split[0], None)
                    if self.field_object:  # pragma: nobranch
                        func = self.field_object.get_for_dialect(
                            model._meta.db.capabilities.dialect,
                            "function_cast")
                        if func:
                            field = func(self.field_object, field)

            function_field = self.database_func(field, *default_values)
            return {"joins": function_joins, "field": function_field}

        if field_split[0] not in model._meta.fetch_fields:
            raise ConfigurationError(f"{field} not resolvable")
        related_field = cast(RelationalField,
                             model._meta.fields_map[field_split[0]])
        join = (table, field_split[0], related_field)
        related_table = related_field.model_class._meta.basetable
        if isinstance(related_field, ForeignKeyFieldInstance):
            # Only FK's can be to same table, so we only auto-alias FK join tables
            related_table = related_table.as_(
                f"{table.get_table_name()}__{field_split[0]}")
        function = self._resolve_field_for_model(related_field.model_class,
                                                 related_table,
                                                 "__".join(field_split[1:]),
                                                 *default_values)
        function["joins"].append(join)
        return function
Esempio n. 9
0
    def _init_apps(cls, apps_config: dict) -> None:
        for name, info in apps_config.items():
            try:
                connections.get(info.get("default_connection", "default"))
            except KeyError:
                raise ConfigurationError(
                    'Unknown connection "{}" for app "{}"'.format(
                        info.get("default_connection", "default"), name))

            cls.init_models(info["models"], name, _init_relations=False)

            for model in cls.apps[name].values():
                model._meta.default_connection = info.get(
                    "default_connection", "default")

        cls._init_relations()

        cls._build_initial_querysets()
Esempio n. 10
0
    async def _drop_databases(cls) -> None:
        """
        Tries to drop all databases provided in config passed to ``.init()`` method.
        Normally should be used only for testing purposes.

        :raises ConfigurationError: When ``.init()`` has not been called.
        """
        if not cls._inited:
            raise ConfigurationError(
                "You have to call .init() first before deleting schemas")
        # this closes any existing connections/pool if any and clears
        # the storage
        await connections.close_all(discard=False)
        for conn in connections.all():
            await conn.db_delete()
            connections.discard(conn.connection_name)

        await cls._reset_apps()
Esempio n. 11
0
    def db_config(self) -> "DBConfigType":
        """
        Return the DB config.

        This is the same config passed to the
        :meth:`Tortoise.init<tortoise.Tortoise.init>` method while initialization.

        :raises ConfigurationError:
            If this property is accessed before calling the
            :meth:`Tortoise.init<tortoise.Tortoise.init>` method.
        """
        if self._db_config is None:
            raise ConfigurationError(
                "DB configuration not initialised. Make sure to call "
                "Tortoise.init with a valid configuration before attempting "
                "to create connections."
            )
        return self._db_config
Esempio n. 12
0
 def __init__(self,
              model_name: str,
              through: Optional[str] = None,
              forward_key: Optional[str] = None,
              backward_key: str = "",
              related_name: str = "",
              **kwargs) -> None:
     super().__init__(**kwargs)
     if len(model_name.split(".")) != 2:
         raise ConfigurationError(
             'Foreign key accepts model name in format "app.Model"')
     self.model_name = model_name
     self.related_name = related_name
     self.forward_key = forward_key or "{}_id".format(
         model_name.split(".")[1].lower())
     self.backward_key = backward_key
     self.through = through
     self._generated = False
Esempio n. 13
0
    def add_field(self, name: str, value: Field):
        if name in self.fields_map:
            raise ConfigurationError(f"Field {name} already present in meta")
        value.model = self._model
        self.fields_map[name] = value

        if value.has_db_field:
            self.fields_db_projection[name] = value.source_field or name

        if isinstance(value, fields.ManyToManyField):
            self.m2m_fields.add(name)
        elif isinstance(value, fields.BackwardFKRelation):
            self.backward_fk_fields.add(name)

        field_filters = get_filters_for_field(
            field_name=name, field=value, source_field=value.source_field or name
        )
        self._filters.update(field_filters)
        self.finalise_fields()
Esempio n. 14
0
 def __init__(
     self,
     model_name: str,
     through: Optional[str] = None,
     forward_key: Optional[str] = None,
     backward_key: str = "",
     related_name: str = "",
     field_type: "Type[Model]" = None,  # type: ignore
     **kwargs: Any,
 ) -> None:
     super().__init__(**kwargs)
     self.model_class: "Type[Model]" = field_type
     if len(model_name.split(".")) != 2:
         raise ConfigurationError('Foreign key accepts model name in format "app.Model"')
     self.model_name: str = model_name
     self.related_name: str = related_name
     self.forward_key: str = forward_key or f"{model_name.split('.')[1].lower()}_id"
     self.backward_key: str = backward_key
     self.through: str = through  # type: ignore
     self._generated: bool = False
Esempio n. 15
0
    def _init_apps(self, apps_config: dict) -> None:
        for app_name, app_config in apps_config.items():
            connection_name = app_config.get("default_connection", "default")
            try:
                self.get_db_client(connection_name)
            except KeyError:
                raise ConfigurationError(
                    'Unknown connection "{}" for app "{}"'.format(
                        connection_name, app_name))

            app_models: List[Type[Model]] = []
            for module in app_config["models"]:
                app_models += self._discover_models(module, app_name)

            for model in app_models:
                model._meta.connection_name = connection_name

            self._app_models_map[app_name] = {
                model.__name__: model
                for model in app_models
            }
Esempio n. 16
0
    def _init_apps(cls, apps_config):
        for name, info in apps_config.items():
            try:
                connection = cls.get_connection(
                    info.get('default_connection', 'default'))
            except KeyError:
                raise ConfigurationError(
                    'Unknown connection "{}" for app "{}"'.format(
                        apps_config.get('default_connection', 'default'),
                        name,
                    ))
            app_models = []  # type: List[Type[Model]]
            for module in info['models']:
                app_models += cls._discover_models(module, name)

            models_map = {}
            for model in app_models:
                model._meta.default_db = connection
                models_map[model.__name__] = model

            cls.apps[name] = models_map
Esempio n. 17
0
    def get_create_schema_sql(self, safe=True) -> str:
        from tortoise import Tortoise

        models_to_create = []

        for app in Tortoise.apps.values():
            for model in app.values():
                if model._meta.db == self.client:
                    model.check()
                    models_to_create.append(model)

        tables_to_create = []
        for model in models_to_create:
            tables_to_create.append(self._get_table_sql(model, safe))

        tables_to_create_count = len(tables_to_create)

        created_tables = set()  # type: Set[dict]
        ordered_tables_for_create = []
        m2m_tables_to_create = []  # type: List[str]
        while True:
            if len(created_tables) == tables_to_create_count:
                break
            try:
                next_table_for_create = next(
                    t for t in tables_to_create
                    if t["references"].issubset(created_tables))
            except StopIteration:
                raise ConfigurationError(
                    "Can't create schema due to cyclic fk references")
            tables_to_create.remove(next_table_for_create)
            created_tables.add(next_table_for_create["table"])
            ordered_tables_for_create.append(
                next_table_for_create["table_creation_string"])
            m2m_tables_to_create += next_table_for_create["m2m_tables"]

        schema_creation_string = " ".join(ordered_tables_for_create +
                                          m2m_tables_to_create)
        return schema_creation_string
Esempio n. 18
0
 def _discover_models(cls, models_path: str, app_label: str) -> List[Type[Model]]:
     try:
         module = importlib.import_module(models_path)
     except ImportError:
         raise ConfigurationError(f'Module "{models_path}" not found')
     discovered_models = []
     possible_models = getattr(module, "__models__", None)
     try:
         possible_models = [*possible_models]
     except TypeError:
         possible_models = None
     if not possible_models:
         possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
     for attr in possible_models:
         if isclass(attr) and issubclass(attr, Model) and not attr._meta.abstract:
             if attr._meta.app and attr._meta.app != app_label:
                 continue
             attr._meta.app = app_label
             discovered_models.append(attr)
     if not discovered_models:
         warnings.warn(f'Module "{models_path}" has no models', RuntimeWarning, stacklevel=4)
     return discovered_models
Esempio n. 19
0
    def create_relation(self, tortoise) -> None:
        remote_model = tortoise.get_model(self.remote_model, self.model)

        self.id_field_name = f"{self.model_field_name}_id"

        id_field_object = deepcopy(remote_model._meta.pk)
        id_field_object.primary_key = self.primary_key
        id_field_object.unique = self.unique
        id_field_object.db_index = self.db_index
        id_field_object.default = self.default
        id_field_object.null = self.null
        id_field_object.generated = self.generated
        id_field_object.auto_created = True
        id_field_object.reference = self
        id_field_object.description = self.description
        id_field_object.db_column = self.db_column if self.db_column else self.id_field_name

        self.db_column = id_field_object.db_column
        self.model._meta.add_field(self.id_field_name, id_field_object)
        self.remote_model = remote_model

        if self.primary_key:
            self.model._meta.pk_attr = self.id_field_name

        backward_relation_name = self.related_name
        if backward_relation_name is not False:
            if not backward_relation_name:
                backward_relation_name = "{}_set".format(self.model.__name__.lower())

            if backward_relation_name in remote_model._meta.fields_map:
                raise ConfigurationError(
                    f"backward relation '{backward_relation_name}' duplicates in"
                    f" model {remote_model}"
                )

            backward_relation_field = self.backward_relation_class(
                self.model, self.id_field_name, True, self.description)
            remote_model._meta.add_field(backward_relation_name, backward_relation_field)
Esempio n. 20
0
    def _init_apps(cls, apps_config: dict) -> None:
        for name, info in apps_config.items():
            try:
                cls.get_connection(info.get("default_connection", "default"))
            except KeyError:
                raise ConfigurationError(
                    'Unknown connection "{}" for app "{}"'.format(
                        info.get("default_connection", "default"), name))
            app_models = []  # type: List[Type[Model]]
            for module in info["models"]:
                app_models += cls._discover_models(module, name)

            models_map = {}
            for model in app_models:
                model._meta.default_connection = info.get(
                    "default_connection", "default")
                models_map[model.__name__] = model

            cls.apps[name] = models_map

        cls._init_relations()

        cls._build_initial_querysets()
Esempio n. 21
0
 def __init__(
     self,
     model_name: str,
     through: Optional[str] = None,
     forward_key: Optional[str] = None,
     backward_key: str = "",
     related_name: str = "",
     on_delete: str = CASCADE,
     field_type: "Type[Model]" = None,  # type: ignore
     **kwargs: Any,
 ) -> None:
     # TODO: rename through to through_table
     # TODO: add through to use a Model
     super().__init__(field_type, **kwargs)
     if len(model_name.split(".")) != 2:
         raise ConfigurationError('Foreign key accepts model name in format "app.Model"')
     self.model_name: str = model_name
     self.related_name: str = related_name
     self.forward_key: str = forward_key or f"{model_name.split('.')[1].lower()}_id"
     self.backward_key: str = backward_key
     self.through: str = through  # type: ignore
     self._generated: bool = False
     self.on_delete = on_delete
Esempio n. 22
0
    def add_field(self, name: str, value: Field):
        if name in self.fields_map:
            raise ConfigurationError("Field {} already present in meta".format(name))
        setattr(self._model, name, value)
        value.model = self._model
        self.fields_map[name] = value
        self._fields = None

        if value.has_db_field:
            self.fields_db_projection[name] = value.source_field or name
            self._fields_db_projection_reverse = None

        if isinstance(value, fields.ManyToManyField):
            self.m2m_fields.add(name)
            self._fetch_fields = None
        elif isinstance(value, fields.BackwardFKRelation):
            self.backward_fk_fields.add(name)
            self._fetch_fields = None

        field_filters = get_filters_for_field(
            field_name=name, field=value, source_field=value.source_field or name
        )
        self._filters.update(field_filters)
        self.generate_filters()
Esempio n. 23
0
    def _resolve_field_for_model(self, field: str, model) -> dict:
        field_split = field.split("__")
        if not field_split[1:]:
            aggregation_joins = []  # type: list
            if field_split[0] in model._meta.fetch_fields:
                related_field = model._meta.fields_map[field_split[0]]
                join = (Table(model._meta.table), field_split[0],
                        related_field)
                aggregation_joins.append(join)
                aggregation_field = self.aggregation_func(
                    Table(related_field.type._meta.table).id)
            else:
                aggregation_field = self.aggregation_func(
                    getattr(Table(model._meta.table), field_split[0]))
            return {"joins": aggregation_joins, "field": aggregation_field}

        if field_split[0] not in model._meta.fetch_fields:
            raise ConfigurationError("{} not resolvable".format(field))
        related_field = model._meta.fields_map[field_split[0]]
        join = (Table(model._meta.table), field_split[0], related_field)
        aggregation = self._resolve_field_for_model("__".join(field_split[1:]),
                                                    related_field.type)
        aggregation["joins"].append(join)
        return aggregation
Esempio n. 24
0
    async def init(
        cls,
        config: Optional[dict] = None,
        config_file: Optional[str] = None,
        _create_db: bool = False,
        db_url: Optional[str] = None,
        modules: Optional[Dict[str, Iterable[Union[str, ModuleType]]]] = None,
        use_tz: bool = False,
        timezone: str = "UTC",
        routers: Optional[List[Union[str, Type]]] = None,
    ) -> None:
        """
        Sets up Tortoise-ORM.

        You can configure using only one of ``config``, ``config_file``
        and ``(db_url, modules)``.

        :param config:
            Dict containing config:

            .. admonition:: Example

                .. code-block:: python3

                    {
                        'connections': {
                            # Dict format for connection
                            'default': {
                                'engine': 'tortoise.backends.asyncpg',
                                'credentials': {
                                    'host': 'localhost',
                                    'port': '5432',
                                    'user': '******',
                                    'password': '******',
                                    'database': 'test',
                                }
                            },
                            # Using a DB_URL string
                            'default': 'postgres://*****:*****@localhost:5432/test'
                        },
                        'apps': {
                            'my_app': {
                                'models': ['__main__'],
                                # If no default_connection specified, defaults to 'default'
                                'default_connection': 'default',
                            }
                        },
                        'routers': ['path.router1', 'path.router2'],
                        'use_tz': False,
                        'timezone': 'UTC'
                    }

        :param config_file:
            Path to .json or .yml (if PyYAML installed) file containing config with
            same format as above.
        :param db_url:
            Use a DB_URL string. See :ref:`db_url`
        :param modules:
            Dictionary of ``key``: [``list_of_modules``] that defined "apps" and modules that
            should be discovered for models.
        :param _create_db:
            If ``True`` tries to create database for specified connections,
            could be used for testing purposes.
        :param use_tz:
            A boolean that specifies if datetime will be timezone-aware by default or not.
        :param timezone:
            Timezone to use, default is UTC.
        :param routers:
            A list of db routers str path or module.

        :raises ConfigurationError: For any configuration error
        """
        if cls._inited:
            await cls.close_connections()
            await cls._reset_apps()
        if int(bool(config) + bool(config_file) + bool(db_url)) != 1:
            raise ConfigurationError(
                'You should init either from "config", "config_file" or "db_url"'
            )

        if config_file:
            config = cls._get_config_from_config_file(config_file)

        if db_url:
            if not modules:
                raise ConfigurationError('You must specify "db_url" and "modules" together')
            config = generate_config(db_url, modules)

        try:
            connections_config = config["connections"]  # type: ignore
        except KeyError:
            raise ConfigurationError('Config must define "connections" section')

        try:
            apps_config = config["apps"]  # type: ignore
        except KeyError:
            raise ConfigurationError('Config must define "apps" section')

        use_tz = config.get("use_tz", use_tz)  # type: ignore
        timezone = config.get("timezone", timezone)  # type: ignore
        routers = config.get("routers", routers)  # type: ignore

        # Mask passwords in logs output
        passwords = []
        for name, info in connections_config.items():
            if isinstance(info, str):
                info = expand_db_url(info)
            password = info.get("credentials", {}).get("password")
            if password:
                passwords.append(password)

        str_connection_config = str(connections_config)
        for password in passwords:
            str_connection_config = str_connection_config.replace(
                password,
                # Show one third of the password at beginning (may be better for debugging purposes)
                f"{password[0:len(password) // 3]}***",
            )

        logger.debug(
            "Tortoise-ORM startup\n    connections: %s\n    apps: %s",
            str_connection_config,
            str(apps_config),
        )

        cls._init_timezone(use_tz, timezone)
        await cls._init_connections(connections_config, _create_db)
        cls._init_apps(apps_config)
        cls._init_routers(routers)

        cls._inited = True
Esempio n. 25
0
    def _init_relations(cls) -> None:
        def get_related_model(related_app_name: str, related_model_name: str) -> Type["Model"]:
            """
            Test, if app and model really exist. Throws a ConfigurationError with a hopefully
            helpful message. If successfull, returns the requested model.

            :raises ConfigurationError: If no such app exists.
            """
            try:
                return cls.apps[related_app_name][related_model_name]
            except KeyError:
                if related_app_name not in cls.apps:
                    raise ConfigurationError(f"No app with name '{related_app_name}' registered.")
                raise ConfigurationError(
                    f"No model with name '{related_model_name}' registered in"
                    f" app '{related_app_name}'."
                )

        def split_reference(reference: str) -> Tuple[str, str]:
            """
            Test, if reference follow the official naming conventions. Throws a
            ConfigurationError with a hopefully helpful message. If successfull,
            returns the app and the model name.

            :raises ConfigurationError: If no model reference is invalid.
            """
            items = reference.split(".")
            if len(items) != 2:  # pragma: nocoverage
                raise ConfigurationError(
                    (
                        "'%s' is not a valid model reference Bad Reference."
                        " Should be something like <appname>.<modelname>."
                    )
                    % reference
                )

            return (items[0], items[1])

        for app_name, app in cls.apps.items():
            for model_name, model in app.items():
                if model._meta._inited:
                    continue
                model._meta._inited = True
                if not model._meta.db_table:
                    model._meta.db_table = model.__name__.lower()

                # TODO: refactor to share logic between FK & O2O
                for field in model._meta.fk_fields:
                    fk_object = cast(ForeignKeyFieldInstance, model._meta.fields_map[field])
                    reference = fk_object.model_name
                    related_app_name, related_model_name = split_reference(reference)
                    related_model = get_related_model(related_app_name, related_model_name)

                    if fk_object.to_field:
                        related_field = related_model._meta.fields_map.get(fk_object.to_field, None)
                        if related_field:
                            if related_field.unique:
                                key_fk_object = deepcopy(related_field)
                                fk_object.to_field_instance = related_field
                            else:
                                raise ConfigurationError(
                                    f'field "{fk_object.to_field}" in model'
                                    f' "{related_model_name}" is not unique'
                                )
                        else:
                            raise ConfigurationError(
                                f'there is no field named "{fk_object.to_field}"'
                                f' in model "{related_model_name}"'
                            )
                    else:
                        key_fk_object = deepcopy(related_model._meta.pk)
                        fk_object.to_field_instance = related_model._meta.pk
                        fk_object.to_field = related_model._meta.pk_attr

                    key_field = f"{field}_id"
                    key_fk_object.pk = False
                    key_fk_object.unique = False
                    key_fk_object.index = fk_object.index
                    key_fk_object.default = fk_object.default
                    key_fk_object.null = fk_object.null
                    key_fk_object.generated = fk_object.generated
                    key_fk_object.reference = fk_object
                    key_fk_object.description = fk_object.description
                    if fk_object.source_field:
                        key_fk_object.source_field = fk_object.source_field
                    else:
                        key_fk_object.source_field = key_field
                    model._meta.add_field(key_field, key_fk_object)

                    fk_object.related_model = related_model
                    fk_object.source_field = key_field
                    backward_relation_name = fk_object.related_name
                    if backward_relation_name is not False:
                        if not backward_relation_name:
                            backward_relation_name = f"{model._meta.db_table}s"
                        if backward_relation_name in related_model._meta.fields:
                            raise ConfigurationError(
                                f'backward relation "{backward_relation_name}" duplicates in'
                                f" model {related_model_name}"
                            )
                        fk_relation = BackwardFKRelation(
                            model,
                            f"{field}_id",
                            key_fk_object.source_field,
                            fk_object.null,
                            fk_object.description,
                        )
                        fk_relation.to_field_instance = fk_object.to_field_instance
                        related_model._meta.add_field(backward_relation_name, fk_relation)

                for field in model._meta.o2o_fields:
                    o2o_object = cast(OneToOneFieldInstance, model._meta.fields_map[field])
                    reference = o2o_object.model_name
                    related_app_name, related_model_name = split_reference(reference)
                    related_model = get_related_model(related_app_name, related_model_name)

                    if o2o_object.to_field:
                        related_field = related_model._meta.fields_map.get(
                            o2o_object.to_field, None
                        )
                        if related_field:
                            if related_field.unique:
                                key_o2o_object = deepcopy(related_field)
                                o2o_object.to_field_instance = related_field
                            else:
                                raise ConfigurationError(
                                    f'field "{o2o_object.to_field}" in model'
                                    f' "{related_model_name}" is not unique'
                                )
                        else:
                            raise ConfigurationError(
                                f'there is no field named "{o2o_object.to_field}"'
                                f' in model "{related_model_name}"'
                            )
                    else:
                        key_o2o_object = deepcopy(related_model._meta.pk)
                        o2o_object.to_field_instance = related_model._meta.pk
                        o2o_object.to_field = related_model._meta.pk_attr

                    key_field = f"{field}_id"
                    key_o2o_object.pk = o2o_object.pk
                    key_o2o_object.index = o2o_object.index
                    key_o2o_object.default = o2o_object.default
                    key_o2o_object.null = o2o_object.null
                    key_o2o_object.unique = o2o_object.unique
                    key_o2o_object.generated = o2o_object.generated
                    key_o2o_object.reference = o2o_object
                    key_o2o_object.description = o2o_object.description
                    if o2o_object.source_field:
                        key_o2o_object.source_field = o2o_object.source_field
                    else:
                        key_o2o_object.source_field = key_field
                    model._meta.add_field(key_field, key_o2o_object)

                    o2o_object.related_model = related_model
                    o2o_object.source_field = key_field
                    backward_relation_name = o2o_object.related_name
                    if backward_relation_name is not False:
                        if not backward_relation_name:
                            backward_relation_name = f"{model._meta.db_table}"
                        if backward_relation_name in related_model._meta.fields:
                            raise ConfigurationError(
                                f'backward relation "{backward_relation_name}" duplicates in'
                                f" model {related_model_name}"
                            )
                        o2o_relation = BackwardOneToOneRelation(
                            model,
                            f"{field}_id",
                            key_o2o_object.source_field,
                            null=True,
                            description=o2o_object.description,
                        )
                        o2o_relation.to_field_instance = o2o_object.to_field_instance
                        related_model._meta.add_field(backward_relation_name, o2o_relation)

                    if o2o_object.pk:
                        model._meta.pk_attr = key_field

                for field in list(model._meta.m2m_fields):
                    m2m_object = cast(ManyToManyFieldInstance, model._meta.fields_map[field])
                    if m2m_object._generated:
                        continue

                    backward_key = m2m_object.backward_key
                    if not backward_key:
                        backward_key = f"{model._meta.db_table}_id"
                        if backward_key == m2m_object.forward_key:
                            backward_key = f"{model._meta.db_table}_rel_id"
                        m2m_object.backward_key = backward_key

                    reference = m2m_object.model_name
                    related_app_name, related_model_name = split_reference(reference)
                    related_model = get_related_model(related_app_name, related_model_name)

                    m2m_object.related_model = related_model

                    backward_relation_name = m2m_object.related_name
                    if not backward_relation_name:
                        backward_relation_name = (
                            m2m_object.related_name
                        ) = f"{model._meta.db_table}s"
                    if backward_relation_name in related_model._meta.fields:
                        raise ConfigurationError(
                            f'backward relation "{backward_relation_name}" duplicates in'
                            f" model {related_model_name}"
                        )

                    if not m2m_object.through:
                        related_model_table_name = (
                            related_model._meta.db_table
                            if related_model._meta.db_table
                            else related_model.__name__.lower()
                        )

                        m2m_object.through = f"{model._meta.db_table}_{related_model_table_name}"

                    m2m_relation = ManyToManyFieldInstance(
                        f"{app_name}.{model_name}",
                        m2m_object.through,
                        forward_key=m2m_object.backward_key,
                        backward_key=m2m_object.forward_key,
                        related_name=field,
                        field_type=model,
                        description=m2m_object.description,
                    )
                    m2m_relation._generated = True
                    model._meta.filters.update(get_m2m_filters(field, m2m_object))
                    related_model._meta.add_field(backward_relation_name, m2m_relation)
Esempio n. 26
0
    def __new__(mcs, name: str, bases, attrs: dict, *args, **kwargs):
        fields_db_projection: Dict[str, str] = {}
        fields_map: Dict[str, fields.Field] = {}
        filters: Dict[str, Dict[str, dict]] = {}
        fk_fields: Set[str] = set()
        m2m_fields: Set[str] = set()
        o2o_fields: Set[str] = set()
        meta_class = attrs.get("Meta", type("Meta", (), {}))
        pk_attr: str = "id"

        # Searching for Field attributes in the class hierarchy
        def __search_for_field_attributes(base, attrs: dict):
            """
            Searching for class attributes of type fields.Field
            in the given class.

            If an attribute of the class is an instance of fields.Field,
            then it will be added to the fields dict. But only, if the
            key is not already in the dict. So derived classes have a higher
            precedence. Multiple Inheritance is supported from left to right.

            After checking the given class, the function will look into
            the classes according to the MRO (method resolution order).

            The MRO is 'natural' order, in which python traverses methods and
            fields. For more information on the magic behind check out:
            `The Python 2.3 Method Resolution Order
            <https://www.python.org/download/releases/2.3/mro/>`_.
            """
            for parent in base.__mro__[1:]:
                __search_for_field_attributes(parent, attrs)
            meta = getattr(base, "_meta", None)
            if meta:
                # For abstract classes
                for key, value in meta.fields_map.items():
                    attrs[key] = value
            else:
                # For mixin classes
                for key, value in base.__dict__.items():
                    if isinstance(value, fields.Field) and key not in attrs:
                        attrs[key] = value

        # Start searching for fields in the base classes.
        inherited_attrs: dict = {}
        for base in bases:
            __search_for_field_attributes(base, inherited_attrs)
        if inherited_attrs:
            # Ensure that the inherited fields are before the defined ones.
            attrs = {**inherited_attrs, **attrs}

        if name != "Model":
            custom_pk_present = False
            for key, value in attrs.items():
                if isinstance(value, fields.Field):
                    if value.pk:
                        if custom_pk_present:
                            raise ConfigurationError(
                                f"Can't create model {name} with two primary keys,"
                                " only single pk are supported")
                        if value.generated and not isinstance(
                                value, (fields.SmallIntField, fields.IntField,
                                        fields.BigIntField)):
                            raise ConfigurationError(
                                "Generated primary key allowed only for IntField and BigIntField"
                            )
                        custom_pk_present = True
                        pk_attr = key

            if not custom_pk_present and not getattr(meta_class, "abstract",
                                                     None):
                if "id" not in attrs:
                    attrs = {"id": fields.IntField(pk=True), **attrs}

                if not isinstance(attrs["id"],
                                  fields.Field) or not attrs["id"].pk:
                    raise ConfigurationError(
                        f"Can't create model {name} without explicit primary key if field 'id'"
                        " already present")

            for key, value in attrs.items():
                if isinstance(value, fields.Field):
                    if getattr(meta_class, "abstract", None):
                        value = deepcopy(value)

                    fields_map[key] = value
                    value.model_field_name = key

                    if isinstance(value, fields.ForeignKeyField):
                        fk_fields.add(key)
                    elif isinstance(value, fields.OneToOneField):
                        o2o_fields.add(key)
                    elif isinstance(value, fields.ManyToManyFieldInstance):
                        m2m_fields.add(key)
                    else:
                        fields_db_projection[key] = value.source_field or key
                        filters.update(
                            get_filters_for_field(
                                field_name=key,
                                field=fields_map[key],
                                source_field=fields_db_projection[key],
                            ))
                        if value.pk:
                            filters.update(
                                get_filters_for_field(
                                    field_name="pk",
                                    field=fields_map[key],
                                    source_field=fields_db_projection[key],
                                ))

        # Clean the class attributes
        for slot in fields_map:
            attrs.pop(slot, None)
        attrs["_meta"] = meta = MetaInfo(meta_class)

        meta.fields_map = fields_map
        meta.fields_db_projection = fields_db_projection
        meta._filters = filters
        meta.fk_fields = fk_fields
        meta.backward_fk_fields = set()
        meta.o2o_fields = o2o_fields
        meta.backward_o2o_fields = set()
        meta.m2m_fields = m2m_fields
        meta.default_connection = None
        meta.pk_attr = pk_attr
        meta._inited = False
        if not fields_map:
            meta.abstract = True

        new_class: "Model" = super().__new__(mcs, name, bases,
                                             attrs)  # type: ignore
        for field in meta.fields_map.values():
            field.model = new_class

        meta._model = new_class
        meta.finalise_fields()
        return new_class
Esempio n. 27
0
 def db(self) -> BaseDBAsyncClient:
     try:
         return current_transaction_map[self.default_connection].get()
     except KeyError:
         raise ConfigurationError("No DB associated to model")
Esempio n. 28
0
    async def init(
        cls,
        config: Optional[dict] = None,
        config_file: Optional[str] = None,
        _create_db: bool = False,
        db_url: Optional[str] = None,
        modules: Optional[Dict[str, List[str]]] = None,
        use_tz: bool = False,
        timezone: str = "UTC",
    ) -> None:
        """
        Sets up Tortoise-ORM.

        You can configure using only one of ``config``, ``config_file``
        and ``(db_url, modules)``.

        :param config:
            Dict containing config:

            .. admonition:: Example

                .. code-block:: python3

                    {
                        'connections': {
                            # Dict format for connection
                            'default': {
                                'engine': 'tortoise.backends.asyncpg',
                                'credentials': {
                                    'host': 'localhost',
                                    'port': '5432',
                                    'user': '******',
                                    'password': '******',
                                    'database': 'test',
                                }
                            },
                            # Using a DB_URL string
                            'default': 'postgres://*****:*****@localhost:5432/test'
                        },
                        'apps': {
                            'my_app': {
                                'models': ['__main__'],
                                # If no default_connection specified, defaults to 'default'
                                'default_connection': 'default',
                            }
                        },
                        'use_tz': False,
                        'timezone': UTC
                    }

        :param config_file:
            Path to .json or .yml (if PyYAML installed) file containing config with
            same format as above.
        :param db_url:
            Use a DB_URL string. See :ref:`db_url`
        :param modules:
            Dictionary of ``key``: [``list_of_modules``] that defined "apps" and modules that
            should be discovered for models.
        :param _create_db:
            If ``True`` tries to create database for specified connections,
            could be used for testing purposes.
        :param use_tz:
            A boolean that specifies if datetime will be timezone-aware by default or not.
        :param timezone:
            Timezone to use, default is UTC.

        :raises ConfigurationError: For any configuration error
        """
        if cls._inited:
            await cls.close_connections()
            await cls._reset_apps()
        if int(bool(config) + bool(config_file) + bool(db_url)) != 1:
            raise ConfigurationError(
                'You should init either from "config", "config_file" or "db_url"'
            )

        if config_file:
            config = cls._get_config_from_config_file(config_file)

        if db_url:
            if not modules:
                raise ConfigurationError('You must specify "db_url" and "modules" together')
            config = generate_config(db_url, modules)

        try:
            connections_config = config["connections"]  # type: ignore
        except KeyError:
            raise ConfigurationError('Config must define "connections" section')

        try:
            apps_config = config["apps"]  # type: ignore
        except KeyError:
            raise ConfigurationError('Config must define "apps" section')

        use_tz = config.get("use_tz", use_tz)  # type: ignore
        timezone = config.get("timezone", timezone)  # type: ignore

        logger.info(
            "Tortoise-ORM startup\n    connections: %s\n    apps: %s",
            str(connections_config),
            str(apps_config),
        )

        cls._init_timezone(use_tz, timezone)
        await cls._init_connections(connections_config, _create_db)
        cls._init_apps(apps_config)

        cls._inited = True
Esempio n. 29
0
 def __init__(self, max_length: int = 0, **kwargs) -> None:
     if int(max_length) < 1:
         raise ConfigurationError('max_digits must be >= 1')
     self.max_length = int(max_length)
     super().__init__(str, **kwargs)
Esempio n. 30
0
 def __init__(self, enum_type: Type[Enum], **kwargs):
     super().__init__(128, **kwargs)
     if not issubclass(enum_type, Enum):
         raise ConfigurationError(
             "{} is not a subclass of Enum!".format(enum_type))
     self._enum_type = enum_type