Exemplo n.º 1
0
    def get_hint_params(self):
        """Return parameters for the mutation's hinted evolution.

        Returns:
            list of unicode:
            A list of parameter strings to pass to the mutation's constructor
            in a hinted evolution.
        """
        if self.prop_name in ('index_together', 'unique_together'):
            # Make sure these always appear as lists and not tuples, for
            # compatibility.
            norm_value = list(self.new_value)
        elif self.prop_name == 'constraints':
            # Django >= 2.2
            norm_value = [
                OrderedDict(sorted(six.iteritems(constraint_data),
                                   key=lambda pair: pair[0]))
                for constraint_data in self.new_value
            ]
        elif self.prop_name == 'indexes':
            # Django >= 1.11
            norm_value = [
                OrderedDict(sorted(six.iteritems(index_data),
                                   key=lambda pair: pair[0]))
                for index_data in self.new_value
            ]
        else:
            norm_value = self.new_value

        return [
            self.serialize_value(self.model_name),
            self.serialize_value(self.prop_name),
            self.serialize_value(norm_value),
        ]
Exemplo n.º 2
0
def replace_models(database_state, apps_to_models):
    """Temporarily replace existing models with new definitions.

    This allows a unit test to replace some previously-defined models, backing
    up the old ones and restoring them once the operations are complete.

    Args:
        database_state (django_evolution.db.state.DatabaseState):
            The currently-computed database state.

        apps_to_models (dict):
            A mapping of app labels to lists of tuples. Each tuple contains:

            1. The model class name to register (which can be different from
               the real class name).
            2. The model class.

    Context:
        The replaced models will be available in Django.
    """
    old_models_info = []

    for app_label, app_models in six.iteritems(apps_to_models):
        app = get_app(app_label)
        old_models_info.append((app_label, get_models(app)))

        # Needed in Django 1.6 to ensure the model isn't filtered out in
        # our own get_models() calls.
        for name, app_model in app_models:
            app_model.__module__ = app.__name__

    try:
        for app_label, app_models in six.iteritems(apps_to_models):
            register_models(database_state=database_state,
                            models=app_models,
                            new_app_label=app_label)

        yield
    finally:
        unregister_test_models()

        for app_label, old_models in old_models_info:
            register_app_models(
                app_label,
                [
                    (model._meta.object_name.lower(), model)
                    for model in old_models
                ],
                reset=True)
Exemplo n.º 3
0
    def __iter__(self):
        """Iterate through the list.

        Entries are sorted first by app label, alphabetically, and then
        the order in which migrations were added for that app label.

        Yields:
            info:
            A dictionary containing the following keys:

            ``app_label`` (:py:class:`unicode`):
                The app label for the migration.

            ``name`` (:py:class:`unicode`):
                The name of the migration.

            ``migration`` (:py:class:`django.db.migrations.Migration`):
                The optional migration instance.

            ``recorded_migration`` (:py:class:`django.db.migrations.recorder.MigrationRecorder.Migration`):
                The optional recorded migration.
        """
        for app_label, info_list in sorted(six.iteritems(self._by_app_label),
                                           key=lambda pair: pair[0]):
            for info in info_list:
                yield info
Exemplo n.º 4
0
    def serialize_to_python(cls, value):
        """Serialize an object to a Python code string.

        This will generate code that constructs an instance of the object.

        Args:
            value (object):
                The object to serialize.

        Returns:
            unicode:
            The resulting Python code.
        """
        cls_path, args, kwargs = cls._deconstruct_object(value)
        module_path, cls_name = cls_path.rsplit('.', 1)

        if cls_path.startswith('django.db.models'):
            cls_name = 'models.%s' % cls_name

        all_args = []

        if args:
            all_args += [serialize_to_python(_arg) for _arg in args]

        if kwargs:
            all_args += [
                '%s=%s' % (_key, serialize_to_python(_value))
                for _key, _value in sorted(six.iteritems(kwargs),
                                           key=lambda pair: pair[0])
            ]

        return '%s(%s)' % (cls_name, ', '.join(all_args))
Exemplo n.º 5
0
    def rescan_tables(self):
        """Rescan the list of tables from the database.

        This will look up all tables found in the database, along with
        information (such as indexes) on those tables.

        Existing information on the tables will be flushed.
        """
        evolver = EvolutionOperationsMulti(self.db_name).get_evolver()
        connection = evolver.connection
        introspection = connection.introspection
        cursor = connection.cursor()

        for table_name in introspection.get_table_list(cursor):
            # NOTE: The table names are already normalized, so there's no
            #       need to normalize them again.
            if hasattr(table_name, 'name'):
                # In Django >= 1.7, we get back TableInfo namedtuples,
                # which have 'name' and 'type' keys. We don't care about
                # anything but 'name'.
                table_name = table_name.name

            if self.has_table(table_name):
                self.clear_indexes(table_name)
            else:
                self.add_table(table_name)

            constraints = evolver.get_constraints_for_table(table_name)

            for constraint_name, constraint_info in six.iteritems(constraints):
                self.add_index(table_name=table_name,
                               index_name=constraint_name,
                               columns=constraint_info['columns'],
                               unique=constraint_info['unique'])
