コード例 #1
0
    def remove_field_constraints(self, field, opts, models, refs):
        """Return SQL for removing constraints on a field.

        Args:
            field (django.db.models.Field):
                The field the constraints will be removed from.

            opts (django.db.models.options.Options):
                The Meta class for the model.

            models (list of django.db.models.Model):
                A caller-provided list that will be populated with models
                that constraints will be removed from.

            refs (dict):
                A caller-supplied dictionary that will be populated with
                references that are removed.

                The keys are models, and the values are lists of
                tuples of many-to-many models and fields.

        Returns:
            list:
            The list of SQL statements for removing constraints on the field.
        """
        sql = []

        if self.supports_constraints and field.primary_key:
            for f in opts.local_many_to_many:
                remote_field = get_remote_field(f)

                if remote_field and remote_field.through:
                    through = remote_field.through

                    for m2m_f in through._meta.local_fields:
                        remote_m2m_f = get_remote_field(m2m_f)

                        if not remote_m2m_f:
                            continue

                        remote_m2m_f_model = \
                            get_remote_field_model(remote_m2m_f)

                        if (remote_m2m_f.field_name == field.column
                                and remote_m2m_f_model._meta.db_table
                                == opts.db_table):
                            models.append(remote_m2m_f_model)
                            refs.setdefault(remote_m2m_f_model, []).append(
                                (through, m2m_f))

            remove_refs = refs.copy()

            for relto in models:
                sql.extend(
                    sql_delete_constraints(self.connection, relto,
                                           remove_refs))

        return sql
コード例 #2
0
    def rename_table(self, model, old_db_tablename, db_tablename):
        sql_result = SQLResult()

        if old_db_tablename == db_tablename:
            # No Operation
            return sql_result

        max_name_length = self.connection.ops.max_name_length()

        refs = {}
        models = []

        for field in model._meta.local_many_to_many:
            remote_field = get_remote_field(field)

            if (remote_field and remote_field.through and
                    remote_field.through._meta.db_table == old_db_tablename):

                through = remote_field.through

                for m2m_field in through._meta.local_fields:
                    remote_m2m_field = get_remote_field(m2m_field)

                    if remote_m2m_field:
                        remote_m2m_field_model = get_remote_field_model(
                            remote_m2m_field)

                        if remote_m2m_field_model == model:
                            models.append(remote_m2m_field_model)
                            refs.setdefault(remote_m2m_field_model, []).append(
                                (through, m2m_field))

        remove_refs = refs.copy()

        if self.supports_constraints:
            for relto in models:
                sql_result.add_pre_sql(
                    sql_delete_constraints(self.connection, relto,
                                           remove_refs))

        sql_result.add(
            self.get_rename_table_sql(model, old_db_tablename, db_tablename))

        for relto in models:
            for rel_class, f in refs[relto]:
                if rel_class._meta.db_table == old_db_tablename:
                    rel_class._meta.db_table = db_tablename

                rel_class._meta.db_table = \
                    truncate_name(rel_class._meta.db_table,
                                  max_name_length)

            if self.supports_constraints:
                sql_result.add_post_sql(
                    sql_add_constraints(self.connection, relto, refs))

        return sql_result
コード例 #3
0
    def test_with_many_to_many_field_rel(self):
        """Testing get_field_is_relation with ManyToManyField relation"""
        rel = get_remote_field(
            CompatModelsTestModel._meta.get_field('m2m_field'))

        self.assertIsInstance(rel, related.ManyToManyRel)
        self.assertTrue(get_field_is_relation(rel))
コード例 #4
0
    def test_with_one_to_one_field_rel(self):
        """Testing get_remote_field_related_model with OneToOneField relation
        """
        field = CompatModelsTestModel._meta.get_field('o2o_field')
        model = get_remote_field_related_model(get_remote_field(field))

        self.assertIs(model, CompatModelsTestModel)
コード例 #5
0
    def test_with_foreign_key_rel(self):
        """Testing get_field_is_relation with ForeignKey relation"""
        rel = get_remote_field(
            CompatModelsTestModel._meta.get_field('fkey_field'))

        self.assertIsInstance(rel, related.ManyToOneRel)
        self.assertTrue(get_field_is_relation(rel))
コード例 #6
0
    def get_default_index_name(self, table_name, field):
        """Return a default index name for the database.

        This will return an index name for the given field that matches what
        the database or Django database backend would automatically generate
        when marking a field as indexed or unique.

        This can be overridden by subclasses if the database or Django
        database backend provides different values.

        Args:
            table_name (str):
                The name of the table for the index.

            field (django.db.models.Field):
                The field for the index.

        Returns:
            str:
            The name of the index.
        """
        if (hasattr(self.connection, 'schema_editor')
                and get_remote_field(field) and field.db_constraint):
            # Django >= 1.7
            target_field = get_rel_target_field(field)

            return self.connection.schema_editor()._create_index_name(
                field.model, [field.column],
                suffix='_fk_%s_%s' %
                (target_field.model._meta.db_table, target_field.column))

        return super(EvolutionOperations,
                     self).get_default_index_name(table_name, field)
コード例 #7
0
    def test_with_one_to_one_field_rel(self):
        """Testing get_field_is_relation with OneToOneField relation"""
        rel = get_remote_field(
            CompatModelsTestModel._meta.get_field('o2o_field'))

        self.assertIsInstance(rel, related.ManyToOneRel)
        self.assertTrue(get_field_is_relation(rel))
コード例 #8
0
    def test_with_many_to_many_field_rel(self):
        """Testing get_remote_field_related_model with ManyToManyField
        relation
        """
        field = CompatModelsTestModel._meta.get_field('m2m_field')
        model = get_remote_field_related_model(get_remote_field(field))

        self.assertIs(model, CompatModelsTestModel)
