예제 #1
0
    def create_table(self,
                     table_name,
                     field_list,
                     temporary=False,
                     create_index=True):
        qn = self.connection.ops.quote_name
        output = []

        create = ['CREATE']
        if temporary:
            create.append('TEMPORARY')
        create.append('TABLE %s' % qn(table_name))
        output = [' '.join(create)]
        output.append('(')
        columns = []

        for field in field_list:
            if type(field) is not models.ManyToManyField:
                column_name = qn(field.column)
                column_type = field.db_type(connection=self.connection)
                params = [column_name, column_type]

                # Always use null if this is a temporary table. It may be
                # used to create a new field (which will be null while data is
                # copied across from the old table).
                if temporary or field.null:
                    params.append('NULL')
                else:
                    params.append('NOT NULL')

                if field.unique:
                    params.append('UNIQUE')

                if field.primary_key:
                    params.append('PRIMARY KEY')

                if not temporary and isinstance(field, models.ForeignKey):
                    remote_field = get_remote_field(field)
                    remote_field_model = get_remote_field_model(remote_field)

                    params.append(
                        'REFERENCES %s (%s) DEFERRABLE INITIALLY DEFERRED' %
                        (qn(remote_field_model._meta.db_table),
                         qn(
                             remote_field_model._meta.get_field(
                                 remote_field.field_name).column)))

                columns.append(' '.join(params))

        output.append(', '.join(columns))
        output.append(');')
        output = [''.join(output)]

        if create_index:
            output.extend(self.create_indexes_for_table(
                table_name, field_list))

        return output
예제 #2
0
    def remove_field_constraints(self, field, opts, models, refs):
        """Return SQL for removing constraints on a field.

        Args:
            field (django.db.models.Field):
                The field the constraints will be removed from.

            opts (django.db.models.options.Options):
                The Meta class for the model.

            models (list of django.db.models.Model):
                A caller-provided list that will be populated with models
                that constraints will be removed from.

            refs (dict):
                A caller-supplied dictionary that will be populated with
                references that are removed.

                The keys are models, and the values are lists of
                tuples of many-to-many models and fields.

        Returns:
            list:
            The list of SQL statements for removing constraints on the field.
        """
        sql = []

        if self.supports_constraints and field.primary_key:
            for f in opts.local_many_to_many:
                remote_field = get_remote_field(f)

                if remote_field and remote_field.through:
                    through = remote_field.through

                    for m2m_f in through._meta.local_fields:
                        remote_m2m_f = get_remote_field(m2m_f)

                        if not remote_m2m_f:
                            continue

                        remote_m2m_f_model = \
                            get_remote_field_model(remote_m2m_f)

                        if (remote_m2m_f.field_name == field.column
                                and remote_m2m_f_model._meta.db_table
                                == opts.db_table):
                            models.append(remote_m2m_f_model)
                            refs.setdefault(remote_m2m_f_model, []).append(
                                (through, m2m_f))

            remove_refs = refs.copy()

            for relto in models:
                sql.extend(
                    sql_delete_constraints(self.connection, relto,
                                           remove_refs))

        return sql
예제 #3
0
def get_model_rel_tree():
    """Return the full field relationship tree for all registered models.

    This will walk through every field in every model registered in Django,
    storing the relationships between objects, caching them. Each entry in
    the resulting dictionary will be a table mapping to a list of relation
    fields that point back at it.

    This can be used to quickly locate any and all reverse relations made to
    a field.

    This is similar to Django's built-in reverse relation tree used internally
    (with different implementations) in
    :py:class:`django.db.models.options.Options`, but works across all
    supported versions of Django, and supports cache clearing.

    Version Added:
        2.2

    Returns:
        dict:
        The model relation tree.
    """
    global _rel_tree_cache

    if _rel_tree_cache is not None:
        return _rel_tree_cache

    rel_tree = defaultdict(list)
    all_models = get_models(include_auto_created=True)

    # We'll walk the entire model tree, looking for any immediate fields on
    # each model, building a mapping of models to fields that reference the
    # model.
    for cur_model in all_models:
        if cur_model._meta.abstract:
            continue

        for field in iter_model_fields(cur_model,
                                       include_parent_models=False,
                                       include_forward_fields=True,
                                       include_reverse_fields=False,
                                       include_hidden_fields=False):
            if (get_field_is_relation(field)
                    and get_remote_field_related_model(field) is not None):
                remote_field = get_remote_field(field)
                remote_field_model = get_remote_field_model(remote_field)

                # Make sure this isn't a "self" relation or similar.
                if not isinstance(remote_field_model, six.string_types):
                    db_table = \
                        remote_field_model._meta.concrete_model._meta.db_table
                    rel_tree[db_table].append(field)

    _rel_tree_cache = rel_tree

    return rel_tree