Exemplo n.º 6
0
    def _get_changed_field_attrs(self, old_field_sig):
        """Return the attributes that have changed.

        Version Added:
            2.2

        Args:
            old_field_sig (django_evolution.signature.FieldSignature):
                The signature of the old field, before any changes are
                applied.

        Returns:
            dict:
            A mapping of attribute names to a field change dictionary with
            the following keys:

            * ``old_value``: The value in the old field signature.
            * ``new_value``: The new value provided to the mutation.
        """
        changed_field_attrs = {}

        for attr_name, attr_value in six.iteritems(self.field_attrs):
            old_attr_value = old_field_sig.get_attr_value(attr_name)

            # Avoid useless SQL commands if nothing has changed.
            if old_attr_value != attr_value:
                changed_field_attrs[attr_name] = {
                    'old_value': old_attr_value,
                    'new_value': attr_value,
                }

        return changed_field_attrs
Exemplo n.º 7
0
    def _deserialize_deconstructed(cls, payload):
        """Deserialize a deconstructed object payload.

        Args:
            payload (dict):
                The payload representing a deconstructed object.

        Returns:
            tuple:
            A tuple containing:

            1. The object class.
            2. Positional arguments to pass to the constructor.
            3. Keyword arguments to pass to the constructor,
        """
        cls_path = payload['type']
        cls_module, cls_name = cls_path.rsplit('.', 1)

        try:
            cls_type = getattr(import_module(cls_module), cls_name)
        except (AttributeError, ImportError):
            raise ImportError('Unable to locate value type %s' % cls_path)

        args = tuple(
            deserialize_from_signature(_arg_value)
            for _arg_value in payload['args'])

        kwargs = {
            _key: deserialize_from_signature(_arg_value)
            for _key, _arg_value in six.iteritems(payload['kwargs'])
        }

        return cls_type, args, kwargs
Exemplo n.º 8
0
    def fail(self, error, **error_vars):
        """Fail the simulation.

        This will end up raising a
        :py:class:`~django_evolution.errors.SimulationFailure` with an error
        message based on the mutation's simulation failed message an the
        provided message.

        Args:
            error (unicode):
                The error message for this particular failure.

            **error_vars (dict):
                Variables to include in the error message. These will
                override any defaults for the mutation's error.

        Raises:
            django_evolution.errors.SimulationFailure:
                The resulting simulation failure with the given error.
        """
        msg = '%s %s' % (self.mutation.simulation_failure_error, error)

        error_dict = {
            'app_label': self.app_label,
        }
        error_dict.update(
            (key, getattr(self.mutation, value))
            for key, value in six.iteritems(self.mutation.error_vars))
        error_dict.update(error_vars)

        raise SimulationFailure(msg % error_dict)
Exemplo n.º 9
0
    def serialize_to_python(cls, value):
        """Serialize a dictionary to a Python code string.

        Args:
            value (dict):
                The dictionary to serialize.

        Returns:
            unicode:
            The resulting Python code.
        """
        if isinstance(value, OrderedDict):
            items = six.iteritems(value)
        else:
            items = sorted(six.iteritems(value), key=lambda pair: pair[0])

        return '{%s}' % ', '.join(
            '%s: %s' % (serialize_to_python(_key), serialize_to_python(_value))
            for _key, _value in items)
Exemplo n.º 10
0
def merge_dicts(dest, source):
    """Merge two dictionaries together.

    This will recursively merge a source dictionary into a destination
    dictionary with the following rules:

    * Any keys in the source that aren't in the destination will be placed
      directly to the destination (using the same instance of the value, not
      a copy).
    * Any lists that are in both the source and destination will be combined
      by appending the source list to the destinataion list (and this will not
      recurse into lists).
    * Any dictionaries that are in both the source and destinataion will be
      merged using this function.
    * Any keys that are not a list or dictionary that exist in both
      dictionaries will result in a :py:exc:`TypeError`.

    Version Added:
        2.1

    Args:
        dest (dict):
            The destination dictionary to merge into.

        source (dict):
            The source dictionary to merge into the destination.

    Raises:
        TypeError:
            A key was present in both dictionaries with a type that could not
            be merged.
    """
    for key, value in six.iteritems(source):
        if key in dest:
            if isinstance(value, list):
                if not isinstance(dest[key], list):
                    raise TypeError(
                        'Cannot merge a list into a %r for key "%s".' %
                        (type(dest[key]), key))

                dest[key] += value
            elif isinstance(value, dict):
                if not isinstance(dest[key], dict):
                    raise TypeError(
                        'Cannot merge a dictionary into a %r for key "%s".' %
                        (type(dest[key]), key))

                merge_dicts(dest[key], value)
            else:
                raise TypeError('Key "%s" was not an expected type (found %r) '
                                'when merging dictionaries.' %
                                (key, type(value)))
        else:
            dest[key] = value