コード例 #9
0
    def create_table(self,
                     table_name,
                     field_list,
                     temporary=False,
                     create_index=True):
        qn = self.connection.ops.quote_name
        output = []

        create = ['CREATE']
        if temporary:
            create.append('TEMPORARY')
        create.append('TABLE %s' % qn(table_name))
        output = [' '.join(create)]
        output.append('(')
        columns = []

        for field in field_list:
            if type(field) is not models.ManyToManyField:
                column_name = qn(field.column)
                column_type = field.db_type(connection=self.connection)
                params = [column_name, column_type]

                # Always use null if this is a temporary table. It may be
                # used to create a new field (which will be null while data is
                # copied across from the old table).
                if temporary or field.null:
                    params.append('NULL')
                else:
                    params.append('NOT NULL')

                if field.unique:
                    params.append('UNIQUE')

                if field.primary_key:
                    params.append('PRIMARY KEY')

                if not temporary and isinstance(field, models.ForeignKey):
                    remote_field = get_remote_field(field)
                    remote_field_model = get_remote_field_model(remote_field)

                    params.append(
                        'REFERENCES %s (%s) DEFERRABLE INITIALLY DEFERRED' %
                        (qn(remote_field_model._meta.db_table),
                         qn(
                             remote_field_model._meta.get_field(
                                 remote_field.field_name).column)))

                columns.append(' '.join(params))

        output.append(', '.join(columns))
        output.append(');')
        output = [''.join(output)]

        if create_index:
            output.extend(self.create_indexes_for_table(
                table_name, field_list))

        return output
コード例 #10
0
def get_model_rel_tree():
    """Return the full field relationship tree for all registered models.

    This will walk through every field in every model registered in Django,
    storing the relationships between objects, caching them. Each entry in
    the resulting dictionary will be a table mapping to a list of relation
    fields that point back at it.

    This can be used to quickly locate any and all reverse relations made to
    a field.

    This is similar to Django's built-in reverse relation tree used internally
    (with different implementations) in
    :py:class:`django.db.models.options.Options`, but works across all
    supported versions of Django, and supports cache clearing.

    Version Added:
        2.2

    Returns:
        dict:
        The model relation tree.
    """
    global _rel_tree_cache

    if _rel_tree_cache is not None:
        return _rel_tree_cache

    rel_tree = defaultdict(list)
    all_models = get_models(include_auto_created=True)

    # We'll walk the entire model tree, looking for any immediate fields on
    # each model, building a mapping of models to fields that reference the
    # model.
    for cur_model in all_models:
        if cur_model._meta.abstract:
            continue

        for field in iter_model_fields(cur_model,
                                       include_parent_models=False,
                                       include_forward_fields=True,
                                       include_reverse_fields=False,
                                       include_hidden_fields=False):
            if (get_field_is_relation(field)
                    and get_remote_field_related_model(field) is not None):
                remote_field = get_remote_field(field)
                remote_field_model = get_remote_field_model(remote_field)

                # Make sure this isn't a "self" relation or similar.
                if not isinstance(remote_field_model, six.string_types):
                    db_table = \
                        remote_field_model._meta.concrete_model._meta.db_table
                    rel_tree[db_table].append(field)

    _rel_tree_cache = rel_tree

    return rel_tree
コード例 #11
0
    def test_has_model_with_auto_created(self):
        """Testing DatabaseState.has_model with auto-created model"""
        model = get_remote_field(User._meta.get_field('groups')).through
        self.assertTrue(model._meta.auto_created)

        database_state = DatabaseState(db_name='default', scan=False)
        self.assertFalse(database_state.has_model(model))

        database_state.rescan_tables()
        self.assertTrue(database_state.has_model(model))
コード例 #12
0
    def create_table(self, table_name, field_list, temporary=False,
                     create_index=True):
        qn = self.connection.ops.quote_name
        output = []

        create = ['CREATE']
        if temporary:
            create.append('TEMPORARY')
        create.append('TABLE %s' % qn(table_name))
        output = [' '.join(create)]
        output.append('(')
        columns = []

        for field in field_list:
            if type(field) is not models.ManyToManyField:
                column_name = qn(field.column)
                column_type = field.db_type(connection=self.connection)
                params = [column_name, column_type]

                # Always use null if this is a temporary table. It may be
                # used to create a new field (which will be null while data is
                # copied across from the old table).
                if temporary or field.null:
                    params.append('NULL')
                else:
                    params.append('NOT NULL')

                if field.unique:
                    params.append('UNIQUE')

                if field.primary_key:
                    params.append('PRIMARY KEY')

                if not temporary and isinstance(field, models.ForeignKey):
                    remote_field = get_remote_field(field)
                    remote_field_model = get_remote_field_model(remote_field)

                    params.append(
                        'REFERENCES %s (%s) DEFERRABLE INITIALLY DEFERRED'
                        % (qn(remote_field_model._meta.db_table),
                           qn(remote_field_model._meta.get_field(
                               remote_field.field_name).column)))

                columns.append(' '.join(params))

        output.append(', '.join(columns))
        output.append(');')
        output = [''.join(output)]

        if create_index:
            output.extend(
                self.create_indexes_for_table(table_name, field_list))

        return output
コード例 #13
0
ファイル: db.py プロジェクト: beanbaginc/django-evolution
def sql_create_for_many_to_many_field(connection, model, field):
    """Return SQL statements for creating a ManyToManyField's table.

    This provides compatibility with all supported versions of Django.

    Args:
        connection (object):
            The database connection.

        model (django.db.models.Model):
            The model for the ManyToManyField's relations.

        field (django.db.models.ManyToManyField):
            The field setting up the many-to-many relation.

    Returns:
        list:
        The list of SQL statements for creating the table and constraints.
    """
    through = get_remote_field(field).through

    if BaseDatabaseSchemaEditor:
        # Django >= 1.7
        with connection.schema_editor(collect_sql=True) as schema_editor:
            schema_editor.create_model(through)

        return schema_editor.collected_sql
    else:
        # Django < 1.7
        style = color.no_style()

        if through:
            references = {}
            pending_references = {}

            sql, references = connection.creation.sql_create_model(
                through, style)

            # Sort the list, in order to create consistency in the order of
            # ALTER TABLEs. This is primarily needed for unit tests.
            for refto, refs in sorted(six.iteritems(references),
                                      key=lambda i: repr(i)):
                pending_references.setdefault(refto, []).extend(refs)
                sql.extend(sql_add_constraints(connection, refto,
                                               pending_references))

            sql.extend(sql_add_constraints(connection, through,
                                           pending_references))
        else:
            sql = connection.creation.sql_for_many_to_many_field(
                model, field, style)

        return sql
