def create_table(self, table_name, field_list, temporary=False, create_index=True): qn = self.connection.ops.quote_name output = [] create = ['CREATE'] if temporary: create.append('TEMPORARY') create.append('TABLE %s' % qn(table_name)) output = [' '.join(create)] output.append('(') columns = [] for field in field_list: if type(field) is not models.ManyToManyField: column_name = qn(field.column) column_type = field.db_type(connection=self.connection) params = [column_name, column_type] # Always use null if this is a temporary table. It may be # used to create a new field (which will be null while data is # copied across from the old table). if temporary or field.null: params.append('NULL') else: params.append('NOT NULL') if field.unique: params.append('UNIQUE') if field.primary_key: params.append('PRIMARY KEY') if not temporary and isinstance(field, models.ForeignKey): remote_field = get_remote_field(field) remote_field_model = get_remote_field_model(remote_field) params.append( 'REFERENCES %s (%s) DEFERRABLE INITIALLY DEFERRED' % (qn(remote_field_model._meta.db_table), qn( remote_field_model._meta.get_field( remote_field.field_name).column))) columns.append(' '.join(params)) output.append(', '.join(columns)) output.append(');') output = [''.join(output)] if create_index: output.extend(self.create_indexes_for_table( table_name, field_list)) return output
def remove_field_constraints(self, field, opts, models, refs): """Return SQL for removing constraints on a field. Args: field (django.db.models.Field): The field the constraints will be removed from. opts (django.db.models.options.Options): The Meta class for the model. models (list of django.db.models.Model): A caller-provided list that will be populated with models that constraints will be removed from. refs (dict): A caller-supplied dictionary that will be populated with references that are removed. The keys are models, and the values are lists of tuples of many-to-many models and fields. Returns: list: The list of SQL statements for removing constraints on the field. """ sql = [] if self.supports_constraints and field.primary_key: for f in opts.local_many_to_many: remote_field = get_remote_field(f) if remote_field and remote_field.through: through = remote_field.through for m2m_f in through._meta.local_fields: remote_m2m_f = get_remote_field(m2m_f) if not remote_m2m_f: continue remote_m2m_f_model = \ get_remote_field_model(remote_m2m_f) if (remote_m2m_f.field_name == field.column and remote_m2m_f_model._meta.db_table == opts.db_table): models.append(remote_m2m_f_model) refs.setdefault(remote_m2m_f_model, []).append( (through, m2m_f)) remove_refs = refs.copy() for relto in models: sql.extend( sql_delete_constraints(self.connection, relto, remove_refs)) return sql
def get_model_rel_tree(): """Return the full field relationship tree for all registered models. This will walk through every field in every model registered in Django, storing the relationships between objects, caching them. Each entry in the resulting dictionary will be a table mapping to a list of relation fields that point back at it. This can be used to quickly locate any and all reverse relations made to a field. This is similar to Django's built-in reverse relation tree used internally (with different implementations) in :py:class:`django.db.models.options.Options`, but works across all supported versions of Django, and supports cache clearing. Version Added: 2.2 Returns: dict: The model relation tree. """ global _rel_tree_cache if _rel_tree_cache is not None: return _rel_tree_cache rel_tree = defaultdict(list) all_models = get_models(include_auto_created=True) # We'll walk the entire model tree, looking for any immediate fields on # each model, building a mapping of models to fields that reference the # model. for cur_model in all_models: if cur_model._meta.abstract: continue for field in iter_model_fields(cur_model, include_parent_models=False, include_forward_fields=True, include_reverse_fields=False, include_hidden_fields=False): if (get_field_is_relation(field) and get_remote_field_related_model(field) is not None): remote_field = get_remote_field(field) remote_field_model = get_remote_field_model(remote_field) # Make sure this isn't a "self" relation or similar. if not isinstance(remote_field_model, six.string_types): db_table = \ remote_field_model._meta.concrete_model._meta.db_table rel_tree[db_table].append(field) _rel_tree_cache = rel_tree return rel_tree
def rename_table(self, model, old_db_tablename, db_tablename): sql_result = SQLResult() if old_db_tablename == db_tablename: # No Operation return sql_result max_name_length = self.connection.ops.max_name_length() refs = {} models = [] for field in model._meta.local_many_to_many: remote_field = get_remote_field(field) if (remote_field and remote_field.through and remote_field.through._meta.db_table == old_db_tablename): through = remote_field.through for m2m_field in through._meta.local_fields: remote_m2m_field = get_remote_field(m2m_field) if remote_m2m_field: remote_m2m_field_model = get_remote_field_model( remote_m2m_field) if remote_m2m_field_model == model: models.append(remote_m2m_field_model) refs.setdefault(remote_m2m_field_model, []).append( (through, m2m_field)) remove_refs = refs.copy() if self.supports_constraints: for relto in models: sql_result.add_pre_sql( sql_delete_constraints(self.connection, relto, remove_refs)) sql_result.add( self.get_rename_table_sql(model, old_db_tablename, db_tablename)) for relto in models: for rel_class, f in refs[relto]: if rel_class._meta.db_table == old_db_tablename: rel_class._meta.db_table = db_tablename rel_class._meta.db_table = \ truncate_name(rel_class._meta.db_table, max_name_length) if self.supports_constraints: sql_result.add_post_sql( sql_add_constraints(self.connection, relto, refs)) return sql_result
def create_table(self, table_name, field_list, temporary=False, create_index=True): qn = self.connection.ops.quote_name output = [] create = ['CREATE'] if temporary: create.append('TEMPORARY') create.append('TABLE %s' % qn(table_name)) output = [' '.join(create)] output.append('(') columns = [] for field in field_list: if type(field) is not models.ManyToManyField: column_name = qn(field.column) column_type = field.db_type(connection=self.connection) params = [column_name, column_type] # Always use null if this is a temporary table. It may be # used to create a new field (which will be null while data is # copied across from the old table). if temporary or field.null: params.append('NULL') else: params.append('NOT NULL') if field.unique: params.append('UNIQUE') if field.primary_key: params.append('PRIMARY KEY') if not temporary and isinstance(field, models.ForeignKey): remote_field = get_remote_field(field) remote_field_model = get_remote_field_model(remote_field) params.append( 'REFERENCES %s (%s) DEFERRABLE INITIALLY DEFERRED' % (qn(remote_field_model._meta.db_table), qn(remote_field_model._meta.get_field( remote_field.field_name).column))) columns.append(' '.join(params)) output.append(', '.join(columns)) output.append(');') output = [''.join(output)] if create_index: output.extend( self.create_indexes_for_table(table_name, field_list)) return output
def _get_rename_column_sql(self, opts, old_field, new_field): qn = self.connection.ops.quote_name style = color.no_style() col_type = new_field.db_type(connection=self.connection) tablespace = new_field.db_tablespace or opts.db_tablespace alter_table_item = '' # Make the definition (e.g. 'foo VARCHAR(30)') for this field. field_output = [ style.SQL_FIELD(qn(new_field.column)), style.SQL_COLTYPE(col_type), style.SQL_KEYWORD('%sNULL' % (not new_field.null and 'NOT ' or '')), ] if new_field.primary_key: field_output.append(style.SQL_KEYWORD('PRIMARY KEY')) if new_field.unique: field_output.append(style.SQL_KEYWORD('UNIQUE')) if (tablespace and self.connection.features.supports_tablespaces and self.connection.features.autoindexes_primary_keys and (new_field.unique or new_field.primary_key)): # We must specify the index tablespace inline, because we # won't be generating a CREATE INDEX statement for this field. field_output.append( self.connection.ops.tablespace_sql(tablespace, inline=True)) new_remote_field = get_remote_field(new_field) if new_remote_field: new_remote_field_meta = \ get_remote_field_model(new_remote_field)._meta field_output.append('%s %s (%s)%s' % ( style.SQL_KEYWORD('REFERENCES'), style.SQL_TABLE(qn(new_remote_field_meta.db_table)), style.SQL_FIELD( qn( new_remote_field_meta.get_field( new_remote_field.field_name).column)), self.connection.ops.deferrable_sql(), )) if old_field.primary_key: alter_table_item = 'DROP PRIMARY KEY, ' alter_table_item += ('CHANGE COLUMN %s %s' % (qn(old_field.column), ' '.join(field_output))) return [{'sql': alter_table_item}]
def delete_column(self, model, f): sql_result = AlterTableSQLResult(self, model) remote_field = get_remote_field(f) if remote_field: remote_field_model = get_remote_field_model(remote_field) sql_result.add( sql_delete_constraints(self.connection, remote_field_model, {remote_field_model: [(model, f)]})) sql_result.add_sql( super(EvolutionOperations, self).delete_column(model, f)) return sql_result
def create_field(project_sig, field_name, field_type, field_attrs, parent_model, related_model=None): """Create a Django field instance for the given signature data. This creates a field in a way that's compatible with a variety of versions of Django. It takes in data such as the field's name and attributes and creates an instance that can be used like any field found on a model. Args: field_name (unicode): The name of the field. field_type (cls): The class for the type of field being constructed. This must be a subclass of :py:class:`django.db.models.Field`. field_attrs (dict): Attributes to set on the field. parent_model (cls): The parent model that would own this field. This must be a subclass of :py:class:`django.db.models.Model`. related_model (unicode, optional): The full class path to a model this relates to. This requires a :py:class:`django.db.models.ForeignKey` field type. Returns: django.db.models.Field: A new field instance matching the provided data. """ # Convert to the standard string format for each version of Python, to # simulate what the format would be for the default name. field_name = str(field_name) assert 'related_model' not in field_attrs, \ ('related_model cannot be passed in field_attrs when calling ' 'create_field(). Pass the related_model parameter instead.') if related_model: related_app_name, related_model_name = related_model.split('.') related_model_sig = (project_sig.get_app_sig( related_app_name, required=True).get_model_sig(related_model_name, required=True)) to = MockModel(project_sig=project_sig, app_name=related_app_name, model_name=related_model_name, model_sig=related_model_sig, stub=True) if (issubclass(field_type, models.ForeignKey) and hasattr(models, 'CASCADE') and 'on_delete' not in field_attrs): # Starting in Django 2.0, on_delete is a requirement for # ForeignKeys. If not provided in the signature, we want to # default this to CASCADE, which is the value that Django # previously defaulted to. field_attrs = dict({ 'on_delete': models.CASCADE, }, **field_attrs) field = field_type(to, name=field_name, **field_attrs) else: field = field_type(name=field_name, **field_attrs) if (issubclass(field_type, models.ManyToManyField) and parent_model is not None): # Starting in Django 1.2, a ManyToManyField must have a through # model defined. This will be set internally to an auto-created # model if one isn't specified. We have to fake that model. through_model = field_attrs.get('through_model') through_model_sig = None if through_model: through_app_name, through_model_name = through_model.split('.') through_model_sig = (project_sig.get_app_sig( through_app_name).get_model_sig(through_model_name)) elif hasattr(field, '_get_m2m_attr'): # Django >= 1.2 remote_field = get_remote_field(field) remote_field_model = get_remote_field_model(remote_field) to_field_name = remote_field_model._meta.object_name.lower() if (remote_field_model == RECURSIVE_RELATIONSHIP_CONSTANT or to_field_name == parent_model._meta.object_name.lower()): from_field_name = 'from_%s' % to_field_name to_field_name = 'to_%s' % to_field_name else: from_field_name = parent_model._meta.object_name.lower() # This corresponds to the signature in # related.create_many_to_many_intermediary_model through_app_name = parent_model.app_name through_model_name = '%s_%s' % (parent_model._meta.object_name, field.name), through_model_sig = ModelSignature( model_name=through_model_name, table_name=field._get_m2m_db_table(parent_model._meta), pk_column='id', unique_together=[(from_field_name, to_field_name)]) # 'id' field through_model_sig.add_field_sig( FieldSignature(field_name='id', field_type=models.AutoField, field_attrs={ 'primary_key': True, })) # 'from' field through_model_sig.add_field_sig( FieldSignature( field_name=from_field_name, field_type=models.ForeignKey, field_attrs={ 'related_name': '%s+' % through_model_name, }, related_model='%s.%s' % (parent_model.app_name, parent_model._meta.object_name))) # 'to' field through_model_sig.add_field_sig( FieldSignature(field_name=to_field_name, field_type=models.ForeignKey, field_attrs={ 'related_name': '%s+' % through_model_name, }, related_model=related_model)) field.auto_created = True if through_model_sig: through = MockModel(project_sig=project_sig, app_name=through_app_name, model_name=through_model_name, model_sig=through_model_sig, auto_created=not through_model, managed=not through_model) get_remote_field(field).through = through field.m2m_db_table = curry(field._get_m2m_db_table, parent_model._meta) field.set_attributes_from_rel() field.set_attributes_from_name(field_name) # Needed in Django >= 1.7, for index building. field.model = parent_model return field
def add_column(self, model, f, initial): qn = self.connection.ops.quote_name sql_result = AlterTableSQLResult(self, model) table_name = model._meta.db_table remote_field = get_remote_field(f) if remote_field: # it is a foreign key field # NOT NULL REFERENCES "django_evolution_addbasemodel" # ("id") DEFERRABLE INITIALLY DEFERRED # ALTER TABLE <tablename> ADD COLUMN <column name> NULL # REFERENCES <tablename1> ("<colname>") DEFERRABLE INITIALLY # DEFERRED related_model = get_remote_field_model(remote_field) related_table = related_model._meta.db_table related_pk_col = related_model._meta.pk.name constraints = ['%sNULL' % (not f.null and 'NOT ' or '')] if f.unique or f.primary_key: constraints.append('UNIQUE') sql_result.add_alter_table([{ 'op': 'ADD COLUMN', 'column': f.column, 'db_type': f.db_type(connection=self.connection), 'params': constraints + [ 'REFERENCES', qn(related_table), '(%s)' % qn(related_pk_col), self.connection.ops.deferrable_sql(), ] }]) else: null_constraints = '%sNULL' % (not f.null and 'NOT ' or '') if f.unique or f.primary_key: unique_constraints = 'UNIQUE' else: unique_constraints = '' # At this point, initial can only be None if null=True, # otherwise it is a user callable or the default # AddFieldInitialCallback which will shortly raise an exception. if initial is not None: if callable(initial): sql_result.add_alter_table([{ 'op': 'ADD COLUMN', 'column': f.column, 'db_type': f.db_type(connection=self.connection), 'params': [unique_constraints], }]) sql_result.add_sql([ 'UPDATE %s SET %s = %s WHERE %s IS NULL;' % (qn(table_name), qn(f.column), initial(), qn(f.column)) ]) if not f.null: # Only put this sql statement if the column cannot # be null. sql_result.add_sql( self.set_field_null(model, f, f.null)) else: sql_result.add_alter_table([{ 'op': 'ADD COLUMN', 'column': f.column, 'db_type': f.db_type(connection=self.connection), 'params': [ null_constraints, unique_constraints, 'DEFAULT', '%s', ], 'sql_params': [initial] }]) # Django doesn't generate default columns, so now that # we've added one to get default values for existing # tables, drop that default. sql_result.add_post_sql([ 'ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;' % (qn(table_name), qn(f.column)) ]) else: sql_result.add_alter_table([{ 'op': 'ADD COLUMN', 'column': f.column, 'db_type': f.db_type(connection=self.connection), 'params': [null_constraints, unique_constraints], }]) if f.unique or f.primary_key: self.database_state.add_index( table_name=table_name, index_name=self.get_new_constraint_name(table_name, f.column), columns=[f.column], unique=True) return sql_result
def test_with_foreign_key(self): """Testing get_remote_field_model with ForeignKey""" model = get_remote_field_model( CompatModelsTestModel._meta.get_field('fkey_field')) self.assertIs(model, CompatModelsTestModel)
def test_with_one_to_one_field_rel(self): """Testing get_remote_field_model with OneToOneField relation""" field = CompatModelsTestModel._meta.get_field('o2o_field') model = get_remote_field_model(get_remote_field(field)) self.assertIs(model, CompatModelsAnchor)
def test_with_one_to_one_field(self): """Testing get_remote_field_model with OneToOneField""" model = get_remote_field_model( CompatModelsTestModel._meta.get_field('o2o_field')) self.assertIs(model, CompatModelsTestModel)
def test_with_many_to_many_field_rel(self): """Testing get_remote_field_model with ManyToManyField relation""" field = CompatModelsTestModel._meta.get_field('m2m_field') model = get_remote_field_model(get_remote_field(field)) self.assertIs(model, CompatModelsAnchor)
def test_with_many_to_many_field(self): """Testing get_remote_field_model with ManyToManyField""" model = get_remote_field_model( CompatModelsTestModel._meta.get_field('m2m_field')) self.assertIs(model, CompatModelsTestModel)
def test_with_foreign_key_rel(self): """Testing get_remote_field_model with ForeignKey relation""" field = CompatModelsTestModel._meta.get_field('fkey_field') model = get_remote_field_model(get_remote_field(field)) self.assertIs(model, CompatModelsAnchor)
def register_models(database_state, models, register_indexes=False, new_app_label='tests', db_name='default', app=evo_test): """Register models for testing purposes. Args: database_state (django_evolution.db.state.DatabaseState): The database state to populate with model information. models (list of django.db.models.Model): The models to register. register_indexes (bool, optional): Whether indexes should be registered for any models. Defaults to ``False``. new_app_label (str, optional): The label for the test app. Defaults to "tests". db_name (str, optional): The name of the database connection. Defaults to "default". app (module, optional): The application module for the test models. Returns: collections.OrderedDict: A dictionary of registered models. The keys are model names, and the values are the models. """ app_cache = OrderedDict() evolver = EvolutionOperationsMulti(db_name, database_state).get_evolver() db_connection = connections[db_name or DEFAULT_DB_ALIAS] max_name_length = db_connection.ops.max_name_length() for new_object_name, model in reversed(models): # Grab some state from the model's meta instance. Some of this will # be original state that we'll keep around to help us unregister old # values and compute new ones. meta = model._meta orig_app_label = meta.app_label orig_db_table = meta.db_table orig_object_name = meta.object_name orig_model_name = get_model_name(model) # Find out if the table name being used is a custom table name, or # one generated by Django. new_model_name = new_object_name.lower() new_db_table = orig_db_table generated_db_table = truncate_name( '%s_%s' % (orig_app_label, orig_model_name), max_name_length) if orig_db_table == generated_db_table: # It was a generated one, so replace it with a version containing # the new model and app names. new_db_table = truncate_name('%s_%s' % (new_app_label, new_model_name), max_name_length) meta.db_table = new_db_table # Set the new app/model names back on the meta instance. meta.app_label = new_app_label meta.object_name = new_object_name set_model_name(model, new_model_name) # Add an entry for the table in the database state, if it's not # already there. if not database_state.has_table(new_db_table): database_state.add_table(new_db_table) if register_indexes: # Now that we definitely have an entry, store the indexes for # all the fields in the database state, so that other operations # can look up the index names. for field in meta.local_fields: if field.db_index or field.unique: new_index_name = create_index_name( db_connection, new_db_table, field_names=[field.name], col_names=[field.column], unique=field.unique) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column], unique=field.unique) for field_names in meta.unique_together: fields = evolver.get_fields_for_names(model, field_names) new_index_name = create_index_name( db_connection, new_db_table, field_names=field_names, unique=True) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column for field in fields], unique=True) for field_names in getattr(meta, 'index_together', []): # Django >= 1.5 fields = evolver.get_fields_for_names(model, field_names) new_index_name = create_index_together_name( db_connection, new_db_table, field_names=[field.name for field in fields]) database_state.add_index( index_name=new_index_name, table_name=new_db_table, columns=[field.column for field in fields]) if getattr(meta, 'indexes', None): # Django >= 1.11 for index, orig_index in zip(meta.indexes, meta.original_attrs['indexes']): if not orig_index.name: # The name was auto-generated. We'll need to generate # it again for the new table name. index.set_name_with_model(model) fields = evolver.get_fields_for_names( model, index.fields, allow_sort_prefixes=True) database_state.add_index( index_name=index.name, table_name=new_db_table, columns=[field.column for field in fields]) # ManyToManyFields have their own tables, which will also need to be # renamed. Go through each of them and figure out what changes need # to be made. for field in meta.local_many_to_many: through = get_remote_field(field).through if not through: continue through_meta = through._meta through_orig_model_name = get_model_name(through) through_new_model_name = through_orig_model_name # Find out if the through table name is a custom table name, or # one generated by Django. generated_db_table = truncate_name( '%s_%s' % (orig_db_table, field.name), max_name_length) if through_meta.db_table == generated_db_table: # This is an auto-generated table name. Start changing the # state for it. assert through_meta.app_label == orig_app_label through_meta.app_label = new_app_label # Transform the 'through' table information only if we've # transformed the parent db_table. if new_db_table != orig_db_table: through_meta.db_table = truncate_name( '%s_%s' % (new_db_table, field.name), max_name_length) through_meta.object_name = \ through_meta.object_name.replace(orig_object_name, new_object_name) through_new_model_name = \ through_orig_model_name.replace(orig_model_name, new_model_name) set_model_name(through, through_new_model_name) # Change each of the columns for the fields on the # ManyToManyField's model to reflect the new model names. for through_field in through._meta.local_fields: through_remote_field = get_remote_field(through_field) if (through_remote_field and get_remote_field_model(through_remote_field)): column = through_field.column if (column.startswith((orig_model_name, 'to_%s' % orig_model_name, 'from_%s' % orig_model_name))): # This is a field that references one end of the # relation or another. Update the model naem in the # field's column. through_field.column = column.replace(orig_model_name, new_model_name) # Replace the entry in the models cache for the through table, # removing the old name and adding the new one. if through_orig_model_name in all_models[orig_app_label]: unregister_app_model(orig_app_label, through_orig_model_name) app_cache[through_new_model_name] = through register_app_models(new_app_label, [(through_new_model_name, through)]) # Unregister with the old model name and register the new one. if orig_model_name in all_models[orig_app_label]: unregister_app_model(orig_app_label, orig_model_name) register_app_models(new_app_label, [(new_model_name, model)]) app_cache[new_model_name] = model # If the app hasn't yet been registered, do that now. if not is_app_registered(app): register_app(new_app_label, app) return app_cache