Exemplo n.º 11
0
def sql_create_for_many_to_many_field(connection, model, field):
    """Return SQL statements for creating a ManyToManyField's table.

    This provides compatibility with all supported versions of Django.

    Args:
        connection (object):
            The database connection.

        model (django.db.models.Model):
            The model for the ManyToManyField's relations.

        field (django.db.models.ManyToManyField):
            The field setting up the many-to-many relation.

    Returns:
        list:
        The list of SQL statements for creating the table and constraints.
    """
    through = get_remote_field(field).through

    if BaseDatabaseSchemaEditor:
        # Django >= 1.7
        with connection.schema_editor(collect_sql=True) as schema_editor:
            schema_editor.create_model(through)

        return schema_editor.collected_sql
    else:
        # Django < 1.7
        style = color.no_style()

        if through:
            references = {}
            pending_references = {}

            sql, references = connection.creation.sql_create_model(
                through, style)

            # Sort the list, in order to create consistency in the order of
            # ALTER TABLEs. This is primarily needed for unit tests.
            for refto, refs in sorted(six.iteritems(references),
                                      key=lambda i: repr(i)):
                pending_references.setdefault(refto, []).extend(refs)
                sql.extend(
                    sql_add_constraints(connection, refto, pending_references))

            sql.extend(
                sql_add_constraints(connection, through, pending_references))
        else:
            sql = connection.creation.sql_for_many_to_many_field(
                model, field, style)

        return sql
Exemplo n.º 12
0
    def _prepare_tasks(self):
        """Prepare all queued tasks for further operations.

        Once prepared, no new tasks can be added. This will be done before
        performing any operations requiring state from queued tasks.
        """
        if not self._tasks_prepared:
            self._tasks_prepared = True

            for task_cls, tasks in six.iteritems(self._tasks_by_class):
                task_cls.prepare_tasks(evolver=self,
                                       tasks=tasks,
                                       hinted=self.hinted)
Exemplo n.º 13
0
    def deserialize_from_deconstructed(cls, type_cls, args, kwargs):
        """Deserialize an object from deconstructed object information.

        Args:
            type_cls (type):
                The type of object to construct.

            args (tuple):
                The positional arguments passed to the constructor.

            kwargs (dict):
                The keyword arguments passed to the constructor.

        Returns:
            object:
            The resulting object.
        """
        norm_keywords = six.PY2

        negated = kwargs.pop('_negated', False)
        connector = kwargs.pop('_connector', Q.default)

        new_args = []

        for arg in args:
            if isinstance(arg, (list, tuple)):
                if norm_keywords:
                    # On Python 2, keyword arguments should be native strings.
                    # This isn't a problem for general usage, but it does
                    # affect the string representation, which assertQEqual()
                    # uses to determine equality.
                    arg = (arg[0].encode('utf-8'), arg[1])

                new_args.append(tuple(arg))
            else:
                new_args.append(arg)

        if norm_keywords:
            # We also need to normalize anything found in kwargs.
            kwargs = {
                str(_key): _value
                for _key, _value in six.iteritems(kwargs)
            }

        q = type_cls(*new_args, **kwargs)
        q.connector = connector

        if negated:
            q.negate()

        return q
Exemplo n.º 14
0
    def serialize_to_signature(cls, value):
        """Serialize a dictionary to JSON-compatible signature data.

        Args:
            value (dict):
                The dictionary to serialize.

        Returns:
            dict:
            The resulting dictionary.
        """
        return {
            _key: serialize_to_signature(_value)
            for _key, _value in six.iteritems(value)
        }
Exemplo n.º 15
0
    def deserialize_from_signature(cls, payload):
        """Deserialize dictionary signature data to a value.

        Args:
            payload (dict):
                The payload to deserialize.

        Returns:
            dict:
            The resulting value.
        """
        return {
            _key: deserialize_from_signature(_value)
            for _key, _value in six.iteritems(payload)
        }
Exemplo n.º 16
0
    def evolve(self):
        """Perform the evolution.

        This will run through all queued tasks and attempt to apply them in
        a database transaction, tracking each new batch of evolutions as the
        tasks finish.

        This can only be called once per evolver instance.

        Raises:
            django_evolution.errors.EvolutionException:
                Something went wrong during the evolution process. Details
                are in the error message. Note that a more specific exception
                may be raised.

            django_evolution.errors.EvolutionExecutionError:
                A specific evolution task failed. Details are in the error.
        """
        if self.evolved:
            raise EvolutionException(
                _('Evolver.evolve() has already been run once. It cannot be '
                  'run again.'))

        self._prepare_tasks()

        evolving.send(sender=self)

        try:
            new_evolutions = []

            for task_cls, tasks in six.iteritems(self._tasks_by_class):
                # Perform the evolution for the app. This is responsible
                # for raising any exceptions.
                task_cls.execute_tasks(evolver=self, tasks=tasks)

                for task in tasks:
                    new_evolutions += task.new_evolutions

                # Things may have changed, so rescan the database.
                self.database_state.rescan_tables()

            self._save_project_sig(new_evolutions=new_evolutions)
            self.evolved = True
        except Exception as e:
            evolving_failed.send(sender=self, exception=e)
            raise

        evolved.send(sender=self)
Exemplo n.º 17
0
    def __init__(self, conflicts):
        """Initialize the error.

        Args:
            conflicts (dict):
                A dictionary of conflicts, provided by the migrations system.
        """
        # Note that we're using the same error message that Django's migrate
        # command uses.
        super(MigrationConflictsError, self).__init__(
            "Conflicting migrations detected; multiple leaf nodes "
            "in the migration graph: (%s).\n"
            "To fix them run 'python manage.py makemigrations "
            "--merge'" % '; '.join(
                '%s in %s' % (', '.join(sorted(conflict_names)), app_label)
                for app_label, conflict_names in six.iteritems(conflicts)))