コード例 #14
0
def sql_create_for_many_to_many_field(connection, model, field):
    """Return SQL statements for creating a ManyToManyField's table.

    This provides compatibility with all supported versions of Django.

    Args:
        connection (object):
            The database connection.

        model (django.db.models.Model):
            The model for the ManyToManyField's relations.

        field (django.db.models.ManyToManyField):
            The field setting up the many-to-many relation.

    Returns:
        list:
        The list of SQL statements for creating the table and constraints.
    """
    through = get_remote_field(field).through

    if BaseDatabaseSchemaEditor:
        # Django >= 1.7
        with connection.schema_editor(collect_sql=True) as schema_editor:
            schema_editor.create_model(through)

        return schema_editor.collected_sql
    else:
        # Django < 1.7
        style = color.no_style()

        if through:
            references = {}
            pending_references = {}

            sql, references = connection.creation.sql_create_model(
                through, style)

            # Sort the list, in order to create consistency in the order of
            # ALTER TABLEs. This is primarily needed for unit tests.
            for refto, refs in sorted(six.iteritems(references),
                                      key=lambda i: repr(i)):
                pending_references.setdefault(refto, []).extend(refs)
                sql.extend(
                    sql_add_constraints(connection, refto, pending_references))

            sql.extend(
                sql_add_constraints(connection, through, pending_references))
        else:
            sql = connection.creation.sql_for_many_to_many_field(
                model, field, style)

        return sql
コード例 #15
0
    def _get_rename_column_sql(self, opts, old_field, new_field):
        qn = self.connection.ops.quote_name
        style = color.no_style()
        col_type = new_field.db_type(connection=self.connection)
        tablespace = new_field.db_tablespace or opts.db_tablespace
        alter_table_item = ''

        # Make the definition (e.g. 'foo VARCHAR(30)') for this field.
        field_output = [
            style.SQL_FIELD(qn(new_field.column)),
            style.SQL_COLTYPE(col_type),
            style.SQL_KEYWORD('%sNULL' %
                              (not new_field.null and 'NOT ' or '')),
        ]

        if new_field.primary_key:
            field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))

        if new_field.unique:
            field_output.append(style.SQL_KEYWORD('UNIQUE'))

        if (tablespace and self.connection.features.supports_tablespaces
                and self.connection.features.autoindexes_primary_keys
                and (new_field.unique or new_field.primary_key)):
            # We must specify the index tablespace inline, because we
            # won't be generating a CREATE INDEX statement for this field.
            field_output.append(
                self.connection.ops.tablespace_sql(tablespace, inline=True))

        new_remote_field = get_remote_field(new_field)

        if new_remote_field:
            new_remote_field_meta = \
                get_remote_field_model(new_remote_field)._meta

            field_output.append('%s %s (%s)%s' % (
                style.SQL_KEYWORD('REFERENCES'),
                style.SQL_TABLE(qn(new_remote_field_meta.db_table)),
                style.SQL_FIELD(
                    qn(
                        new_remote_field_meta.get_field(
                            new_remote_field.field_name).column)),
                self.connection.ops.deferrable_sql(),
            ))

        if old_field.primary_key:
            alter_table_item = 'DROP PRIMARY KEY, '

        alter_table_item += ('CHANGE COLUMN %s %s' %
                             (qn(old_field.column), ' '.join(field_output)))

        return [{'sql': alter_table_item}]
コード例 #16
0
    def add_primary_key_field_constraints(self, old_field, new_field, models,
                                          refs):
        sql = []

        if self.supports_constraints and old_field.primary_key:
            for relto in models:
                for rel_class, f in refs[relto]:
                    get_remote_field(f).field_name = new_field.column

                del relto._meta._fields[old_field.name]
                relto._meta._fields[new_field.name] = new_field

                sql.extend(sql_add_constraints(self.connection, relto, refs))

        return sql
コード例 #17
0
    def delete_column(self, model, f):
        sql_result = AlterTableSQLResult(self, model)

        remote_field = get_remote_field(f)

        if remote_field:
            remote_field_model = get_remote_field_model(remote_field)

            sql_result.add(
                sql_delete_constraints(self.connection, remote_field_model,
                                       {remote_field_model: [(model, f)]}))

        sql_result.add_sql(
            super(EvolutionOperations, self).delete_column(model, f))

        return sql_result
コード例 #18
0
def iter_non_m2m_reverse_relations(field):
    """Iterate through non-M2M reverse relations pointing to a field.

    This will exclude any :py:class:`~django.db.models.ManyToManyField`s,
    but will include the relation fields on their "through" tables.

    Note that this may return duplicate results, or multiple relations
    pointing to the same field. It's up to the caller to handle this.

    Version Added:
        2.2

    Args:
        field (django.db.models.Field):
            The field that relations must point to.

    Yields:
        django.db.models.Field or object:
        Each field or relation object pointing to this field.

        The type of the relation object depends on the version of Django.
    """
    is_primary_key = field.primary_key
    field_name = field.name

    for rel in iter_model_fields(field.model,
                                 include_parent_models=True,
                                 include_forward_fields=False,
                                 include_reverse_fields=True,
                                 include_hidden_fields=True):
        rel_from_field = rel.field

        # Exclude any ManyToManyFields, and make sure the referencing fields
        # point directly to the ID on this field.
        if (not get_field_is_many_to_many(rel_from_field)
                and ((is_primary_key and rel_from_field.to_fields == [None])
                     or field_name in rel_from_field.to_fields)):
            yield rel

            # Now do the same for the fields on the model of the related field.
            other_rel_fields = iter_non_m2m_reverse_relations(
                get_remote_field(rel))

            for rel2 in other_rel_fields:
                yield rel2