예제 #4
0
    def rename_table(self, model, old_db_tablename, db_tablename):
        sql_result = SQLResult()

        if old_db_tablename == db_tablename:
            # No Operation
            return sql_result

        max_name_length = self.connection.ops.max_name_length()

        refs = {}
        models = []

        for field in model._meta.local_many_to_many:
            remote_field = get_remote_field(field)

            if (remote_field and remote_field.through and
                    remote_field.through._meta.db_table == old_db_tablename):

                through = remote_field.through

                for m2m_field in through._meta.local_fields:
                    remote_m2m_field = get_remote_field(m2m_field)

                    if remote_m2m_field:
                        remote_m2m_field_model = get_remote_field_model(
                            remote_m2m_field)

                        if remote_m2m_field_model == model:
                            models.append(remote_m2m_field_model)
                            refs.setdefault(remote_m2m_field_model, []).append(
                                (through, m2m_field))

        remove_refs = refs.copy()

        if self.supports_constraints:
            for relto in models:
                sql_result.add_pre_sql(
                    sql_delete_constraints(self.connection, relto,
                                           remove_refs))

        sql_result.add(
            self.get_rename_table_sql(model, old_db_tablename, db_tablename))

        for relto in models:
            for rel_class, f in refs[relto]:
                if rel_class._meta.db_table == old_db_tablename:
                    rel_class._meta.db_table = db_tablename

                rel_class._meta.db_table = \
                    truncate_name(rel_class._meta.db_table,
                                  max_name_length)

            if self.supports_constraints:
                sql_result.add_post_sql(
                    sql_add_constraints(self.connection, relto, refs))

        return sql_result
예제 #5
0
    def create_table(self, table_name, field_list, temporary=False,
                     create_index=True):
        qn = self.connection.ops.quote_name
        output = []

        create = ['CREATE']
        if temporary:
            create.append('TEMPORARY')
        create.append('TABLE %s' % qn(table_name))
        output = [' '.join(create)]
        output.append('(')
        columns = []

        for field in field_list:
            if type(field) is not models.ManyToManyField:
                column_name = qn(field.column)
                column_type = field.db_type(connection=self.connection)
                params = [column_name, column_type]

                # Always use null if this is a temporary table. It may be
                # used to create a new field (which will be null while data is
                # copied across from the old table).
                if temporary or field.null:
                    params.append('NULL')
                else:
                    params.append('NOT NULL')

                if field.unique:
                    params.append('UNIQUE')

                if field.primary_key:
                    params.append('PRIMARY KEY')

                if not temporary and isinstance(field, models.ForeignKey):
                    remote_field = get_remote_field(field)
                    remote_field_model = get_remote_field_model(remote_field)

                    params.append(
                        'REFERENCES %s (%s) DEFERRABLE INITIALLY DEFERRED'
                        % (qn(remote_field_model._meta.db_table),
                           qn(remote_field_model._meta.get_field(
                               remote_field.field_name).column)))

                columns.append(' '.join(params))

        output.append(', '.join(columns))
        output.append(');')
        output = [''.join(output)]

        if create_index:
            output.extend(
                self.create_indexes_for_table(table_name, field_list))

        return output