Exemplo n.º 18
0
    def get_hint_params(self):
        """Return parameters for the mutation's hinted evolution.

        Returns:
            list of unicode:
            A list of parameter strings to pass to the mutation's constructor
            in a hinted evolution.
        """
        params = [
            self.serialize_attr(key, value)
            for key, value in six.iteritems(self.field_attrs)
        ]

        if self.initial is not None:
            params.append(self.serialize_attr('initial', self.initial))

        return [
            self.serialize_value(self.model_name),
            self.serialize_value(self.field_name),
            self.serialize_value(self.field_type),
        ] + sorted(params)
Exemplo n.º 19
0
    def to_sql(self):
        """Return a list of SQL statements for the table rebuild.

        Any :py:attr:`alter_table` operations will be collapsed together into
        a single table rebuild.

        Returns:
            list of unicode:
            The list of SQL statements to run for the rebuild.
        """
        evolver = self.evolver
        model = self.model
        connection = evolver.connection
        qn = connection.ops.quote_name
        table_name = model._meta.db_table

        # Calculate some state for the rebuild operations, based on the
        # Alter Table ops that were provided.
        added_fields = []
        deleted_columns = set()
        renamed_columns = {}
        replaced_fields = {}
        added_constraints = []
        new_initial = {}
        reffed_renamed_cols = []
        added_field_db_indexes = []
        dropped_field_db_indexes = []
        needs_rebuild = False
        sql = []

        for item in self.alter_table:
            op = item['op']

            if op == 'ADD COLUMN':
                needs_rebuild = True
                field = item['field']

                if field.db_type(connection=connection) is not None:
                    initial = item['initial']

                    added_fields.append(field)

                    if initial is not None:
                        new_initial[field.column] = initial
            elif op == 'DELETE COLUMN':
                needs_rebuild = True
                deleted_columns.add(item['column'])
            elif op == 'RENAME COLUMN':
                needs_rebuild = True
                old_field = item['old_field']
                new_field = item['new_field']
                old_column = old_field.column
                new_column = new_field.column

                renamed_columns[old_column] = new_field.column
                replaced_fields[old_column] = new_field

                if evolver.is_column_referenced(table_name, old_column):
                    reffed_renamed_cols.append((old_column, new_column))
            elif op == 'MODIFY COLUMN':
                needs_rebuild = True
                field = item['field']
                initial = item['initial']

                replaced_fields[field.column] = field

                if initial is not None:
                    new_initial[field.column] = initial
            elif op == 'CHANGE COLUMN TYPE':
                needs_rebuild = True
                old_field = item['old_field']
                new_field = item['new_field']
                column = old_field.column

                replaced_fields[column] = new_field
            elif op == 'ADD CONSTRAINTS':
                needs_rebuild = True
                added_constraints = item['constraints']
            elif op == 'REBUILD':
                # We're just rebuilding, not changing anything about it.
                # This is used to get rid of auto-indexes from SQLite.
                needs_rebuild = True
            elif op == 'ADD DB INDEX':
                added_field_db_indexes.append(item['field'])
            elif op == 'DROP DB INDEX':
                dropped_field_db_indexes.append(item['field'])
            else:
                raise ValueError(
                    '%s is not a valid Alter Table op for SQLite' % op)

        for field in dropped_field_db_indexes:
            sql += self.normalize_sql(evolver.drop_index(model, field))

        if not needs_rebuild:
            # We don't have any operations requiring a full table rebuild.
            # We may have indexes to add (which would normally be added
            # along with the rebuild).
            for field in added_field_db_indexes:
                sql += self.normalize_sql(evolver.create_index(model, field))

            return self.pre_sql + self.sql + sql + self.post_sql

        # Remove any Generic Fields.
        old_fields = [
            _field for _field in model._meta.local_fields
            if _field.db_type(connection=connection) is not None
        ]

        new_fields = [
            replaced_fields.get(_field.column, _field)
            for _field in old_fields + added_fields
            if _field.column not in deleted_columns
        ]

        field_values = OrderedDict()

        for field in old_fields:
            old_column = field.column

            if old_column not in deleted_columns:
                new_column = renamed_columns.get(old_column, old_column)

                field_values[new_column] = qn(old_column)

        field_initials = []

        # If we have any new fields, add their defaults.
        if new_initial:
            for column, initial in six.iteritems(new_initial):
                # Note that initial will only be None if null=True. Otherwise,
                # it will be set to a user-defined callable or the default
                # AddFieldInitialCallback, which will raise an exception in
                # common code before we get too much further.
                if initial is not None:
                    initial, embed_initial = evolver.normalize_initial(initial)

                    if embed_initial:
                        field_values[column] = initial
                    else:
                        field_initials.append(initial)

                        if column in field_values:
                            field_values[column] = \
                                'coalesce(%s, %%s)' % qn(column)
                        else:
                            field_values[column] = '%s'

        # The SQLite documentation defines the steps that should be taken to
        # safely alter the schema for a table. Unlike most types of databases,
        # SQLite doesn't provide a general ALTER TABLE that can modify any
        # part of the table, so for most things, we require a full table
        # rebuild, and it must be done correctly.
        #
        # Step 1: Create a temporary table representing the new table
        #         schema. This will be temporary, and we don't need to worry
        #         about any indexes yet. Later, this will become the new
        #         table.
        columns_sql = []
        columns_sql_params = []

        for field in new_fields:
            if not isinstance(field, models.ManyToManyField):
                schema = evolver.build_column_schema(model=model, field=field)

                columns_sql.append('%s %s %s' %
                                   (qn(schema['name']), schema['db_type'],
                                    ' '.join(schema['definition'])))
                columns_sql_params += schema['definition_sql_params']

        constraints_sql = []

        if added_constraints:
            # Django >= 2.2
            with connection.schema_editor(collect_sql=True) as schema_editor:
                for constraint in added_constraints:
                    constraint_sql = constraint.constraint_sql(
                        model, schema_editor)

                    if constraint_sql:
                        constraints_sql.append(constraint_sql)

        sql.append((
            'CREATE TABLE %s (%s);' %
            (qn(TEMP_TABLE_NAME), ', '.join(columns_sql + constraints_sql)),
            tuple(columns_sql_params),
        ))

        # Step 2: Copy over any data from the old table into the new one.
        sql.append(('INSERT INTO %s (%s) SELECT %s FROM %s;' % (
            qn(TEMP_TABLE_NAME),
            ', '.join(qn(column) for column in six.iterkeys(field_values)),
            ', '.join(
                six.text_type(_value)
                for _value in six.itervalues(field_values)),
            qn(table_name),
        ), tuple(field_initials)))

        # Step 3: Drop the old table, making room for us to recreate the
        #         new schema table in its place.
        sql += evolver.delete_table(table_name).to_sql()

        # Step 4: Move over the temp table to the destination table name.
        sql += evolver.rename_table(model=model,
                                    old_db_table=TEMP_TABLE_NAME,
                                    new_db_table=table_name).to_sql()

        # Step 5: Restore any indexes.
        class _Model(object):
            class _meta(object):
                db_table = table_name
                local_fields = new_fields
                db_tablespace = None
                managed = True
                proxy = False
                swapped = False
                index_together = []
                indexes = []

        sql += sql_indexes_for_model(connection, _Model)

        # We've added all the indexes above. Any that were already there
        # will be in the database state. However, if we've *specifically*
        # had requests to add indexes, those ones won't be. We'll need to
        # add them now.
        #
        # The easiest way is to use the same SQL generation functions we'd
        # normally use to generate per-field indexes, since those track
        # database state. We won't actually use the SQL.
        for field in added_field_db_indexes:
            evolver.create_index(model, field)

        if reffed_renamed_cols:
            # One or more tables referenced one or more renamed columns on
            # this table, so now we need to update them.
            #
            # There are issues with renaming columns referenced by a foreign
            # key in SQLite. Historically, we've allowed it, but the reality
            # is that it can result in those foreign keys pointing to the
            # wrong (old) column, causing any foreign key reference checks to
            # fail. This is noticeable with Django 2.2+, which explicitly
            # checks in its schema editor (which we invoke).
            #
            # We don't actually want or need to do a table rebuild on these.
            # SQLite has another trick (and this is recommended in their
            # documentation). We want to go through each of the tables that
            # reference these columns and rewrite their table creation SQL
            # in the sqlite_master table, and then tell SQLite to apply the
            # new schema.
            #
            # This requires that we enable writable schemas and bump up the
            # SQLite schema version for this database. This must be done at
            # the moment we want to run this SQL statement, so we'll be
            # adding this as a dynamic function to run later, rather than
            # hard-coding any SQL now.
            #
            # Most of this can be done in a transaction, but not all. We have
            # to execute much of this in its own transaction, and then write
            # the new schema to disk with a VACUUM outside of a transaction.
            def _update_refs(cursor):
                schema_version = \
                    cursor.execute('PRAGMA schema_version').fetchone()[0]

                refs_template = ' REFERENCES "%s" ("%%s") ' % table_name

                return [
                    NewTransactionSQL([
                        # Allow us to update the database schema by
                        # manipulating the sqlite_master table.
                        'PRAGMA writable_schema = 1;',
                    ] + [
                        # Update all tables that reference any renamed
                        # columns, setting their references to point to
                        # the new names.
                        ('UPDATE sqlite_master SET sql ='
                         ' replace(sql, %s, %s);',
                         (refs_template % old_column,
                          refs_template % new_column))
                        for old_column, new_column in reffed_renamed_cols
                    ] + [
                        # Tell SQLite that we're done writing the schema,
                        # and give it a new schema version number.
                        ('PRAGMA schema_version = %s;' % (schema_version + 1)),
                        'PRAGMA writable_schema = 0;',

                        # Make sure everything went well. We want to bail
                        # here before we commit the transaction if
                        # anything goes wrong.
                        'PRAGMA integrity_check;',
                    ]),
                    NoTransactionSQL(['VACUUM;']),
                ]

            sql.append(_update_refs)

        return self.pre_sql + sql + self.sql + self.post_sql