コード例 #19
0
    def add_column(self, model, f, initial):
        qn = self.connection.ops.quote_name
        sql_result = AlterTableSQLResult(self, model)
        table_name = model._meta.db_table

        remote_field = get_remote_field(f)

        if remote_field:
            # it is a foreign key field
            # NOT NULL REFERENCES "django_evolution_addbasemodel"
            # ("id") DEFERRABLE INITIALLY DEFERRED

            # ALTER TABLE <tablename> ADD COLUMN <column name> NULL
            # REFERENCES <tablename1> ("<colname>") DEFERRABLE INITIALLY
            # DEFERRED
            related_model = get_remote_field_model(remote_field)
            related_table = related_model._meta.db_table
            related_pk_col = related_model._meta.pk.name
            constraints = ['%sNULL' % (not f.null and 'NOT ' or '')]

            if f.unique or f.primary_key:
                constraints.append('UNIQUE')

            sql_result.add_alter_table([{
                'op':
                'ADD COLUMN',
                'column':
                f.column,
                'db_type':
                f.db_type(connection=self.connection),
                'params':
                constraints + [
                    'REFERENCES',
                    qn(related_table),
                    '(%s)' % qn(related_pk_col),
                    self.connection.ops.deferrable_sql(),
                ]
            }])
        else:
            null_constraints = '%sNULL' % (not f.null and 'NOT ' or '')

            if f.unique or f.primary_key:
                unique_constraints = 'UNIQUE'
            else:
                unique_constraints = ''

            # At this point, initial can only be None if null=True,
            # otherwise it is a user callable or the default
            # AddFieldInitialCallback which will shortly raise an exception.
            if initial is not None:
                if callable(initial):
                    sql_result.add_alter_table([{
                        'op':
                        'ADD COLUMN',
                        'column':
                        f.column,
                        'db_type':
                        f.db_type(connection=self.connection),
                        'params': [unique_constraints],
                    }])

                    sql_result.add_sql([
                        'UPDATE %s SET %s = %s WHERE %s IS NULL;' %
                        (qn(table_name), qn(f.column), initial(), qn(f.column))
                    ])

                    if not f.null:
                        # Only put this sql statement if the column cannot
                        # be null.
                        sql_result.add_sql(
                            self.set_field_null(model, f, f.null))
                else:
                    sql_result.add_alter_table([{
                        'op':
                        'ADD COLUMN',
                        'column':
                        f.column,
                        'db_type':
                        f.db_type(connection=self.connection),
                        'params': [
                            null_constraints,
                            unique_constraints,
                            'DEFAULT',
                            '%s',
                        ],
                        'sql_params': [initial]
                    }])

                    # Django doesn't generate default columns, so now that
                    # we've added one to get default values for existing
                    # tables, drop that default.
                    sql_result.add_post_sql([
                        'ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;' %
                        (qn(table_name), qn(f.column))
                    ])
            else:
                sql_result.add_alter_table([{
                    'op':
                    'ADD COLUMN',
                    'column':
                    f.column,
                    'db_type':
                    f.db_type(connection=self.connection),
                    'params': [null_constraints, unique_constraints],
                }])

        if f.unique or f.primary_key:
            self.database_state.add_index(
                table_name=table_name,
                index_name=self.get_new_constraint_name(table_name, f.column),
                columns=[f.column],
                unique=True)

        return sql_result