예제 #6
0
    def _get_rename_column_sql(self, opts, old_field, new_field):
        qn = self.connection.ops.quote_name
        style = color.no_style()
        col_type = new_field.db_type(connection=self.connection)
        tablespace = new_field.db_tablespace or opts.db_tablespace
        alter_table_item = ''

        # Make the definition (e.g. 'foo VARCHAR(30)') for this field.
        field_output = [
            style.SQL_FIELD(qn(new_field.column)),
            style.SQL_COLTYPE(col_type),
            style.SQL_KEYWORD('%sNULL' %
                              (not new_field.null and 'NOT ' or '')),
        ]

        if new_field.primary_key:
            field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))

        if new_field.unique:
            field_output.append(style.SQL_KEYWORD('UNIQUE'))

        if (tablespace and self.connection.features.supports_tablespaces
                and self.connection.features.autoindexes_primary_keys
                and (new_field.unique or new_field.primary_key)):
            # We must specify the index tablespace inline, because we
            # won't be generating a CREATE INDEX statement for this field.
            field_output.append(
                self.connection.ops.tablespace_sql(tablespace, inline=True))

        new_remote_field = get_remote_field(new_field)

        if new_remote_field:
            new_remote_field_meta = \
                get_remote_field_model(new_remote_field)._meta

            field_output.append('%s %s (%s)%s' % (
                style.SQL_KEYWORD('REFERENCES'),
                style.SQL_TABLE(qn(new_remote_field_meta.db_table)),
                style.SQL_FIELD(
                    qn(
                        new_remote_field_meta.get_field(
                            new_remote_field.field_name).column)),
                self.connection.ops.deferrable_sql(),
            ))

        if old_field.primary_key:
            alter_table_item = 'DROP PRIMARY KEY, '

        alter_table_item += ('CHANGE COLUMN %s %s' %
                             (qn(old_field.column), ' '.join(field_output)))

        return [{'sql': alter_table_item}]
예제 #7
0
    def delete_column(self, model, f):
        sql_result = AlterTableSQLResult(self, model)

        remote_field = get_remote_field(f)

        if remote_field:
            remote_field_model = get_remote_field_model(remote_field)

            sql_result.add(
                sql_delete_constraints(self.connection, remote_field_model,
                                       {remote_field_model: [(model, f)]}))

        sql_result.add_sql(
            super(EvolutionOperations, self).delete_column(model, f))

        return sql_result