Exemplo n.º 20
0
def sql_create_models(models,
                      tables=None,
                      db_name=None,
                      return_deferred=False):
    """Return SQL statements for creating a list of models.

    This provides compatibility with all supported versions of Django.

    It's recommended that callers include auto-created models in the list,
    to ensure all references are correct.

    Version Changed:
        2.2:
        Added the ``return_deferred` argument.

    Args:
        models (list of type):
            The list of :py:class:`~django.db.models.Model` subclasses.

        tables (list of unicode, optional):
            A list of existing table names from the database. If not provided,
            this will be introspected from the database.

        db_name (str, optional):
            The database connection name. Defaults to the default database
            connection.

        return_deferred (bool, optional):
            Whether to return any deferred SQL separately from the model
            creation SQL. If ``True``, the return type will change to a tuple.

    Returns:
        list or tuple:
        If ``return_deferred=False`` (the default), this will be a list of
        SQL statements used to create the models for the app.

        If ``return_deferred=True``, this will be a 2-tuple in the form of
        ``(list_of_sql, list_of_deferred_sql)``.
    """
    connection = connections[db_name or DEFAULT_DB_ALIAS]

    if BaseDatabaseSchemaEditor:
        # Django >= 1.7
        with connection.schema_editor(collect_sql=True) as schema_editor:
            for model in models:
                schema_editor.create_model(model)

            if return_deferred:
                collected_sql = list(schema_editor.collected_sql)
                deferred_sql = [
                    '%s;' % _statement
                    for _statement in schema_editor.deferred_sql
                ]

                return collected_sql, deferred_sql

        return schema_editor.collected_sql
    else:
        # Django < 1.7
        creation = connection.creation
        style = color.no_style()
        pending_references = {}

        if tables is None:
            tables = connection.introspection.table_names()

        seen_models = connection.introspection.installed_models(tables)
        sql = []
        deferred_sql = []

        for model in models:
            model_sql, references = creation.sql_create_model(
                model, style, seen_models)
            seen_models.add(model)

            sql += model_sql

            for ref_to, refs in six.iteritems(references):
                pending_references.setdefault(ref_to, []).extend(refs)

                if ref_to in seen_models:
                    deferred_sql += creation.sql_for_pending_references(
                        ref_to, style, pending_references)

            deferred_sql += creation.sql_for_pending_references(
                model, style, pending_references)

        for model in models:
            deferred_sql += creation.sql_indexes_for_model(model, style)

        if return_deferred:
            return sql, deferred_sql
        else:
            return sql + deferred_sql