コード例 #20
0
def register_models(database_state, models, register_indexes=False,
                    new_app_label='tests', db_name='default', app=evo_test):
    """Register models for testing purposes.

    Args:
        database_state (django_evolution.db.state.DatabaseState):
            The database state to populate with model information.

        models (list of django.db.models.Model):
            The models to register.

        register_indexes (bool, optional):
            Whether indexes should be registered for any models. Defaults to
            ``False``.

        new_app_label (str, optional):
            The label for the test app. Defaults to "tests".

        db_name (str, optional):
            The name of the database connection. Defaults to "default".

        app (module, optional):
            The application module for the test models.

    Returns:
        collections.OrderedDict:
        A dictionary of registered models. The keys are model names, and
        the values are the models.
    """
    app_cache = OrderedDict()
    evolver = EvolutionOperationsMulti(db_name, database_state).get_evolver()

    db_connection = connections[db_name or DEFAULT_DB_ALIAS]
    max_name_length = db_connection.ops.max_name_length()

    for new_object_name, model in reversed(models):
        # Grab some state from the model's meta instance. Some of this will
        # be original state that we'll keep around to help us unregister old
        # values and compute new ones.
        meta = model._meta

        orig_app_label = meta.app_label
        orig_db_table = meta.db_table
        orig_object_name = meta.object_name
        orig_model_name = get_model_name(model)

        # Find out if the table name being used is a custom table name, or
        # one generated by Django.
        new_model_name = new_object_name.lower()
        new_db_table = orig_db_table

        generated_db_table = truncate_name(
            '%s_%s' % (orig_app_label, orig_model_name),
            max_name_length)

        if orig_db_table == generated_db_table:
            # It was a generated one, so replace it with a version containing
            # the new model and app names.
            new_db_table = truncate_name('%s_%s' % (new_app_label,
                                                    new_model_name),
                                         max_name_length)
            meta.db_table = new_db_table

        # Set the new app/model names back on the meta instance.
        meta.app_label = new_app_label
        meta.object_name = new_object_name
        set_model_name(model, new_model_name)

        # Add an entry for the table in the database state, if it's not
        # already there.
        if not database_state.has_table(new_db_table):
            database_state.add_table(new_db_table)

        if register_indexes:
            # Now that we definitely have an entry, store the indexes for
            # all the fields in the database state, so that other operations
            # can look up the index names.
            for field in meta.local_fields:
                if field.db_index or field.unique:
                    new_index_name = create_index_name(
                        db_connection,
                        new_db_table,
                        field_names=[field.name],
                        col_names=[field.column],
                        unique=field.unique)

                    database_state.add_index(
                        index_name=new_index_name,
                        table_name=new_db_table,
                        columns=[field.column],
                        unique=field.unique)

            for field_names in meta.unique_together:
                fields = evolver.get_fields_for_names(model, field_names)
                new_index_name = create_index_name(
                    db_connection,
                    new_db_table,
                    field_names=field_names,
                    unique=True)

                database_state.add_index(
                    index_name=new_index_name,
                    table_name=new_db_table,
                    columns=[field.column for field in fields],
                    unique=True)

            for field_names in getattr(meta, 'index_together', []):
                # Django >= 1.5
                fields = evolver.get_fields_for_names(model, field_names)
                new_index_name = create_index_together_name(
                    db_connection,
                    new_db_table,
                    field_names=[field.name for field in fields])

                database_state.add_index(
                    index_name=new_index_name,
                    table_name=new_db_table,
                    columns=[field.column for field in fields])

            if getattr(meta, 'indexes', None):
                # Django >= 1.11
                for index, orig_index in zip(meta.indexes,
                                             meta.original_attrs['indexes']):
                    if not orig_index.name:
                        # The name was auto-generated. We'll need to generate
                        # it again for the new table name.
                        index.set_name_with_model(model)

                    fields = evolver.get_fields_for_names(
                        model, index.fields, allow_sort_prefixes=True)
                    database_state.add_index(
                        index_name=index.name,
                        table_name=new_db_table,
                        columns=[field.column for field in fields])

        # ManyToManyFields have their own tables, which will also need to be
        # renamed. Go through each of them and figure out what changes need
        # to be made.
        for field in meta.local_many_to_many:
            through = get_remote_field(field).through

            if not through:
                continue

            through_meta = through._meta
            through_orig_model_name = get_model_name(through)
            through_new_model_name = through_orig_model_name

            # Find out if the through table name is a custom table name, or
            # one generated by Django.
            generated_db_table = truncate_name(
                '%s_%s' % (orig_db_table, field.name),
                max_name_length)

            if through_meta.db_table == generated_db_table:
                # This is an auto-generated table name. Start changing the
                # state for it.
                assert through_meta.app_label == orig_app_label
                through_meta.app_label = new_app_label

                # Transform the 'through' table information only if we've
                # transformed the parent db_table.
                if new_db_table != orig_db_table:
                    through_meta.db_table = truncate_name(
                        '%s_%s' % (new_db_table, field.name),
                        max_name_length)

                    through_meta.object_name = \
                        through_meta.object_name.replace(orig_object_name,
                                                         new_object_name)

                    through_new_model_name = \
                        through_orig_model_name.replace(orig_model_name,
                                                        new_model_name)
                    set_model_name(through, through_new_model_name)

            # Change each of the columns for the fields on the
            # ManyToManyField's model to reflect the new model names.
            for through_field in through._meta.local_fields:
                through_remote_field = get_remote_field(through_field)

                if (through_remote_field and
                    get_remote_field_model(through_remote_field)):
                    column = through_field.column

                    if (column.startswith((orig_model_name,
                                           'to_%s' % orig_model_name,
                                           'from_%s' % orig_model_name))):
                        # This is a field that references one end of the
                        # relation or another. Update the model naem in the
                        # field's column.
                        through_field.column = column.replace(orig_model_name,
                                                              new_model_name)

            # Replace the entry in the models cache for the through table,
            # removing the old name and adding the new one.
            if through_orig_model_name in all_models[orig_app_label]:
                unregister_app_model(orig_app_label, through_orig_model_name)

            app_cache[through_new_model_name] = through
            register_app_models(new_app_label,
                                [(through_new_model_name, through)])

        # Unregister with the old model name and register the new one.
        if orig_model_name in all_models[orig_app_label]:
            unregister_app_model(orig_app_label, orig_model_name)

        register_app_models(new_app_label, [(new_model_name, model)])
        app_cache[new_model_name] = model

    # If the app hasn't yet been registered, do that now.
    if not is_app_registered(app):
        register_app(new_app_label, app)

    return app_cache
コード例 #21
0
    def test_with_one_to_one_field(self):
        """Testing get_remote_field with OneToOneField"""
        rel = get_remote_field(
            CompatModelsTestModel._meta.get_field('o2o_field'))

        self.assertIsInstance(rel, related.OneToOneRel)
コード例 #22
0
def sql_add_constraints(connection, model, refs):
    """Return SQL statements for adding constraints.

    This provides compatibility with all supported versions of Django.

    Args:
        connection (object):
            The database connection.

        model (django.db.models.Model):
            The database model to add constraints on.

        refs (dict):
            A dictionary of constraint references to add.

            The keys are instances of :py:class:`django.db.models.Model`.
            The values are a tuple of (:py:class:`django.db.models.Model`,
            :py:class:`django.db.models.Field`).

    Returns:
        list:
        The list of SQL statements for adding constraints.
    """
    if BaseDatabaseSchemaEditor:
        # Django >= 1.7
        meta = model._meta

        if not meta.managed or meta.swapped:
            return []

        sql = []

        if model in refs:
            with connection.schema_editor() as schema_editor:
                qn = schema_editor.quote_name

                for rel_class, f in refs[model]:
                    # Ideally, we would use schema_editor._create_fk_sql here,
                    # but it depends on a lot more state than we have
                    # available currently in our mocks. So we have to build
                    # the SQL ourselves. It's not a lot of work, fortunately.
                    #
                    # For reference, this is what we'd ideally do:
                    #
                    #     sql.append('%s;' % schema_editor._create_fk_sql(
                    #         rel_class, f,
                    #         '_fk_%(to_table)s_%(to_column)s'))
                    #
                    rel_meta = rel_class._meta
                    to_column = (meta.get_field(
                        get_remote_field(f).field_name).column)

                    suffix = '_fk_%(to_table)s_%(to_column)s' % {
                        'to_table': meta.db_table,
                        'to_column': to_column,
                    }

                    name = create_index_name(connection=connection,
                                             table_name=rel_meta.db_table,
                                             col_names=[f.column],
                                             suffix=suffix)

                    create_sql = schema_editor.sql_create_fk % {
                        'table': qn(rel_meta.db_table),
                        'name': qn(name),
                        'column': qn(f.column),
                        'to_table': qn(meta.db_table),
                        'to_column': qn(to_column),
                        'deferrable': connection.ops.deferrable_sql(),
                    }

                    sql.append('%s;' % create_sql)

            del refs[model]

        return sql
    else:
        # Django < 1.7
        return connection.creation.sql_for_pending_references(
            model, color.no_style(), refs)