def create_field(project_sig,
                 field_name,
                 field_type,
                 field_attrs,
                 parent_model,
                 related_model=None):
    """Create a Django field instance for the given signature data.

    This creates a field in a way that's compatible with a variety of versions
    of Django. It takes in data such as the field's name and attributes
    and creates an instance that can be used like any field found on a model.

    Args:
        field_name (unicode):
            The name of the field.

        field_type (cls):
            The class for the type of field being constructed. This must be a
            subclass of :py:class:`django.db.models.Field`.

        field_attrs (dict):
            Attributes to set on the field.

        parent_model (cls):
            The parent model that would own this field. This must be a
            subclass of :py:class:`django.db.models.Model`.

        related_model (unicode, optional):
            The full class path to a model this relates to. This requires
            a :py:class:`django.db.models.ForeignKey` field type.

    Returns:
        django.db.models.Field:
        A new field instance matching the provided data.
    """
    # Convert to the standard string format for each version of Python, to
    # simulate what the format would be for the default name.
    field_name = str(field_name)

    assert 'related_model' not in field_attrs, \
           ('related_model cannot be passed in field_attrs when calling '
            'create_field(). Pass the related_model parameter instead.')

    if related_model:
        related_app_name, related_model_name = related_model.split('.')
        related_model_sig = (project_sig.get_app_sig(
            related_app_name, required=True).get_model_sig(related_model_name,
                                                           required=True))
        to = MockModel(project_sig=project_sig,
                       app_name=related_app_name,
                       model_name=related_model_name,
                       model_sig=related_model_sig,
                       stub=True)

        if (issubclass(field_type, models.ForeignKey)
                and hasattr(models, 'CASCADE')
                and 'on_delete' not in field_attrs):
            # Starting in Django 2.0, on_delete is a requirement for
            # ForeignKeys. If not provided in the signature, we want to
            # default this to CASCADE, which is the value that Django
            # previously defaulted to.
            field_attrs = dict({
                'on_delete': models.CASCADE,
            }, **field_attrs)

        field = field_type(to, name=field_name, **field_attrs)
    else:
        field = field_type(name=field_name, **field_attrs)

    if (issubclass(field_type, models.ManyToManyField)
            and parent_model is not None):
        # Starting in Django 1.2, a ManyToManyField must have a through
        # model defined. This will be set internally to an auto-created
        # model if one isn't specified. We have to fake that model.
        through_model = field_attrs.get('through_model')
        through_model_sig = None

        if through_model:
            through_app_name, through_model_name = through_model.split('.')
            through_model_sig = (project_sig.get_app_sig(
                through_app_name).get_model_sig(through_model_name))
        elif hasattr(field, '_get_m2m_attr'):
            # Django >= 1.2
            remote_field = get_remote_field(field)
            remote_field_model = get_remote_field_model(remote_field)

            to_field_name = remote_field_model._meta.object_name.lower()

            if (remote_field_model == RECURSIVE_RELATIONSHIP_CONSTANT or
                    to_field_name == parent_model._meta.object_name.lower()):
                from_field_name = 'from_%s' % to_field_name
                to_field_name = 'to_%s' % to_field_name
            else:
                from_field_name = parent_model._meta.object_name.lower()

            # This corresponds to the signature in
            # related.create_many_to_many_intermediary_model
            through_app_name = parent_model.app_name
            through_model_name = '%s_%s' % (parent_model._meta.object_name,
                                            field.name),

            through_model_sig = ModelSignature(
                model_name=through_model_name,
                table_name=field._get_m2m_db_table(parent_model._meta),
                pk_column='id',
                unique_together=[(from_field_name, to_field_name)])

            # 'id' field
            through_model_sig.add_field_sig(
                FieldSignature(field_name='id',
                               field_type=models.AutoField,
                               field_attrs={
                                   'primary_key': True,
                               }))

            # 'from' field
            through_model_sig.add_field_sig(
                FieldSignature(
                    field_name=from_field_name,
                    field_type=models.ForeignKey,
                    field_attrs={
                        'related_name': '%s+' % through_model_name,
                    },
                    related_model='%s.%s' %
                    (parent_model.app_name, parent_model._meta.object_name)))

            # 'to' field
            through_model_sig.add_field_sig(
                FieldSignature(field_name=to_field_name,
                               field_type=models.ForeignKey,
                               field_attrs={
                                   'related_name': '%s+' % through_model_name,
                               },
                               related_model=related_model))

            field.auto_created = True

        if through_model_sig:
            through = MockModel(project_sig=project_sig,
                                app_name=through_app_name,
                                model_name=through_model_name,
                                model_sig=through_model_sig,
                                auto_created=not through_model,
                                managed=not through_model)
            get_remote_field(field).through = through

        field.m2m_db_table = curry(field._get_m2m_db_table, parent_model._meta)
        field.set_attributes_from_rel()

    field.set_attributes_from_name(field_name)

    # Needed in Django >= 1.7, for index building.
    field.model = parent_model

    return field