Exemplo n.º 21
0
    def __str__(self):
        """Return a string description of the diff.

        This will describe the changes found in the diff, for human
        consumption.

        Returns:
            unicode:
            The string representation of the diff.
        """
        lines = [
            'The application %s has been deleted' % app_label
            for app_label in self.deleted
        ]

        for app_label, app_changes in six.iteritems(self.changed):
            lines += [
                'The model %s.%s has been deleted' % (app_label, model_name)
                for model_name in app_changes.get('deleted', {})
            ]

            app_meta_changed = app_changes.get('meta_changed', {})

            if app_meta_changed:
                lines.append('In app %s:' % app_label)

                if ('app_id' in app_meta_changed or
                    'legacy_app_label' in app_meta_changed):
                    lines.append('    App label has changed')

                if 'upgrade_method' in app_meta_changed:
                    lines.append('    Schema upgrade method changed')

            app_changed = app_changes.get('changed', {})

            for model_name, change in six.iteritems(app_changed):
                lines.append('In model %s.%s:' % (app_label, model_name))
                lines += [
                    "    Field '%s' has been added" % field_name
                    for field_name in change.get('added', [])
                ] + [
                    "    Field '%s' has been deleted" % field_name
                    for field_name in change.get('deleted', [])
                ]

                changed = change.get('changed', {})

                for field_name, field_change in six.iteritems(changed):
                    lines.append("    In field '%s':" % field_name)

                    if 'field_type' in field_change:
                        # This is the only change that matters. We don't
                        # want potentially unrelated attributes to be shown.
                        field_change = ['field_type']

                    lines += [
                        "        Property '%s' has changed" % prop
                        for prop in field_change
                    ]

                lines += [
                    "    Meta property '%s' has changed" % prop_name
                    for prop_name in change.get('meta_changed', [])
                ]

        return '\n'.join(lines)