コード例 #23
0
    def test_with_foreign_key_rel(self):
        """Testing get_remote_field_related_model with ForeignKey relation"""
        field = CompatModelsTestModel._meta.get_field('fkey_field')
        model = get_remote_field_related_model(get_remote_field(field))

        self.assertIs(model, CompatModelsTestModel)
コード例 #24
0
def iter_model_fields(model,
                      include_parent_models=True,
                      include_forward_fields=True,
                      include_reverse_fields=False,
                      include_hidden_fields=False,
                      seen_models=None):
    """Iterate through all fields on a model using the given criteria.

    This is roughly equivalent to Django's internal
    :py:func:`django.db.models.options.Option._get_fields` on Django 1.8+,
    but makes use of our model reverse relation tree, and works across all
    supported versions of Django.

    Version Added:
        2.2

    Args:
        model (type):
            The model owning the fields.

        include_parent_models (bool, optional):
            Whether to include fields defined on parent models.

        include_forward_fields (bool, optional):
            Whether to include fields owned by the model (or a parent).

        include_reverse_fields (bool, optional):
            Whether to include fields on other models that point to this
            model.

        include_hidden_fields (bool, optional):
            Whether to include hidden fields.

        seen_models (set, optional):
            Models seen during iteration. This is intended for internal
            use only by this function.

    Yields:
        django.db.models.Field:
        Each field matching the criteria.
    """
    concrete_model = model._meta.concrete_model

    if seen_models is None:
        seen_models = set()

    if include_parent_models:
        candidate_models = walk_model_tree(model)
    else:
        candidate_models = [model]

    if include_reverse_fields:
        # Find all models containing fields that point to this model.
        rel_tree = get_model_rel_tree()
        rel_fields = rel_tree.get(model._meta.concrete_model._meta.db_table,
                                  [])
    else:
        rel_fields = []

    for cur_model in candidate_models:
        cur_model_label = cur_model._meta.db_table

        if (cur_model_label in seen_models
                or cur_model._meta.concrete_model != concrete_model):
            continue

        seen_models.add(cur_model_label)

        if include_parent_models:
            for parent in cur_model._meta.parents:
                if parent not in seen_models:
                    parent_fields = iter_model_fields(
                        parent,
                        include_parent_models=True,
                        include_forward_fields=include_forward_fields,
                        include_reverse_fields=include_reverse_fields,
                        include_hidden_fields=include_hidden_fields)

                    for field in parent_fields:
                        yield field

        if include_reverse_fields and not cur_model._meta.proxy:
            for rel_field in rel_fields:
                remote_field = get_remote_field(rel_field)

                if (include_hidden_fields
                        or not get_field_is_hidden(remote_field)):
                    yield remote_field

        if include_forward_fields:
            for field in cur_model._meta.local_fields:
                yield field

            for field in cur_model._meta.local_many_to_many:
                yield field

    # Django >= 1.10
    for field in getattr(model._meta, 'private_fields', []):
        yield field
コード例 #25
0
ファイル: db.py プロジェクト: beanbaginc/django-evolution
def sql_add_constraints(connection, model, refs):
    """Return SQL statements for adding constraints.

    This provides compatibility with all supported versions of Django.

    Args:
        connection (object):
            The database connection.

        model (django.db.models.Model):
            The database model to add constraints on.

        refs (dict):
            A dictionary of constraint references to add.

            The keys are instances of :py:class:`django.db.models.Model`.
            The values are a tuple of (:py:class:`django.db.models.Model`,
            :py:class:`django.db.models.Field`).

    Returns:
        list:
        The list of SQL statements for adding constraints.
    """
    if BaseDatabaseSchemaEditor:
        # Django >= 1.7
        meta = model._meta

        if not meta.managed or meta.swapped:
            return []

        sql = []

        if model in refs:
            with connection.schema_editor() as schema_editor:
                qn = schema_editor.quote_name

                for rel_class, f in refs[model]:
                    # Ideally, we would use schema_editor._create_fk_sql here,
                    # but it depends on a lot more state than we have
                    # available currently in our mocks. So we have to build
                    # the SQL ourselves. It's not a lot of work, fortunately.
                    #
                    # For reference, this is what we'd ideally do:
                    #
                    #     sql.append('%s;' % schema_editor._create_fk_sql(
                    #         rel_class, f,
                    #         '_fk_%(to_table)s_%(to_column)s'))
                    #
                    rel_meta = rel_class._meta
                    to_column = (
                        meta.get_field(get_remote_field(f).field_name)
                        .column
                    )

                    suffix = '_fk_%(to_table)s_%(to_column)s' % {
                        'to_table': meta.db_table,
                        'to_column': to_column,
                    }

                    name = create_index_name(connection=connection,
                                             table_name=rel_meta.db_table,
                                             col_names=[f.column],
                                             suffix=suffix)

                    create_sql = schema_editor.sql_create_fk % {
                        'table': qn(rel_meta.db_table),
                        'name': qn(name),
                        'column': qn(f.column),
                        'to_table': qn(meta.db_table),
                        'to_column': qn(to_column),
                        'deferrable': connection.ops.deferrable_sql(),
                    }

                    sql.append('%s;' % create_sql)

            del refs[model]

        return sql
    else:
        # Django < 1.7
        return connection.creation.sql_for_pending_references(
            model, color.no_style(), refs)