예제 #9
0
    def add_column(self, model, f, initial):
        qn = self.connection.ops.quote_name
        sql_result = AlterTableSQLResult(self, model)
        table_name = model._meta.db_table

        remote_field = get_remote_field(f)

        if remote_field:
            # it is a foreign key field
            # NOT NULL REFERENCES "django_evolution_addbasemodel"
            # ("id") DEFERRABLE INITIALLY DEFERRED

            # ALTER TABLE <tablename> ADD COLUMN <column name> NULL
            # REFERENCES <tablename1> ("<colname>") DEFERRABLE INITIALLY
            # DEFERRED
            related_model = get_remote_field_model(remote_field)
            related_table = related_model._meta.db_table
            related_pk_col = related_model._meta.pk.name
            constraints = ['%sNULL' % (not f.null and 'NOT ' or '')]

            if f.unique or f.primary_key:
                constraints.append('UNIQUE')

            sql_result.add_alter_table([{
                'op':
                'ADD COLUMN',
                'column':
                f.column,
                'db_type':
                f.db_type(connection=self.connection),
                'params':
                constraints + [
                    'REFERENCES',
                    qn(related_table),
                    '(%s)' % qn(related_pk_col),
                    self.connection.ops.deferrable_sql(),
                ]
            }])
        else:
            null_constraints = '%sNULL' % (not f.null and 'NOT ' or '')

            if f.unique or f.primary_key:
                unique_constraints = 'UNIQUE'
            else:
                unique_constraints = ''

            # At this point, initial can only be None if null=True,
            # otherwise it is a user callable or the default
            # AddFieldInitialCallback which will shortly raise an exception.
            if initial is not None:
                if callable(initial):
                    sql_result.add_alter_table([{
                        'op':
                        'ADD COLUMN',
                        'column':
                        f.column,
                        'db_type':
                        f.db_type(connection=self.connection),
                        'params': [unique_constraints],
                    }])

                    sql_result.add_sql([
                        'UPDATE %s SET %s = %s WHERE %s IS NULL;' %
                        (qn(table_name), qn(f.column), initial(), qn(f.column))
                    ])

                    if not f.null:
                        # Only put this sql statement if the column cannot
                        # be null.
                        sql_result.add_sql(
                            self.set_field_null(model, f, f.null))
                else:
                    sql_result.add_alter_table([{
                        'op':
                        'ADD COLUMN',
                        'column':
                        f.column,
                        'db_type':
                        f.db_type(connection=self.connection),
                        'params': [
                            null_constraints,
                            unique_constraints,
                            'DEFAULT',
                            '%s',
                        ],
                        'sql_params': [initial]
                    }])

                    # Django doesn't generate default columns, so now that
                    # we've added one to get default values for existing
                    # tables, drop that default.
                    sql_result.add_post_sql([
                        'ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;' %
                        (qn(table_name), qn(f.column))
                    ])
            else:
                sql_result.add_alter_table([{
                    'op':
                    'ADD COLUMN',
                    'column':
                    f.column,
                    'db_type':
                    f.db_type(connection=self.connection),
                    'params': [null_constraints, unique_constraints],
                }])

        if f.unique or f.primary_key:
            self.database_state.add_index(
                table_name=table_name,
                index_name=self.get_new_constraint_name(table_name, f.column),
                columns=[f.column],
                unique=True)

        return sql_result
예제 #10
0
    def test_with_foreign_key(self):
        """Testing get_remote_field_model with ForeignKey"""
        model = get_remote_field_model(
            CompatModelsTestModel._meta.get_field('fkey_field'))

        self.assertIs(model, CompatModelsTestModel)
예제 #11
0
    def test_with_one_to_one_field_rel(self):
        """Testing get_remote_field_model with OneToOneField relation"""
        field = CompatModelsTestModel._meta.get_field('o2o_field')
        model = get_remote_field_model(get_remote_field(field))

        self.assertIs(model, CompatModelsAnchor)
예제 #12
0
    def test_with_one_to_one_field(self):
        """Testing get_remote_field_model with OneToOneField"""
        model = get_remote_field_model(
            CompatModelsTestModel._meta.get_field('o2o_field'))

        self.assertIs(model, CompatModelsTestModel)
예제 #13
0
    def test_with_many_to_many_field_rel(self):
        """Testing get_remote_field_model with ManyToManyField relation"""
        field = CompatModelsTestModel._meta.get_field('m2m_field')
        model = get_remote_field_model(get_remote_field(field))

        self.assertIs(model, CompatModelsAnchor)
예제 #14
0
    def test_with_many_to_many_field(self):
        """Testing get_remote_field_model with ManyToManyField"""
        model = get_remote_field_model(
            CompatModelsTestModel._meta.get_field('m2m_field'))

        self.assertIs(model, CompatModelsTestModel)
예제 #15
0
    def test_with_foreign_key_rel(self):
        """Testing get_remote_field_model with ForeignKey relation"""
        field = CompatModelsTestModel._meta.get_field('fkey_field')
        model = get_remote_field_model(get_remote_field(field))

        self.assertIs(model, CompatModelsAnchor)