Exemplo n.º 22
0
    def evolution(self):
        """Return the mutations needed for resolving the diff.

        This will attempt to return a hinted evolution, consisting of a series
        of mutations for each affected application. These mutations will
        convert the database from the original to the target signatures.

        Returns:
            collections.OrderedDict:
            An ordered dictionary of mutations. Each key is an application
            label, and each value is a list of mutations for the application.
        """
        if self._mutations is not None:
            return self._mutations

        mutations = OrderedDict()

        for app_label, app_changes in six.iteritems(self.changed):
            app_sig = self.target_project_sig.get_app_sig(app_label)
            model_changes = app_changes.get('changed', {})
            app_mutations = []

            for model_name, model_change in six.iteritems(model_changes):
                model_sig = app_sig.get_model_sig(model_name)

                # Process the list of added fields for the model.
                for field_name in model_change.get('added', {}):
                    field_sig = model_sig.get_field_sig(field_name)
                    field_type = field_sig.field_type

                    add_params = field_sig.field_attrs.copy()
                    add_params['field_type'] = field_type

                    if (not issubclass(field_type, models.ManyToManyField) and
                        not field_sig.get_attr_value('null')):
                        # This field requires an initial value. Inject either
                        # a suitable initial value or a placeholder that must
                        # be filled in by the developer.
                        add_params['initial'] = \
                            self._get_initial_value(app_label=app_label,
                                                    model_name=model_name,
                                                    field_name=field_name)

                    if field_sig.related_model:
                        add_params['related_model'] = field_sig.related_model

                    app_mutations.append(AddField(
                        model_name=model_name,
                        field_name=field_name,
                        **add_params))

                # Process the list of deleted fields for the model.
                app_mutations += [
                    DeleteField(model_name=model_name,
                                field_name=field_name)
                    for field_name in model_change.get('deleted', [])
                ]

                # Process the list of changed fields for the model.
                field_changes = model_change.get('changed', {})

                for field_name, field_change in six.iteritems(field_changes):
                    field_sig = model_sig.get_field_sig(field_name)
                    changed_attrs = OrderedDict()

                    field_type_changed = 'field_type' in field_change

                    if field_type_changed:
                        # If the field type changes, we're doing a hard
                        # reset on the attributes. We won't be showing the
                        # difference between any other attributes on here.
                        changed_attrs['field_type'] = field_sig.field_type
                        changed_attrs.update(field_sig.field_attrs)
                    else:
                        changed_attrs.update(
                            (attr, field_sig.get_attr_value(attr))
                            for attr in field_change
                        )

                    if ('null' in field_change and
                        not field_sig.get_attr_value('null') and
                        not issubclass(field_sig.field_type,
                                       models.ManyToManyField)):
                        # The field no longer allows null values, meaning an
                        # initial value is required. Inject either a suitable
                        # initial value or a placeholder that must be filled
                        # in by the developer.
                        changed_attrs['initial'] = \
                            self._get_initial_value(app_label=app_label,
                                                    model_name=model_name,
                                                    field_name=field_name)

                    if 'related_model' in field_change:
                        changed_attrs['related_model'] = \
                            field_sig.related_model

                    app_mutations.append(ChangeField(
                        model_name=model_name,
                        field_name=field_name,
                        **changed_attrs))

                # Process the Meta attribute changes for the model.
                meta_changed = model_change.get('meta_changed', [])

                # Check if the Meta.constraints property has any changes.
                # They'll all be assembled into a single ChangeMeta.
                if 'constraints' in meta_changed:
                    app_mutations.append(ChangeMeta(
                        model_name=model_name,
                        prop_name='constraints',
                        new_value=[
                            dict({
                                'type': constraint_sig.type,
                                'name': constraint_sig.name,
                            }, **constraint_sig.attrs)
                            for constraint_sig in model_sig.constraint_sigs
                        ]))

                # Check if the Meta.indexes property has any changes.
                # They'll all be assembled into a single ChangeMeta.
                if 'indexes' in meta_changed:
                    change_meta_indexes = []

                    for index_sig in model_sig.index_sigs:
                        change_meta_index = index_sig.attrs.copy()

                        if index_sig.expressions:
                            change_meta_index['expressions'] = \
                                index_sig.expressions

                        if index_sig.fields:
                            change_meta_index['fields'] = index_sig.fields

                        if index_sig.name:
                            change_meta_index['name'] = index_sig.name

                        change_meta_indexes.append(change_meta_index)

                    app_mutations.append(ChangeMeta(
                        model_name=model_name,
                        prop_name='indexes',
                        new_value=change_meta_indexes))

                # Check Meta.index_together and Meta.unique_together.
                app_mutations += [
                    ChangeMeta(model_name=model_name,
                               prop_name=prop_name,
                               new_value=getattr(model_sig, prop_name) or [])
                    for prop_name in ('index_together', 'unique_together')
                    if prop_name in meta_changed
                ]

            # Process the list of deleted models for the application.
            app_mutations += [
                DeleteModel(model_name=model_name)
                for model_name in app_changes.get('deleted', {})
            ]

            # See if any important details about the app have changed.
            meta_changed = app_changes.get('meta_changed', {})
            app_label_changed = meta_changed.get('app_id', {})
            legacy_app_label_changed = meta_changed.get('legacy_app_label', {})

            if app_label_changed or legacy_app_label_changed:
                app_mutations.append(RenameAppLabel(
                    app_label_changed.get('old', app_sig.app_id),
                    app_label_changed.get('new', app_sig.app_id),
                    legacy_app_label=legacy_app_label_changed.get(
                        'new', app_sig.legacy_app_label)))

            if app_mutations:
                mutations[app_label] = app_mutations

        self._mutations = mutations

        return mutations