コード例 #26
0
def create_field(project_sig,
                 field_name,
                 field_type,
                 field_attrs,
                 parent_model,
                 related_model=None):
    """Create a Django field instance for the given signature data.

    This creates a field in a way that's compatible with a variety of versions
    of Django. It takes in data such as the field's name and attributes
    and creates an instance that can be used like any field found on a model.

    Args:
        field_name (unicode):
            The name of the field.

        field_type (cls):
            The class for the type of field being constructed. This must be a
            subclass of :py:class:`django.db.models.Field`.

        field_attrs (dict):
            Attributes to set on the field.

        parent_model (cls):
            The parent model that would own this field. This must be a
            subclass of :py:class:`django.db.models.Model`.

        related_model (unicode, optional):
            The full class path to a model this relates to. This requires
            a :py:class:`django.db.models.ForeignKey` field type.

    Returns:
        django.db.models.Field:
        A new field instance matching the provided data.
    """
    # Convert to the standard string format for each version of Python, to
    # simulate what the format would be for the default name.
    field_name = str(field_name)

    assert 'related_model' not in field_attrs, \
           ('related_model cannot be passed in field_attrs when calling '
            'create_field(). Pass the related_model parameter instead.')

    if related_model:
        related_app_name, related_model_name = related_model.split('.')
        related_model_sig = (project_sig.get_app_sig(
            related_app_name, required=True).get_model_sig(related_model_name,
                                                           required=True))
        to = MockModel(project_sig=project_sig,
                       app_name=related_app_name,
                       model_name=related_model_name,
                       model_sig=related_model_sig,
                       stub=True)

        if (issubclass(field_type, models.ForeignKey)
                and hasattr(models, 'CASCADE')
                and 'on_delete' not in field_attrs):
            # Starting in Django 2.0, on_delete is a requirement for
            # ForeignKeys. If not provided in the signature, we want to
            # default this to CASCADE, which is the value that Django
            # previously defaulted to.
            field_attrs = dict({
                'on_delete': models.CASCADE,
            }, **field_attrs)

        field = field_type(to, name=field_name, **field_attrs)
    else:
        field = field_type(name=field_name, **field_attrs)

    if (issubclass(field_type, models.ManyToManyField)
            and parent_model is not None):
        # Starting in Django 1.2, a ManyToManyField must have a through
        # model defined. This will be set internally to an auto-created
        # model if one isn't specified. We have to fake that model.
        through_model = field_attrs.get('through_model')
        through_model_sig = None

        if through_model:
            through_app_name, through_model_name = through_model.split('.')
            through_model_sig = (project_sig.get_app_sig(
                through_app_name).get_model_sig(through_model_name))
        elif hasattr(field, '_get_m2m_attr'):
            # Django >= 1.2
            remote_field = get_remote_field(field)
            remote_field_model = get_remote_field_model(remote_field)

            to_field_name = remote_field_model._meta.object_name.lower()

            if (remote_field_model == RECURSIVE_RELATIONSHIP_CONSTANT or
                    to_field_name == parent_model._meta.object_name.lower()):
                from_field_name = 'from_%s' % to_field_name
                to_field_name = 'to_%s' % to_field_name
            else:
                from_field_name = parent_model._meta.object_name.lower()

            # This corresponds to the signature in
            # related.create_many_to_many_intermediary_model
            through_app_name = parent_model.app_name
            through_model_name = '%s_%s' % (parent_model._meta.object_name,
                                            field.name),

            through_model_sig = ModelSignature(
                model_name=through_model_name,
                table_name=field._get_m2m_db_table(parent_model._meta),
                pk_column='id',
                unique_together=[(from_field_name, to_field_name)])

            # 'id' field
            through_model_sig.add_field_sig(
                FieldSignature(field_name='id',
                               field_type=models.AutoField,
                               field_attrs={
                                   'primary_key': True,
                               }))

            # 'from' field
            through_model_sig.add_field_sig(
                FieldSignature(
                    field_name=from_field_name,
                    field_type=models.ForeignKey,
                    field_attrs={
                        'related_name': '%s+' % through_model_name,
                    },
                    related_model='%s.%s' %
                    (parent_model.app_name, parent_model._meta.object_name)))

            # 'to' field
            through_model_sig.add_field_sig(
                FieldSignature(field_name=to_field_name,
                               field_type=models.ForeignKey,
                               field_attrs={
                                   'related_name': '%s+' % through_model_name,
                               },
                               related_model=related_model))

            field.auto_created = True

        if through_model_sig:
            through = MockModel(project_sig=project_sig,
                                app_name=through_app_name,
                                model_name=through_model_name,
                                model_sig=through_model_sig,
                                auto_created=not through_model,
                                managed=not through_model)
            get_remote_field(field).through = through

        field.m2m_db_table = curry(field._get_m2m_db_table, parent_model._meta)
        field.set_attributes_from_rel()

    field.set_attributes_from_name(field_name)

    # Needed in Django >= 1.7, for index building.
    field.model = parent_model

    return field