예제 #16
0
def register_models(database_state, models, register_indexes=False,
                    new_app_label='tests', db_name='default', app=evo_test):
    """Register models for testing purposes.

    Args:
        database_state (django_evolution.db.state.DatabaseState):
            The database state to populate with model information.

        models (list of django.db.models.Model):
            The models to register.

        register_indexes (bool, optional):
            Whether indexes should be registered for any models. Defaults to
            ``False``.

        new_app_label (str, optional):
            The label for the test app. Defaults to "tests".

        db_name (str, optional):
            The name of the database connection. Defaults to "default".

        app (module, optional):
            The application module for the test models.

    Returns:
        collections.OrderedDict:
        A dictionary of registered models. The keys are model names, and
        the values are the models.
    """
    app_cache = OrderedDict()
    evolver = EvolutionOperationsMulti(db_name, database_state).get_evolver()

    db_connection = connections[db_name or DEFAULT_DB_ALIAS]
    max_name_length = db_connection.ops.max_name_length()

    for new_object_name, model in reversed(models):
        # Grab some state from the model's meta instance. Some of this will
        # be original state that we'll keep around to help us unregister old
        # values and compute new ones.
        meta = model._meta

        orig_app_label = meta.app_label
        orig_db_table = meta.db_table
        orig_object_name = meta.object_name
        orig_model_name = get_model_name(model)

        # Find out if the table name being used is a custom table name, or
        # one generated by Django.
        new_model_name = new_object_name.lower()
        new_db_table = orig_db_table

        generated_db_table = truncate_name(
            '%s_%s' % (orig_app_label, orig_model_name),
            max_name_length)

        if orig_db_table == generated_db_table:
            # It was a generated one, so replace it with a version containing
            # the new model and app names.
            new_db_table = truncate_name('%s_%s' % (new_app_label,
                                                    new_model_name),
                                         max_name_length)
            meta.db_table = new_db_table

        # Set the new app/model names back on the meta instance.
        meta.app_label = new_app_label
        meta.object_name = new_object_name
        set_model_name(model, new_model_name)

        # Add an entry for the table in the database state, if it's not
        # already there.
        if not database_state.has_table(new_db_table):
            database_state.add_table(new_db_table)

        if register_indexes:
            # Now that we definitely have an entry, store the indexes for
            # all the fields in the database state, so that other operations
            # can look up the index names.
            for field in meta.local_fields:
                if field.db_index or field.unique:
                    new_index_name = create_index_name(
                        db_connection,
                        new_db_table,
                        field_names=[field.name],
                        col_names=[field.column],
                        unique=field.unique)

                    database_state.add_index(
                        index_name=new_index_name,
                        table_name=new_db_table,
                        columns=[field.column],
                        unique=field.unique)

            for field_names in meta.unique_together:
                fields = evolver.get_fields_for_names(model, field_names)
                new_index_name = create_index_name(
                    db_connection,
                    new_db_table,
                    field_names=field_names,
                    unique=True)

                database_state.add_index(
                    index_name=new_index_name,
                    table_name=new_db_table,
                    columns=[field.column for field in fields],
                    unique=True)

            for field_names in getattr(meta, 'index_together', []):
                # Django >= 1.5
                fields = evolver.get_fields_for_names(model, field_names)
                new_index_name = create_index_together_name(
                    db_connection,
                    new_db_table,
                    field_names=[field.name for field in fields])

                database_state.add_index(
                    index_name=new_index_name,
                    table_name=new_db_table,
                    columns=[field.column for field in fields])

            if getattr(meta, 'indexes', None):
                # Django >= 1.11
                for index, orig_index in zip(meta.indexes,
                                             meta.original_attrs['indexes']):
                    if not orig_index.name:
                        # The name was auto-generated. We'll need to generate
                        # it again for the new table name.
                        index.set_name_with_model(model)

                    fields = evolver.get_fields_for_names(
                        model, index.fields, allow_sort_prefixes=True)
                    database_state.add_index(
                        index_name=index.name,
                        table_name=new_db_table,
                        columns=[field.column for field in fields])

        # ManyToManyFields have their own tables, which will also need to be
        # renamed. Go through each of them and figure out what changes need
        # to be made.
        for field in meta.local_many_to_many:
            through = get_remote_field(field).through

            if not through:
                continue

            through_meta = through._meta
            through_orig_model_name = get_model_name(through)
            through_new_model_name = through_orig_model_name

            # Find out if the through table name is a custom table name, or
            # one generated by Django.
            generated_db_table = truncate_name(
                '%s_%s' % (orig_db_table, field.name),
                max_name_length)

            if through_meta.db_table == generated_db_table:
                # This is an auto-generated table name. Start changing the
                # state for it.
                assert through_meta.app_label == orig_app_label
                through_meta.app_label = new_app_label

                # Transform the 'through' table information only if we've
                # transformed the parent db_table.
                if new_db_table != orig_db_table:
                    through_meta.db_table = truncate_name(
                        '%s_%s' % (new_db_table, field.name),
                        max_name_length)

                    through_meta.object_name = \
                        through_meta.object_name.replace(orig_object_name,
                                                         new_object_name)

                    through_new_model_name = \
                        through_orig_model_name.replace(orig_model_name,
                                                        new_model_name)
                    set_model_name(through, through_new_model_name)

            # Change each of the columns for the fields on the
            # ManyToManyField's model to reflect the new model names.
            for through_field in through._meta.local_fields:
                through_remote_field = get_remote_field(through_field)

                if (through_remote_field and
                    get_remote_field_model(through_remote_field)):
                    column = through_field.column

                    if (column.startswith((orig_model_name,
                                           'to_%s' % orig_model_name,
                                           'from_%s' % orig_model_name))):
                        # This is a field that references one end of the
                        # relation or another. Update the model naem in the
                        # field's column.
                        through_field.column = column.replace(orig_model_name,
                                                              new_model_name)

            # Replace the entry in the models cache for the through table,
            # removing the old name and adding the new one.
            if through_orig_model_name in all_models[orig_app_label]:
                unregister_app_model(orig_app_label, through_orig_model_name)

            app_cache[through_new_model_name] = through
            register_app_models(new_app_label,
                                [(through_new_model_name, through)])

        # Unregister with the old model name and register the new one.
        if orig_model_name in all_models[orig_app_label]:
            unregister_app_model(orig_app_label, orig_model_name)

        register_app_models(new_app_label, [(new_model_name, model)])
        app_cache[new_model_name] = model

    # If the app hasn't yet been registered, do that now.
    if not is_app_registered(app):
        register_app(new_app_label, app)

    return app_cache