Exemplo n.º 23
0
    def perform_mutations(self,
                          evolutions,
                          end,
                          end_sig,
                          sql_name=None,
                          rescan_indexes=True,
                          db_name=None,
                          create_test_data_func=None):
        """Apply mutations that and verify the results.

        This will run through the evolution chain, applying each mutation
        on the database and against the signature, and then verifying the
        resulting signature and generated SQL.

        Args:
            evolutions (list of django_evolution.mutations.BaseMutation):
                The evolution chain to run simulations on.

            end (collections.OrderedDict):
                The expected model map at the end of the evolution. This
                is generated by :py:meth:`make_end_signatures`.

            end_sig (django_evolution.signature.ProjectSignature):
                The expected ending signature. This is generated by
                :py:meth:`make_end_signatures`.

            sql_name (unicode, optional):
                The name of the registered SQL content for the database being
                tested. If not provided, the SQL won't be compared.

            rescan_indexes (bool, optional):
                Whether to re-scan the list of table indexes after applying
                the mutations.

            db_name (unicode, optional):
                The name of the database to apply the evolutions against.

            create_test_data_func (callable, optional):
                A function to call in order to create data in the database for
                the initial models, before applying mutations. It must take
                a database name.

        Raises:
            AssertionError:
                The resulting SQL did not match.

            django.db.utils.OperationalError:
                There was an error executing SQL.
        """
        app_label = 'tests'

        def run_mutations():
            if rescan_indexes:
                self.test_database_state.rescan_tables()

            app_mutator = AppMutator(app_label=app_label,
                                     project_sig=test_sig,
                                     database_state=self.test_database_state,
                                     database=db_name)
            app_mutator.run_mutations(evolutions)

            return app_mutator.to_sql()

        db_name = db_name or self.default_database_name

        self.test_database_state = self.database_state.clone()
        test_sig = self.start_sig.clone()

        with ensure_test_db(model_entries=six.iteritems(self.start),
                            end_model_entries=six.iteritems(end),
                            app_label=app_label,
                            database=db_name):
            if create_test_data_func:
                create_test_data_func(db_name)

            sql = execute_test_sql(run_mutations(), database=db_name)

        if sql_name is not None:
            self.assertSQLMappingEqual(sql, sql_name, database=db_name)
Exemplo n.º 24
0
def get_evolution_dependencies(app, evolution_label, custom_evolutions=[]):
    """Return dependencies for an evolution.

    Evolutions can depend on other evolutions or migrations, and can be
    marked as being a dependency of them as well (forcing the evolution to
    apply before another evolution/migration).

    Dependencies are generally specified as a tuple of ``(app_label, name)``,
    where ``name`` is either a migration name or an evolution label.

    Dependencies on evolutions can also be specified as simply a string
    containing an app label, which will reference the sequence of evolutions
    as a whole for that app.

    Version Changed:
        2.2:
        Added the ``custom_evolutions`` argument.

    Version Added:
        2.1

    Args:
        app (module):
            The app the evolution is for.

        evolution_label (unicode):
            The label identifying the evolution for the app.

        custom_evolutions (list of dict, optional):
            An optional list of custom evolutions pertaining to the app, which
            will be searched if a module for ``evolution_label`` could not
            be found.

            Each item is a dictionary containing:

            Keys:
                label (unicode):
                    The evolution label (which ``evolution_label`` will be
                    compared against).

                after_evolutions (list, optional):
                    A list of evolutions that this would apply after. Each
                    item can be a string (indicating an evolution label within
                    this app) or a tuple in the form of:

                        ('app_label', 'evolution_label')

                after_migrations (list of tuple, optional):
                    A list of migrations that this would apply after. Each
                    item must be a tuple in the form of:

                        ('app_label', 'migration_name')

                after_evolutions (list, optional):
                    A list of evolutions that this would apply before. Each
                    item can be a string (indicating an evolution label within
                    this app) or a tuple in the form of:

                        ('app_label', 'evolution_label')

                after_migrations (list of tuple, optional):
                    A list of migrations that this would apply before. Each
                    item must be a tuple in the form of:

                        ('app_label', 'migration_name')

            Version Added:
                2.2

    Returns:
        dict:
        A dictionary of dependency information for the evolution. This has
        the following keys:

        * ``before_migrations``
        * ``after_migrations``
        * ``before_evolutions``
        * ``after_evolutions``

        If the evolution module was not found, this will return ``None``
        instead.
    """
    module = get_evolution_module(app=app, evolution_label=evolution_label)

    if not module:
        found = False

        if custom_evolutions:
            for custom_evolution in custom_evolutions:
                if custom_evolution['label'] == evolution_label:
                    found = True
                    break

        if not found:
            return None

    if module:
        deps = {
            'after_evolutions': set(getattr(module, 'AFTER_EVOLUTIONS', [])),
            'after_migrations': set(getattr(module, 'AFTER_MIGRATIONS', [])),
            'before_evolutions': set(getattr(module, 'BEFORE_EVOLUTIONS', [])),
            'before_migrations': set(getattr(module, 'BEFORE_MIGRATIONS', [])),
        }
        mutations = getattr(module, 'MUTATIONS', [])
    else:
        deps = {
            _key: set(custom_evolution.get(_key, []))
            for _key in ('after_evolutions', 'after_migrations',
                         'before_evolutions', 'before_migrations')
        }
        mutations = custom_evolution['mutations']

    app_label = get_app_label(app)

    # Check if any mutations have dependencies to inject.
    for mutation in mutations:
        mutation_deps = mutation.generate_dependencies(app_label=app_label)

        if mutation_deps:
            for key, value in six.iteritems(mutation_deps):
                assert key in deps, (
                    '"%s" is not a valid dependency key from mutation %r' %
                    (key, mutation))

                deps[key].update(value)

    return deps