コード例 #27
0
ファイル: utils.py プロジェクト: beanbaginc/django-evolution
def register_models(database_state, models, register_indexes=False,
                    new_app_label='tests', db_name='default', app=evo_test):
    """Register models for testing purposes.

    Args:
        database_state (django_evolution.db.state.DatabaseState):
            The database state to populate with model information.

        models (list of django.db.models.Model):
            The models to register.

        register_indexes (bool, optional):
            Whether indexes should be registered for any models. Defaults to
            ``False``.

        new_app_label (str, optional):
            The label for the test app. Defaults to "tests".

        db_name (str, optional):
            The name of the database connection. Defaults to "default".

        app (module, optional):
            The application module for the test models.

    Returns:
        collections.OrderedDict:
        A dictionary of registered models. The keys are model names, and
        the values are the models.
    """
    app_cache = OrderedDict()
    evolver = EvolutionOperationsMulti(db_name, database_state).get_evolver()

    db_connection = connections[db_name or DEFAULT_DB_ALIAS]
    max_name_length = db_connection.ops.max_name_length()

    for new_object_name, model in reversed(models):
        # Grab some state from the model's meta instance. Some of this will
        # be original state that we'll keep around to help us unregister old
        # values and compute new ones.
        meta = model._meta

        orig_app_label = meta.app_label
        orig_db_table = meta.db_table
        orig_object_name = meta.object_name
        orig_model_name = get_model_name(model)

        # Find out if the table name being used is a custom table name, or
        # one generated by Django.
        new_model_name = new_object_name.lower()
        new_db_table = orig_db_table

        generated_db_table = truncate_name(
            '%s_%s' % (orig_app_label, orig_model_name),
            max_name_length)

        if orig_db_table == generated_db_table:
            # It was a generated one, so replace it with a version containing
            # the new model and app names.
            new_db_table = truncate_name('%s_%s' % (new_app_label,
                                                    new_model_name),
                                         max_name_length)
            meta.db_table = new_db_table

        # Set the new app/model names back on the meta instance.
        meta.app_label = new_app_label
        meta.object_name = new_object_name
        set_model_name(model, new_model_name)

        # Add an entry for the table in the database state, if it's not
        # already there.
        if not database_state.has_table(new_db_table):
            database_state.add_table(new_db_table)

        if register_indexes:
            # Now that we definitely have an entry, store the indexes for
            # all the fields in the database state, so that other operations
            # can look up the index names.
            for field in meta.local_fields:
                if field.db_index or field.unique:
                    new_index_name = create_index_name(
                        db_connection,
                        new_db_table,
                        field_names=[field.name],
                        col_names=[field.column],
                        unique=field.unique)

                    database_state.add_index(
                        index_name=new_index_name,
                        table_name=new_db_table,
                        columns=[field.column],
                        unique=field.unique)

            for field_names in meta.unique_together:
                fields = evolver.get_fields_for_names(model, field_names)
                new_index_name = create_index_name(
                    db_connection,
                    new_db_table,
                    field_names=field_names,
                    unique=True)

                database_state.add_index(
                    index_name=new_index_name,
                    table_name=new_db_table,
                    columns=[field.column for field in fields],
                    unique=True)

            for field_names in getattr(meta, 'index_together', []):
                # Django >= 1.5
                fields = evolver.get_fields_for_names(model, field_names)
                new_index_name = create_index_together_name(
                    db_connection,
                    new_db_table,
                    field_names=[field.name for field in fields])

                database_state.add_index(
                    index_name=new_index_name,
                    table_name=new_db_table,
                    columns=[field.column for field in fields])

            if getattr(meta, 'indexes', None):
                # Django >= 1.11
                for index, orig_index in zip(meta.indexes,
                                             meta.original_attrs['indexes']):
                    if not orig_index.name:
                        # The name was auto-generated. We'll need to generate
                        # it again for the new table name.
                        index.set_name_with_model(model)

                    fields = evolver.get_fields_for_names(
                        model, index.fields, allow_sort_prefixes=True)
                    database_state.add_index(
                        index_name=index.name,
                        table_name=new_db_table,
                        columns=[field.column for field in fields])

        # ManyToManyFields have their own tables, which will also need to be
        # renamed. Go through each of them and figure out what changes need
        # to be made.
        for field in meta.local_many_to_many:
            through = get_remote_field(field).through

            if not through:
                continue

            through_meta = through._meta
            through_orig_model_name = get_model_name(through)
            through_new_model_name = through_orig_model_name

            # Find out if the through table name is a custom table name, or
            # one generated by Django.
            generated_db_table = truncate_name(
                '%s_%s' % (orig_db_table, field.name),
                max_name_length)

            if through_meta.db_table == generated_db_table:
                # This is an auto-generated table name. Start changing the
                # state for it.
                assert through_meta.app_label == orig_app_label
                through_meta.app_label = new_app_label

                # Transform the 'through' table information only if we've
                # transformed the parent db_table.
                if new_db_table != orig_db_table:
                    through_meta.db_table = truncate_name(
                        '%s_%s' % (new_db_table, field.name),
                        max_name_length)

                    through_meta.object_name = \
                        through_meta.object_name.replace(orig_object_name,
                                                         new_object_name)

                    through_new_model_name = \
                        through_orig_model_name.replace(orig_model_name,
                                                        new_model_name)
                    set_model_name(through, through_new_model_name)

            # Change each of the columns for the fields on the
            # ManyToManyField's model to reflect the new model names.
            for through_field in through._meta.local_fields:
                through_remote_field = get_remote_field(through_field)

                if (through_remote_field and
                    get_remote_field_model(through_remote_field)):
                    column = through_field.column

                    if (column.startswith((orig_model_name,
                                           'to_%s' % orig_model_name,
                                           'from_%s' % orig_model_name))):
                        # This is a field that references one end of the
                        # relation or another. Update the model naem in the
                        # field's column.
                        through_field.column = column.replace(orig_model_name,
                                                              new_model_name)

            # Replace the entry in the models cache for the through table,
            # removing the old name and adding the new one.
            if through_orig_model_name in all_models[orig_app_label]:
                unregister_app_model(orig_app_label, through_orig_model_name)

            app_cache[through_new_model_name] = through
            register_app_models(new_app_label,
                                [(through_new_model_name, through)])

        # Unregister with the old model name and register the new one.
        if orig_model_name in all_models[orig_app_label]:
            unregister_app_model(orig_app_label, orig_model_name)

        register_app_models(new_app_label, [(new_model_name, model)])
        app_cache[new_model_name] = model

    # If the app hasn't yet been registered, do that now.
    if not is_app_registered(app):
        register_app(new_app_label, app)

    return app_cache