def remove_field_constraints(self, field, opts, models, refs): """Return SQL for removing constraints on a field. Args: field (django.db.models.Field): The field the constraints will be removed from. opts (django.db.models.options.Options): The Meta class for the model. models (list of django.db.models.Model): A caller-provided list that will be populated with models that constraints will be removed from. refs (dict): A caller-supplied dictionary that will be populated with references that are removed. The keys are models, and the values are lists of tuples of many-to-many models and fields. Returns: list: The list of SQL statements for removing constraints on the field. """ sql = [] if self.supports_constraints and field.primary_key: for f in opts.local_many_to_many: remote_field = get_remote_field(f) if remote_field and remote_field.through: through = remote_field.through for m2m_f in through._meta.local_fields: remote_m2m_f = get_remote_field(m2m_f) if not remote_m2m_f: continue remote_m2m_f_model = \ get_remote_field_model(remote_m2m_f) if (remote_m2m_f.field_name == field.column and remote_m2m_f_model._meta.db_table == opts.db_table): models.append(remote_m2m_f_model) refs.setdefault(remote_m2m_f_model, []).append( (through, m2m_f)) remove_refs = refs.copy() for relto in models: sql.extend( sql_delete_constraints(self.connection, relto, remove_refs)) return sql
def rename_table(self, model, old_db_tablename, db_tablename): sql_result = SQLResult() if old_db_tablename == db_tablename: # No Operation return sql_result max_name_length = self.connection.ops.max_name_length() refs = {} models = [] for field in model._meta.local_many_to_many: remote_field = get_remote_field(field) if (remote_field and remote_field.through and remote_field.through._meta.db_table == old_db_tablename): through = remote_field.through for m2m_field in through._meta.local_fields: remote_m2m_field = get_remote_field(m2m_field) if remote_m2m_field: remote_m2m_field_model = get_remote_field_model( remote_m2m_field) if remote_m2m_field_model == model: models.append(remote_m2m_field_model) refs.setdefault(remote_m2m_field_model, []).append( (through, m2m_field)) remove_refs = refs.copy() if self.supports_constraints: for relto in models: sql_result.add_pre_sql( sql_delete_constraints(self.connection, relto, remove_refs)) sql_result.add( self.get_rename_table_sql(model, old_db_tablename, db_tablename)) for relto in models: for rel_class, f in refs[relto]: if rel_class._meta.db_table == old_db_tablename: rel_class._meta.db_table = db_tablename rel_class._meta.db_table = \ truncate_name(rel_class._meta.db_table, max_name_length) if self.supports_constraints: sql_result.add_post_sql( sql_add_constraints(self.connection, relto, refs)) return sql_result
def test_with_many_to_many_field_rel(self): """Testing get_field_is_relation with ManyToManyField relation""" rel = get_remote_field( CompatModelsTestModel._meta.get_field('m2m_field')) self.assertIsInstance(rel, related.ManyToManyRel) self.assertTrue(get_field_is_relation(rel))
def test_with_one_to_one_field_rel(self): """Testing get_remote_field_related_model with OneToOneField relation """ field = CompatModelsTestModel._meta.get_field('o2o_field') model = get_remote_field_related_model(get_remote_field(field)) self.assertIs(model, CompatModelsTestModel)
def test_with_foreign_key_rel(self): """Testing get_field_is_relation with ForeignKey relation""" rel = get_remote_field( CompatModelsTestModel._meta.get_field('fkey_field')) self.assertIsInstance(rel, related.ManyToOneRel) self.assertTrue(get_field_is_relation(rel))
def get_default_index_name(self, table_name, field): """Return a default index name for the database. This will return an index name for the given field that matches what the database or Django database backend would automatically generate when marking a field as indexed or unique. This can be overridden by subclasses if the database or Django database backend provides different values. Args: table_name (str): The name of the table for the index. field (django.db.models.Field): The field for the index. Returns: str: The name of the index. """ if (hasattr(self.connection, 'schema_editor') and get_remote_field(field) and field.db_constraint): # Django >= 1.7 target_field = get_rel_target_field(field) return self.connection.schema_editor()._create_index_name( field.model, [field.column], suffix='_fk_%s_%s' % (target_field.model._meta.db_table, target_field.column)) return super(EvolutionOperations, self).get_default_index_name(table_name, field)
def test_with_one_to_one_field_rel(self): """Testing get_field_is_relation with OneToOneField relation""" rel = get_remote_field( CompatModelsTestModel._meta.get_field('o2o_field')) self.assertIsInstance(rel, related.ManyToOneRel) self.assertTrue(get_field_is_relation(rel))
def test_with_many_to_many_field_rel(self): """Testing get_remote_field_related_model with ManyToManyField relation """ field = CompatModelsTestModel._meta.get_field('m2m_field') model = get_remote_field_related_model(get_remote_field(field)) self.assertIs(model, CompatModelsTestModel)
def create_table(self, table_name, field_list, temporary=False, create_index=True): qn = self.connection.ops.quote_name output = [] create = ['CREATE'] if temporary: create.append('TEMPORARY') create.append('TABLE %s' % qn(table_name)) output = [' '.join(create)] output.append('(') columns = [] for field in field_list: if type(field) is not models.ManyToManyField: column_name = qn(field.column) column_type = field.db_type(connection=self.connection) params = [column_name, column_type] # Always use null if this is a temporary table. It may be # used to create a new field (which will be null while data is # copied across from the old table). if temporary or field.null: params.append('NULL') else: params.append('NOT NULL') if field.unique: params.append('UNIQUE') if field.primary_key: params.append('PRIMARY KEY') if not temporary and isinstance(field, models.ForeignKey): remote_field = get_remote_field(field) remote_field_model = get_remote_field_model(remote_field) params.append( 'REFERENCES %s (%s) DEFERRABLE INITIALLY DEFERRED' % (qn(remote_field_model._meta.db_table), qn( remote_field_model._meta.get_field( remote_field.field_name).column))) columns.append(' '.join(params)) output.append(', '.join(columns)) output.append(');') output = [''.join(output)] if create_index: output.extend(self.create_indexes_for_table( table_name, field_list)) return output
def get_model_rel_tree(): """Return the full field relationship tree for all registered models. This will walk through every field in every model registered in Django, storing the relationships between objects, caching them. Each entry in the resulting dictionary will be a table mapping to a list of relation fields that point back at it. This can be used to quickly locate any and all reverse relations made to a field. This is similar to Django's built-in reverse relation tree used internally (with different implementations) in :py:class:`django.db.models.options.Options`, but works across all supported versions of Django, and supports cache clearing. Version Added: 2.2 Returns: dict: The model relation tree. """ global _rel_tree_cache if _rel_tree_cache is not None: return _rel_tree_cache rel_tree = defaultdict(list) all_models = get_models(include_auto_created=True) # We'll walk the entire model tree, looking for any immediate fields on # each model, building a mapping of models to fields that reference the # model. for cur_model in all_models: if cur_model._meta.abstract: continue for field in iter_model_fields(cur_model, include_parent_models=False, include_forward_fields=True, include_reverse_fields=False, include_hidden_fields=False): if (get_field_is_relation(field) and get_remote_field_related_model(field) is not None): remote_field = get_remote_field(field) remote_field_model = get_remote_field_model(remote_field) # Make sure this isn't a "self" relation or similar. if not isinstance(remote_field_model, six.string_types): db_table = \ remote_field_model._meta.concrete_model._meta.db_table rel_tree[db_table].append(field) _rel_tree_cache = rel_tree return rel_tree
def test_has_model_with_auto_created(self): """Testing DatabaseState.has_model with auto-created model""" model = get_remote_field(User._meta.get_field('groups')).through self.assertTrue(model._meta.auto_created) database_state = DatabaseState(db_name='default', scan=False) self.assertFalse(database_state.has_model(model)) database_state.rescan_tables() self.assertTrue(database_state.has_model(model))
def create_table(self, table_name, field_list, temporary=False, create_index=True): qn = self.connection.ops.quote_name output = [] create = ['CREATE'] if temporary: create.append('TEMPORARY') create.append('TABLE %s' % qn(table_name)) output = [' '.join(create)] output.append('(') columns = [] for field in field_list: if type(field) is not models.ManyToManyField: column_name = qn(field.column) column_type = field.db_type(connection=self.connection) params = [column_name, column_type] # Always use null if this is a temporary table. It may be # used to create a new field (which will be null while data is # copied across from the old table). if temporary or field.null: params.append('NULL') else: params.append('NOT NULL') if field.unique: params.append('UNIQUE') if field.primary_key: params.append('PRIMARY KEY') if not temporary and isinstance(field, models.ForeignKey): remote_field = get_remote_field(field) remote_field_model = get_remote_field_model(remote_field) params.append( 'REFERENCES %s (%s) DEFERRABLE INITIALLY DEFERRED' % (qn(remote_field_model._meta.db_table), qn(remote_field_model._meta.get_field( remote_field.field_name).column))) columns.append(' '.join(params)) output.append(', '.join(columns)) output.append(');') output = [''.join(output)] if create_index: output.extend( self.create_indexes_for_table(table_name, field_list)) return output
def sql_create_for_many_to_many_field(connection, model, field): """Return SQL statements for creating a ManyToManyField's table. This provides compatibility with all supported versions of Django. Args: connection (object): The database connection. model (django.db.models.Model): The model for the ManyToManyField's relations. field (django.db.models.ManyToManyField): The field setting up the many-to-many relation. Returns: list: The list of SQL statements for creating the table and constraints. """ through = get_remote_field(field).through if BaseDatabaseSchemaEditor: # Django >= 1.7 with connection.schema_editor(collect_sql=True) as schema_editor: schema_editor.create_model(through) return schema_editor.collected_sql else: # Django < 1.7 style = color.no_style() if through: references = {} pending_references = {} sql, references = connection.creation.sql_create_model( through, style) # Sort the list, in order to create consistency in the order of # ALTER TABLEs. This is primarily needed for unit tests. for refto, refs in sorted(six.iteritems(references), key=lambda i: repr(i)): pending_references.setdefault(refto, []).extend(refs) sql.extend(sql_add_constraints(connection, refto, pending_references)) sql.extend(sql_add_constraints(connection, through, pending_references)) else: sql = connection.creation.sql_for_many_to_many_field( model, field, style) return sql
def sql_create_for_many_to_many_field(connection, model, field): """Return SQL statements for creating a ManyToManyField's table. This provides compatibility with all supported versions of Django. Args: connection (object): The database connection. model (django.db.models.Model): The model for the ManyToManyField's relations. field (django.db.models.ManyToManyField): The field setting up the many-to-many relation. Returns: list: The list of SQL statements for creating the table and constraints. """ through = get_remote_field(field).through if BaseDatabaseSchemaEditor: # Django >= 1.7 with connection.schema_editor(collect_sql=True) as schema_editor: schema_editor.create_model(through) return schema_editor.collected_sql else: # Django < 1.7 style = color.no_style() if through: references = {} pending_references = {} sql, references = connection.creation.sql_create_model( through, style) # Sort the list, in order to create consistency in the order of # ALTER TABLEs. This is primarily needed for unit tests. for refto, refs in sorted(six.iteritems(references), key=lambda i: repr(i)): pending_references.setdefault(refto, []).extend(refs) sql.extend( sql_add_constraints(connection, refto, pending_references)) sql.extend( sql_add_constraints(connection, through, pending_references)) else: sql = connection.creation.sql_for_many_to_many_field( model, field, style) return sql
def _get_rename_column_sql(self, opts, old_field, new_field): qn = self.connection.ops.quote_name style = color.no_style() col_type = new_field.db_type(connection=self.connection) tablespace = new_field.db_tablespace or opts.db_tablespace alter_table_item = '' # Make the definition (e.g. 'foo VARCHAR(30)') for this field. field_output = [ style.SQL_FIELD(qn(new_field.column)), style.SQL_COLTYPE(col_type), style.SQL_KEYWORD('%sNULL' % (not new_field.null and 'NOT ' or '')), ] if new_field.primary_key: field_output.append(style.SQL_KEYWORD('PRIMARY KEY')) if new_field.unique: field_output.append(style.SQL_KEYWORD('UNIQUE')) if (tablespace and self.connection.features.supports_tablespaces and self.connection.features.autoindexes_primary_keys and (new_field.unique or new_field.primary_key)): # We must specify the index tablespace inline, because we # won't be generating a CREATE INDEX statement for this field. field_output.append( self.connection.ops.tablespace_sql(tablespace, inline=True)) new_remote_field = get_remote_field(new_field) if new_remote_field: new_remote_field_meta = \ get_remote_field_model(new_remote_field)._meta field_output.append('%s %s (%s)%s' % ( style.SQL_KEYWORD('REFERENCES'), style.SQL_TABLE(qn(new_remote_field_meta.db_table)), style.SQL_FIELD( qn( new_remote_field_meta.get_field( new_remote_field.field_name).column)), self.connection.ops.deferrable_sql(), )) if old_field.primary_key: alter_table_item = 'DROP PRIMARY KEY, ' alter_table_item += ('CHANGE COLUMN %s %s' % (qn(old_field.column), ' '.join(field_output))) return [{'sql': alter_table_item}]
def add_primary_key_field_constraints(self, old_field, new_field, models, refs): sql = [] if self.supports_constraints and old_field.primary_key: for relto in models: for rel_class, f in refs[relto]: get_remote_field(f).field_name = new_field.column del relto._meta._fields[old_field.name] relto._meta._fields[new_field.name] = new_field sql.extend(sql_add_constraints(self.connection, relto, refs)) return sql
def delete_column(self, model, f): sql_result = AlterTableSQLResult(self, model) remote_field = get_remote_field(f) if remote_field: remote_field_model = get_remote_field_model(remote_field) sql_result.add( sql_delete_constraints(self.connection, remote_field_model, {remote_field_model: [(model, f)]})) sql_result.add_sql( super(EvolutionOperations, self).delete_column(model, f)) return sql_result
def iter_non_m2m_reverse_relations(field): """Iterate through non-M2M reverse relations pointing to a field. This will exclude any :py:class:`~django.db.models.ManyToManyField`s, but will include the relation fields on their "through" tables. Note that this may return duplicate results, or multiple relations pointing to the same field. It's up to the caller to handle this. Version Added: 2.2 Args: field (django.db.models.Field): The field that relations must point to. Yields: django.db.models.Field or object: Each field or relation object pointing to this field. The type of the relation object depends on the version of Django. """ is_primary_key = field.primary_key field_name = field.name for rel in iter_model_fields(field.model, include_parent_models=True, include_forward_fields=False, include_reverse_fields=True, include_hidden_fields=True): rel_from_field = rel.field # Exclude any ManyToManyFields, and make sure the referencing fields # point directly to the ID on this field. if (not get_field_is_many_to_many(rel_from_field) and ((is_primary_key and rel_from_field.to_fields == [None]) or field_name in rel_from_field.to_fields)): yield rel # Now do the same for the fields on the model of the related field. other_rel_fields = iter_non_m2m_reverse_relations( get_remote_field(rel)) for rel2 in other_rel_fields: yield rel2
def add_column(self, model, f, initial): qn = self.connection.ops.quote_name sql_result = AlterTableSQLResult(self, model) table_name = model._meta.db_table remote_field = get_remote_field(f) if remote_field: # it is a foreign key field # NOT NULL REFERENCES "django_evolution_addbasemodel" # ("id") DEFERRABLE INITIALLY DEFERRED # ALTER TABLE <tablename> ADD COLUMN <column name> NULL # REFERENCES <tablename1> ("<colname>") DEFERRABLE INITIALLY # DEFERRED related_model = get_remote_field_model(remote_field) related_table = related_model._meta.db_table related_pk_col = related_model._meta.pk.name constraints = ['%sNULL' % (not f.null and 'NOT ' or '')] if f.unique or f.primary_key: constraints.append('UNIQUE') sql_result.add_alter_table([{ 'op': 'ADD COLUMN', 'column': f.column, 'db_type': f.db_type(connection=self.connection), 'params': constraints + [ 'REFERENCES', qn(related_table), '(%s)' % qn(related_pk_col), self.connection.ops.deferrable_sql(), ] }]) else: null_constraints = '%sNULL' % (not f.null and 'NOT ' or '') if f.unique or f.primary_key: unique_constraints = 'UNIQUE' else: unique_constraints = '' # At this point, initial can only be None if null=True, # otherwise it is a user callable or the default # AddFieldInitialCallback which will shortly raise an exception. if initial is not None: if callable(initial): sql_result.add_alter_table([{ 'op': 'ADD COLUMN', 'column': f.column, 'db_type': f.db_type(connection=self.connection), 'params': [unique_constraints], }]) sql_result.add_sql([ 'UPDATE %s SET %s = %s WHERE %s IS NULL;' % (qn(table_name), qn(f.column), initial(), qn(f.column)) ]) if not f.null: # Only put this sql statement if the column cannot # be null. sql_result.add_sql( self.set_field_null(model, f, f.null)) else: sql_result.add_alter_table([{ 'op': 'ADD COLUMN', 'column': f.column, 'db_type': f.db_type(connection=self.connection), 'params': [ null_constraints, unique_constraints, 'DEFAULT', '%s', ], 'sql_params': [initial] }]) # Django doesn't generate default columns, so now that # we've added one to get default values for existing # tables, drop that default. sql_result.add_post_sql([ 'ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;' % (qn(table_name), qn(f.column)) ]) else: sql_result.add_alter_table([{ 'op': 'ADD COLUMN', 'column': f.column, 'db_type': f.db_type(connection=self.connection), 'params': [null_constraints, unique_constraints], }]) if f.unique or f.primary_key: self.database_state.add_index( table_name=table_name, index_name=self.get_new_constraint_name(table_name, f.column), columns=[f.column], unique=True) return sql_result
def register_models(database_state, models, register_indexes=False, new_app_label='tests', db_name='default', app=evo_test): """Register models for testing purposes. Args: database_state (django_evolution.db.state.DatabaseState): The database state to populate with model information. models (list of django.db.models.Model): The models to register. register_indexes (bool, optional): Whether indexes should be registered for any models. Defaults to ``False``. new_app_label (str, optional): The label for the test app. Defaults to "tests". db_name (str, optional): The name of the database connection. Defaults to "default". app (module, optional): The application module for the test models. Returns: collections.OrderedDict: A dictionary of registered models. The keys are model names, and the values are the models. """ app_cache = OrderedDict() evolver = EvolutionOperationsMulti(db_name, database_state).get_evolver() db_connection = connections[db_name or DEFAULT_DB_ALIAS] max_name_length = db_connection.ops.max_name_length() for new_object_name, model in reversed(models): # Grab some state from the model's meta instance. Some of this will # be original state that we'll keep around to help us unregister old # values and compute new ones. meta = model._meta orig_app_label = meta.app_label orig_db_table = meta.db_table orig_object_name = meta.object_name orig_model_name = get_model_name(model) # Find out if the table name being used is a custom table name, or # one generated by Django. new_model_name = new_object_name.lower() new_db_table = orig_db_table generated_db_table = truncate_name( '%s_%s' % (orig_app_label, orig_model_name), max_name_length) if orig_db_table == generated_db_table: # It was a generated one, so replace it with a version containing # the new model and app names. new_db_table = truncate_name('%s_%s' % (new_app_label, new_model_name), max_name_length) meta.db_table = new_db_table # Set the new app/model names back on the meta instance. meta.app_label = new_app_label meta.object_name = new_object_name set_model_name(model, new_model_name) # Add an entry for the table in the database state, if it's not # already there. if not database_state.has_table(new_db_table): database_state.add_table(new_db_table) if register_indexes: # Now that we definitely have an entry, store the indexes for # all the fields in the database state, so that other operations # can look up the index names. for field in meta.local_fields: if field.db_index or field.unique: new_index_name = create_index_name( db_connection, new_db_table, field_names=[field.name], col_names=[field.column], unique=field.unique) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column], unique=field.unique) for field_names in meta.unique_together: fields = evolver.get_fields_for_names(model, field_names) new_index_name = create_index_name( db_connection, new_db_table, field_names=field_names, unique=True) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column for field in fields], unique=True) for field_names in getattr(meta, 'index_together', []): # Django >= 1.5 fields = evolver.get_fields_for_names(model, field_names) new_index_name = create_index_together_name( db_connection, new_db_table, field_names=[field.name for field in fields]) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column for field in fields]) if getattr(meta, 'indexes', None): # Django >= 1.11 for index, orig_index in zip(meta.indexes, meta.original_attrs['indexes']): if not orig_index.name: # The name was auto-generated. We'll need to generate # it again for the new table name. index.set_name_with_model(model) fields = evolver.get_fields_for_names( model, index.fields, allow_sort_prefixes=True) database_state.add_index( index_name=index.name, table_name=new_db_table, columns=[field.column for field in fields]) # ManyToManyFields have their own tables, which will also need to be # renamed. Go through each of them and figure out what changes need # to be made. for field in meta.local_many_to_many: through = get_remote_field(field).through if not through: continue through_meta = through._meta through_orig_model_name = get_model_name(through) through_new_model_name = through_orig_model_name # Find out if the through table name is a custom table name, or # one generated by Django. generated_db_table = truncate_name( '%s_%s' % (orig_db_table, field.name), max_name_length) if through_meta.db_table == generated_db_table: # This is an auto-generated table name. Start changing the # state for it. assert through_meta.app_label == orig_app_label through_meta.app_label = new_app_label # Transform the 'through' table information only if we've # transformed the parent db_table. if new_db_table != orig_db_table: through_meta.db_table = truncate_name( '%s_%s' % (new_db_table, field.name), max_name_length) through_meta.object_name = \ through_meta.object_name.replace(orig_object_name, new_object_name) through_new_model_name = \ through_orig_model_name.replace(orig_model_name, new_model_name) set_model_name(through, through_new_model_name) # Change each of the columns for the fields on the # ManyToManyField's model to reflect the new model names. for through_field in through._meta.local_fields: through_remote_field = get_remote_field(through_field) if (through_remote_field and get_remote_field_model(through_remote_field)): column = through_field.column if (column.startswith((orig_model_name, 'to_%s' % orig_model_name, 'from_%s' % orig_model_name))): # This is a field that references one end of the # relation or another. Update the model naem in the # field's column. through_field.column = column.replace(orig_model_name, new_model_name) # Replace the entry in the models cache for the through table, # removing the old name and adding the new one. if through_orig_model_name in all_models[orig_app_label]: unregister_app_model(orig_app_label, through_orig_model_name) app_cache[through_new_model_name] = through register_app_models(new_app_label, [(through_new_model_name, through)]) # Unregister with the old model name and register the new one. if orig_model_name in all_models[orig_app_label]: unregister_app_model(orig_app_label, orig_model_name) register_app_models(new_app_label, [(new_model_name, model)]) app_cache[new_model_name] = model # If the app hasn't yet been registered, do that now. if not is_app_registered(app): register_app(new_app_label, app) return app_cache
def test_with_one_to_one_field(self): """Testing get_remote_field with OneToOneField""" rel = get_remote_field( CompatModelsTestModel._meta.get_field('o2o_field')) self.assertIsInstance(rel, related.OneToOneRel)
def sql_add_constraints(connection, model, refs): """Return SQL statements for adding constraints. This provides compatibility with all supported versions of Django. Args: connection (object): The database connection. model (django.db.models.Model): The database model to add constraints on. refs (dict): A dictionary of constraint references to add. The keys are instances of :py:class:`django.db.models.Model`. The values are a tuple of (:py:class:`django.db.models.Model`, :py:class:`django.db.models.Field`). Returns: list: The list of SQL statements for adding constraints. """ if BaseDatabaseSchemaEditor: # Django >= 1.7 meta = model._meta if not meta.managed or meta.swapped: return [] sql = [] if model in refs: with connection.schema_editor() as schema_editor: qn = schema_editor.quote_name for rel_class, f in refs[model]: # Ideally, we would use schema_editor._create_fk_sql here, # but it depends on a lot more state than we have # available currently in our mocks. So we have to build # the SQL ourselves. It's not a lot of work, fortunately. # # For reference, this is what we'd ideally do: # # sql.append('%s;' % schema_editor._create_fk_sql( # rel_class, f, # '_fk_%(to_table)s_%(to_column)s')) # rel_meta = rel_class._meta to_column = (meta.get_field( get_remote_field(f).field_name).column) suffix = '_fk_%(to_table)s_%(to_column)s' % { 'to_table': meta.db_table, 'to_column': to_column, } name = create_index_name(connection=connection, table_name=rel_meta.db_table, col_names=[f.column], suffix=suffix) create_sql = schema_editor.sql_create_fk % { 'table': qn(rel_meta.db_table), 'name': qn(name), 'column': qn(f.column), 'to_table': qn(meta.db_table), 'to_column': qn(to_column), 'deferrable': connection.ops.deferrable_sql(), } sql.append('%s;' % create_sql) del refs[model] return sql else: # Django < 1.7 return connection.creation.sql_for_pending_references( model, color.no_style(), refs)
def test_with_foreign_key_rel(self): """Testing get_remote_field_related_model with ForeignKey relation""" field = CompatModelsTestModel._meta.get_field('fkey_field') model = get_remote_field_related_model(get_remote_field(field)) self.assertIs(model, CompatModelsTestModel)
def iter_model_fields(model, include_parent_models=True, include_forward_fields=True, include_reverse_fields=False, include_hidden_fields=False, seen_models=None): """Iterate through all fields on a model using the given criteria. This is roughly equivalent to Django's internal :py:func:`django.db.models.options.Option._get_fields` on Django 1.8+, but makes use of our model reverse relation tree, and works across all supported versions of Django. Version Added: 2.2 Args: model (type): The model owning the fields. include_parent_models (bool, optional): Whether to include fields defined on parent models. include_forward_fields (bool, optional): Whether to include fields owned by the model (or a parent). include_reverse_fields (bool, optional): Whether to include fields on other models that point to this model. include_hidden_fields (bool, optional): Whether to include hidden fields. seen_models (set, optional): Models seen during iteration. This is intended for internal use only by this function. Yields: django.db.models.Field: Each field matching the criteria. """ concrete_model = model._meta.concrete_model if seen_models is None: seen_models = set() if include_parent_models: candidate_models = walk_model_tree(model) else: candidate_models = [model] if include_reverse_fields: # Find all models containing fields that point to this model. rel_tree = get_model_rel_tree() rel_fields = rel_tree.get(model._meta.concrete_model._meta.db_table, []) else: rel_fields = [] for cur_model in candidate_models: cur_model_label = cur_model._meta.db_table if (cur_model_label in seen_models or cur_model._meta.concrete_model != concrete_model): continue seen_models.add(cur_model_label) if include_parent_models: for parent in cur_model._meta.parents: if parent not in seen_models: parent_fields = iter_model_fields( parent, include_parent_models=True, include_forward_fields=include_forward_fields, include_reverse_fields=include_reverse_fields, include_hidden_fields=include_hidden_fields) for field in parent_fields: yield field if include_reverse_fields and not cur_model._meta.proxy: for rel_field in rel_fields: remote_field = get_remote_field(rel_field) if (include_hidden_fields or not get_field_is_hidden(remote_field)): yield remote_field if include_forward_fields: for field in cur_model._meta.local_fields: yield field for field in cur_model._meta.local_many_to_many: yield field # Django >= 1.10 for field in getattr(model._meta, 'private_fields', []): yield field
def sql_add_constraints(connection, model, refs): """Return SQL statements for adding constraints. This provides compatibility with all supported versions of Django. Args: connection (object): The database connection. model (django.db.models.Model): The database model to add constraints on. refs (dict): A dictionary of constraint references to add. The keys are instances of :py:class:`django.db.models.Model`. The values are a tuple of (:py:class:`django.db.models.Model`, :py:class:`django.db.models.Field`). Returns: list: The list of SQL statements for adding constraints. """ if BaseDatabaseSchemaEditor: # Django >= 1.7 meta = model._meta if not meta.managed or meta.swapped: return [] sql = [] if model in refs: with connection.schema_editor() as schema_editor: qn = schema_editor.quote_name for rel_class, f in refs[model]: # Ideally, we would use schema_editor._create_fk_sql here, # but it depends on a lot more state than we have # available currently in our mocks. So we have to build # the SQL ourselves. It's not a lot of work, fortunately. # # For reference, this is what we'd ideally do: # # sql.append('%s;' % schema_editor._create_fk_sql( # rel_class, f, # '_fk_%(to_table)s_%(to_column)s')) # rel_meta = rel_class._meta to_column = ( meta.get_field(get_remote_field(f).field_name) .column ) suffix = '_fk_%(to_table)s_%(to_column)s' % { 'to_table': meta.db_table, 'to_column': to_column, } name = create_index_name(connection=connection, table_name=rel_meta.db_table, col_names=[f.column], suffix=suffix) create_sql = schema_editor.sql_create_fk % { 'table': qn(rel_meta.db_table), 'name': qn(name), 'column': qn(f.column), 'to_table': qn(meta.db_table), 'to_column': qn(to_column), 'deferrable': connection.ops.deferrable_sql(), } sql.append('%s;' % create_sql) del refs[model] return sql else: # Django < 1.7 return connection.creation.sql_for_pending_references( model, color.no_style(), refs)
def create_field(project_sig, field_name, field_type, field_attrs, parent_model, related_model=None): """Create a Django field instance for the given signature data. This creates a field in a way that's compatible with a variety of versions of Django. It takes in data such as the field's name and attributes and creates an instance that can be used like any field found on a model. Args: field_name (unicode): The name of the field. field_type (cls): The class for the type of field being constructed. This must be a subclass of :py:class:`django.db.models.Field`. field_attrs (dict): Attributes to set on the field. parent_model (cls): The parent model that would own this field. This must be a subclass of :py:class:`django.db.models.Model`. related_model (unicode, optional): The full class path to a model this relates to. This requires a :py:class:`django.db.models.ForeignKey` field type. Returns: django.db.models.Field: A new field instance matching the provided data. """ # Convert to the standard string format for each version of Python, to # simulate what the format would be for the default name. field_name = str(field_name) assert 'related_model' not in field_attrs, \ ('related_model cannot be passed in field_attrs when calling ' 'create_field(). Pass the related_model parameter instead.') if related_model: related_app_name, related_model_name = related_model.split('.') related_model_sig = (project_sig.get_app_sig( related_app_name, required=True).get_model_sig(related_model_name, required=True)) to = MockModel(project_sig=project_sig, app_name=related_app_name, model_name=related_model_name, model_sig=related_model_sig, stub=True) if (issubclass(field_type, models.ForeignKey) and hasattr(models, 'CASCADE') and 'on_delete' not in field_attrs): # Starting in Django 2.0, on_delete is a requirement for # ForeignKeys. If not provided in the signature, we want to # default this to CASCADE, which is the value that Django # previously defaulted to. field_attrs = dict({ 'on_delete': models.CASCADE, }, **field_attrs) field = field_type(to, name=field_name, **field_attrs) else: field = field_type(name=field_name, **field_attrs) if (issubclass(field_type, models.ManyToManyField) and parent_model is not None): # Starting in Django 1.2, a ManyToManyField must have a through # model defined. This will be set internally to an auto-created # model if one isn't specified. We have to fake that model. through_model = field_attrs.get('through_model') through_model_sig = None if through_model: through_app_name, through_model_name = through_model.split('.') through_model_sig = (project_sig.get_app_sig( through_app_name).get_model_sig(through_model_name)) elif hasattr(field, '_get_m2m_attr'): # Django >= 1.2 remote_field = get_remote_field(field) remote_field_model = get_remote_field_model(remote_field) to_field_name = remote_field_model._meta.object_name.lower() if (remote_field_model == RECURSIVE_RELATIONSHIP_CONSTANT or to_field_name == parent_model._meta.object_name.lower()): from_field_name = 'from_%s' % to_field_name to_field_name = 'to_%s' % to_field_name else: from_field_name = parent_model._meta.object_name.lower() # This corresponds to the signature in # related.create_many_to_many_intermediary_model through_app_name = parent_model.app_name through_model_name = '%s_%s' % (parent_model._meta.object_name, field.name), through_model_sig = ModelSignature( model_name=through_model_name, table_name=field._get_m2m_db_table(parent_model._meta), pk_column='id', unique_together=[(from_field_name, to_field_name)]) # 'id' field through_model_sig.add_field_sig( FieldSignature(field_name='id', field_type=models.AutoField, field_attrs={ 'primary_key': True, })) # 'from' field through_model_sig.add_field_sig( FieldSignature( field_name=from_field_name, field_type=models.ForeignKey, field_attrs={ 'related_name': '%s+' % through_model_name, }, related_model='%s.%s' % (parent_model.app_name, parent_model._meta.object_name))) # 'to' field through_model_sig.add_field_sig( FieldSignature(field_name=to_field_name, field_type=models.ForeignKey, field_attrs={ 'related_name': '%s+' % through_model_name, }, related_model=related_model)) field.auto_created = True if through_model_sig: through = MockModel(project_sig=project_sig, app_name=through_app_name, model_name=through_model_name, model_sig=through_model_sig, auto_created=not through_model, managed=not through_model) get_remote_field(field).through = through field.m2m_db_table = curry(field._get_m2m_db_table, parent_model._meta) field.set_attributes_from_rel() field.set_attributes_from_name(field_name) # Needed in Django >= 1.7, for index building. field.model = parent_model return field
def register_models(database_state, models, register_indexes=False, new_app_label='tests', db_name='default', app=evo_test): """Register models for testing purposes. Args: database_state (django_evolution.db.state.DatabaseState): The database state to populate with model information. models (list of django.db.models.Model): The models to register. register_indexes (bool, optional): Whether indexes should be registered for any models. Defaults to ``False``. new_app_label (str, optional): The label for the test app. Defaults to "tests". db_name (str, optional): The name of the database connection. Defaults to "default". app (module, optional): The application module for the test models. Returns: collections.OrderedDict: A dictionary of registered models. The keys are model names, and the values are the models. """ app_cache = OrderedDict() evolver = EvolutionOperationsMulti(db_name, database_state).get_evolver() db_connection = connections[db_name or DEFAULT_DB_ALIAS] max_name_length = db_connection.ops.max_name_length() for new_object_name, model in reversed(models): # Grab some state from the model's meta instance. Some of this will # be original state that we'll keep around to help us unregister old # values and compute new ones. meta = model._meta orig_app_label = meta.app_label orig_db_table = meta.db_table orig_object_name = meta.object_name orig_model_name = get_model_name(model) # Find out if the table name being used is a custom table name, or # one generated by Django. new_model_name = new_object_name.lower() new_db_table = orig_db_table generated_db_table = truncate_name( '%s_%s' % (orig_app_label, orig_model_name), max_name_length) if orig_db_table == generated_db_table: # It was a generated one, so replace it with a version containing # the new model and app names. new_db_table = truncate_name('%s_%s' % (new_app_label, new_model_name), max_name_length) meta.db_table = new_db_table # Set the new app/model names back on the meta instance. meta.app_label = new_app_label meta.object_name = new_object_name set_model_name(model, new_model_name) # Add an entry for the table in the database state, if it's not # already there. if not database_state.has_table(new_db_table): database_state.add_table(new_db_table) if register_indexes: # Now that we definitely have an entry, store the indexes for # all the fields in the database state, so that other operations # can look up the index names. for field in meta.local_fields: if field.db_index or field.unique: new_index_name = create_index_name( db_connection, new_db_table, field_names=[field.name], col_names=[field.column], unique=field.unique) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column], unique=field.unique) for field_names in meta.unique_together: fields = evolver.get_fields_for_names(model, field_names) new_index_name = create_index_name( db_connection, new_db_table, field_names=field_names, unique=True) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column for field in fields], unique=True) for field_names in getattr(meta, 'index_together', []): # Django >= 1.5 fields = evolver.get_fields_for_names(model, field_names) new_index_name = create_index_together_name( db_connection, new_db_table, field_names=[field.name for field in fields]) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column for field in fields]) if getattr(meta, 'indexes', None): # Django >= 1.11 for index, orig_index in zip(meta.indexes, meta.original_attrs['indexes']): if not orig_index.name: # The name was auto-generated. We'll need to generate # it again for the new table name. index.set_name_with_model(model) fields = evolver.get_fields_for_names( model, index.fields, allow_sort_prefixes=True) database_state.add_index( index_name=index.name, table_name=new_db_table, columns=[field.column for field in fields]) # ManyToManyFields have their own tables, which will also need to be # renamed. Go through each of them and figure out what changes need # to be made. for field in meta.local_many_to_many: through = get_remote_field(field).through if not through: continue through_meta = through._meta through_orig_model_name = get_model_name(through) through_new_model_name = through_orig_model_name # Find out if the through table name is a custom table name, or # one generated by Django. generated_db_table = truncate_name( '%s_%s' % (orig_db_table, field.name), max_name_length) if through_meta.db_table == generated_db_table: # This is an auto-generated table name. Start changing the # state for it. assert through_meta.app_label == orig_app_label through_meta.app_label = new_app_label # Transform the 'through' table information only if we've # transformed the parent db_table. if new_db_table != orig_db_table: through_meta.db_table = truncate_name( '%s_%s' % (new_db_table, field.name), max_name_length) through_meta.object_name = \ through_meta.object_name.replace(orig_object_name, new_object_name) through_new_model_name = \ through_orig_model_name.replace(orig_model_name, new_model_name) set_model_name(through, through_new_model_name) # Change each of the columns for the fields on the # ManyToManyField's model to reflect the new model names. for through_field in through._meta.local_fields: through_remote_field = get_remote_field(through_field) if (through_remote_field and get_remote_field_model(through_remote_field)): column = through_field.column if (column.startswith((orig_model_name, 'to_%s' % orig_model_name, 'from_%s' % orig_model_name))): # This is a field that references one end of the # relation or another. Update the model naem in the # field's column. through_field.column = column.replace(orig_model_name, new_model_name) # Replace the entry in the models cache for the through table, # removing the old name and adding the new one. if through_orig_model_name in all_models[orig_app_label]: unregister_app_model(orig_app_label, through_orig_model_name) app_cache[through_new_model_name] = through register_app_models(new_app_label, [(through_new_model_name, through)]) # Unregister with the old model name and register the new one. if orig_model_name in all_models[orig_app_label]: unregister_app_model(orig_app_label, orig_model_name) register_app_models(new_app_label, [(new_model_name, model)]) app_cache[new_model_name] = model # If the app hasn't yet been registered, do that now. if not is_app_registered(app): register_app(new_app_label, app) return app_cache