예제 #17
0
def register_models(database_state, models, register_indexes=False,
                    new_app_label='tests', db_name='default', app=evo_test):
    """Register models for testing purposes.

    Args:
        database_state (django_evolution.db.state.DatabaseState):
            The database state to populate with model information.

        models (list of django.db.models.Model):
            The models to register.

        register_indexes (bool, optional):
            Whether indexes should be registered for any models. Defaults to
            ``False``.

        new_app_label (str, optional):
            The label for the test app. Defaults to "tests".

        db_name (str, optional):
            The name of the database connection. Defaults to "default".

        app (module, optional):
            The application module for the test models.

    Returns:
        collections.OrderedDict:
        A dictionary of registered models. The keys are model names, and
        the values are the models.
    """
    app_cache = OrderedDict()
    evolver = EvolutionOperationsMulti(db_name, database_state).get_evolver()

    db_connection = connections[db_name or DEFAULT_DB_ALIAS]
    max_name_length = db_connection.ops.max_name_length()

    for new_object_name, model in reversed(models):
        # Grab some state from the model's meta instance. Some of this will
        # be original state that we'll keep around to help us unregister old
        # values and compute new ones.
        meta = model._meta

        orig_app_label = meta.app_label
        orig_db_table = meta.db_table
        orig_object_name = meta.object_name
        orig_model_name = get_model_name(model)

        # Find out if the table name being used is a custom table name, or
        # one generated by Django.
        new_model_name = new_object_name.lower()
        new_db_table = orig_db_table

        generated_db_table = truncate_name(
            '%s_%s' % (orig_app_label, orig_model_name),
            max_name_length)

        if orig_db_table == generated_db_table:
            # It was a generated one, so replace it with a version containing
            # the new model and app names.
            new_db_table = truncate_name('%s_%s' % (new_app_label,
                                                    new_model_name),
                                         max_name_length)
            meta.db_table = new_db_table

        # Set the new app/model names back on the meta instance.
        meta.app_label = new_app_label
        meta.object_name = new_object_name
        set_model_name(model, new_model_name)

        # Add an entry for the table in the database state, if it's not
        # already there.
        if not database_state.has_table(new_db_table):
            database_state.add_table(new_db_table)

        if register_indexes:
            # Now that we definitely have an entry, store the indexes for
            # all the fields in the database state, so that other operations
            # can look up the index names.
            for field in meta.local_fields:
                if field.db_index or field.unique:
                    new_index_name = create_index_name(
                        db_connection,
                        new_db_table,
                        field_names=[field.name],
                        col_names=[field.column],
                        unique=field.unique)

                    database_state.add_index(
                        index_name=new_index_name,
                        table_name=new_db_table,
                        columns=[field.column],
                        unique=field.unique)

            for field_names in meta.unique_together:
                fields = evolver.get_fields_for_names(model, field_names)
                new_index_name = create_index_name(
                    db_connection,
                    new_db_table,
                    field_names=field_names,
                    unique=True)

                database_state.add_index(
                    index_name=new_index_name,
                    table_name=new_db_table,
                    columns=[field.column for field in fields],
                    unique=True)

            for field_names in getattr(meta, 'index_together', []):
                # Django >= 1.5
                fields = evolver.get_fields_for_names(model, field_names)
                new_index_name = create_index_together_name(
                    db_connection,
                    new_db_table,
                    field_names=[field.name for field in fields])

                database_state.add_index(
                    index_name=new_index_name,
                    table_name=new_db_table,
                    columns=[field.column for field in fields])

            if getattr(meta, 'indexes', None):
                # Django >= 1.11
                for index, orig_index in zip(meta.indexes,
                                             meta.original_attrs['indexes']):
                    if not orig_index.name:
                        # The name was auto-generated. We'll need to generate
                        # it again for the new table name.
                        index.set_name_with_model(model)

                    fields = evolver.get_fields_for_names(
                        model, index.fields, allow_sort_prefixes=True)
                    database_state.add_index(
                        index_name=index.name,
                        table_name=new_db_table,
                        columns=[field.column for field in fields])

        # ManyToManyFields have their own tables, which will also need to be
        # renamed. Go through each of them and figure out what changes need
        # to be made.
        for field in meta.local_many_to_many:
            through = get_remote_field(field).through

            if not through:
                continue

            through_meta = through._meta
            through_orig_model_name = get_model_name(through)
            through_new_model_name = through_orig_model_name

            # Find out if the through table name is a custom table name, or
            # one generated by Django.
            generated_db_table = truncate_name(
                '%s_%s' % (orig_db_table, field.name),
                max_name_length)

            if through_meta.db_table == generated_db_table:
                # This is an auto-generated table name. Start changing the
                # state for it.
                assert through_meta.app_label == orig_app_label
                through_meta.app_label = new_app_label

                # Transform the 'through' table information only if we've
                # transformed the parent db_table.
                if new_db_table != orig_db_table:
                    through_meta.db_table = truncate_name(
                        '%s_%s' % (new_db_table, field.name),
                        max_name_length)

                    through_meta.object_name = \
                        through_meta.object_name.replace(orig_object_name,
                                                         new_object_name)

                    through_new_model_name = \
                        through_orig_model_name.replace(orig_model_name,
                                                        new_model_name)
                    set_model_name(through, through_new_model_name)

            # Change each of the columns for the fields on the
            # ManyToManyField's model to reflect the new model names.
            for through_field in through._meta.local_fields:
                through_remote_field = get_remote_field(through_field)

                if (through_remote_field and
                    get_remote_field_model(through_remote_field)):
                    column = through_field.column

                    if (column.startswith((orig_model_name,
                                           'to_%s' % orig_model_name,
                                           'from_%s' % orig_model_name))):
                        # This is a field that references one end of the
                        # relation or another. Update the model naem in the
                        # field's column.
                        through_field.column = column.replace(orig_model_name,
                                                              new_model_name)

            # Replace the entry in the models cache for the through table,
            # removing the old name and adding the new one.
            if through_orig_model_name in all_models[orig_app_label]:
                unregister_app_model(orig_app_label, through_orig_model_name)

            app_cache[through_new_model_name] = through
            register_app_models(new_app_label,
                                [(through_new_model_name, through)])

        # Unregister with the old model name and register the new one.
        if orig_model_name in all_models[orig_app_label]:
            unregister_app_model(orig_app_label, orig_model_name)

        register_app_models(new_app_label, [(new_model_name, model)])
        app_cache[new_model_name] = model

    # If the app hasn't yet been registered, do that now.
    if not is_app_registered(app):
        register_app(new_app_label, app)

    return app_cache