Exemple #1
0
 def get_django_tables(self, only_existing):
     try:
         django_tables = self.introspection.django_table_names(only_existing=only_existing)
     except AttributeError:
         # backwards compatibility for before introspection refactoring (r8296)
         try:
             django_tables = _sql.django_table_names(only_existing=only_existing)
         except AttributeError:
             # backwards compatibility for before svn r7568
             django_tables = _sql.django_table_list(only_existing=only_existing)
     return django_tables
Exemple #2
0
    def handle_diff(self, app, **options):
        from django.db import models, connection, get_introspection_module
        from django.core.management import sql as _sql
        
        app_name = app.__name__.split('.')[-2]
        
        try:
            django_tables = _sql.django_table_names(only_existing=options.get('only_existing', True))
        except AttributeError:
            # backwards compatibility for before svn r7568 
            django_tables = _sql.django_table_list(only_existing=options.get('only_existing', True))
        django_tables = [django_table for django_table in django_tables if django_table.startswith(app_name)]
        
        app_models = models.get_models(app)
        if not app_models:
            return
        
        introspection_module = get_introspection_module()
        cursor = connection.cursor()
        model_diffs = []
        for app_model in app_models:
            _constraints = None
            _meta = app_model._meta
            table_name = _meta.db_table
            
            table_indexes = introspection_module.get_indexes(cursor, table_name)

            
            fieldmap = dict([(field.get_attname(), field) for field in _meta.fields])
            try:
                table_description = introspection_module.get_table_description(cursor, table_name)
            except Exception, e:
                model_diffs.append((app_model.__name__, [str(e).strip()]))
                transaction.rollback() # reset transaction
                continue
            diffs = []
            for i, row in enumerate(table_description):
                att_name = row[0].lower()
                db_field_reverse_type = introspection_module.DATA_TYPES_REVERSE.get(row[1])
                kwargs = {}
                if row[3]:
                    kwargs['max_length'] = row[3]
                if row[4]:
                    kwargs['max_digits'] = row[4]
                if row[5]:
                    kwargs['decimal_places'] = row[5]
                if row[6]:
                    kwargs['blank'] = True
                    if not db_field_reverse_type in ('TextField', 'CharField'):
                        extra_params['null'] = True
                if fieldmap.has_key(att_name):
                    field = fieldmap.pop(att_name)
                    # check type
                    def clean(s):
                        s = s.split(" ")[0]
                        s = s.split("(")[0]
                        return s
                    def cmp_or_serialcmp(x, y):
                        result = x==y
                        if result:
                            return result
                        is_serial = lambda x,y: x.startswith("serial") and y.startswith("integer")
                        strip_serial = lambda x: x.lstrip("serial").lstrip("integer")
                        serial_logic = is_serial(x, y) or is_serial(y, x)
                        if result==False and serial_logic:
                            # use alternate serial logic
                            result = strip_serial(x)==strip_serial(y)
                        return result
                    db_field_type = getattr(models, db_field_reverse_type)(**kwargs).db_type()
                    model_type = field.db_type()
                    # check if we can for constraints (only enabled on postgresql atm)
                    if self.is_pgsql:
                        if _constraints==None:
                            sql = """
                            SELECT
                                pg_constraint.conname, pg_get_constraintdef(pg_constraint.oid)
                            FROM
                                pg_constraint, pg_attribute
                            WHERE
                                pg_constraint.conrelid = pg_attribute.attrelid
                                AND pg_attribute.attnum = any(pg_constraint.conkey)
                                AND pg_constraint.conname ~ %s"""
                            cursor.execute(sql, [table_name])
                            _constraints = [r for r in cursor.fetchall() if r[0].endswith("_check")]
                        for r_name, r_check in _constraints:
                            if table_name+"_"+att_name==r_name.rsplit("_check")[0]:
                                r_check = r_check.replace("((", "(").replace("))", ")")
                                pos = r_check.find("(")
                                r_check = "%s\"%s" % (r_check[:pos+1], r_check[pos+1:])
                                pos = pos+r_check[pos:].find(" ")
                                r_check = "%s\" %s" % (r_check[:pos], r_check[pos+1:])
                                db_field_type += " "+r_check
                    else:
                        # remove constraints
                        model_type = model_type.split("CHECK")[0].strip()
                    c_db_field_type = clean(db_field_type)
                    c_model_type = clean(model_type)
                    if not cmp_or_serialcmp(c_model_type, c_db_field_type):
                        diffs.append("field '%s' not of same type: db=%s, model=%s" % (att_name, c_db_field_type, c_model_type))
                        continue
                    if not cmp_or_serialcmp(db_field_type, model_type):
                        diffs.append("field '%s' parameters differ: db=%s, model=%s" % (att_name, db_field_type, model_type))
                        continue
                else:
                    diffs.append("field '%s' missing in model field" % att_name)
            for field in _meta.fields:
                if field.db_index:
                    if not field.attname in table_indexes and not field.unique:
                        diffs.append("field '%s' INDEX defined in model missing in database" % (field.attname))
            if fieldmap:
                for att_name, field in fieldmap.items():
                    diffs.append("field '%s' missing in database" % att_name)
            if diffs:
                model_diffs.append((app_model.__name__, diffs))
Exemple #3
0
    def handle_diff(self, app, **options):
        from django.db import models, connection
        from django.core.management import sql as _sql
        
        app_name = app.__name__.split('.')[-2]
        
	try:
	    django_tables = connection.introspection.django_table_names(only_existing=options.get('only_existing', True))
	except AttributeError:
	    # backwards compatibility for before introspection refactoring (r8296)
    	    try:
        	django_tables = _sql.django_table_names(only_existing=options.get('only_existing', True))
    	    except AttributeError:
        	# backwards compatibility for before svn r7568 
    	        django_tables = _sql.django_table_list(only_existing=options.get('only_existing', True))
        django_tables = [django_table for django_table in django_tables if django_table.startswith(app_name)]
        
        app_models = models.get_models(app)
        if not app_models:
            return
        
	try:
	    from django.db import get_introspection_module
            introspection_module = get_introspection_module()
	except ImportError:
	    introspection_module = connection.introspection
	
        cursor = connection.cursor()
        model_diffs = []
        for app_model in app_models:
            _constraints = None
            _meta = app_model._meta
            table_name = _meta.db_table
            table_indexes = introspection_module.get_indexes(cursor, table_name)
	    
            fieldmap = dict([(field.get_attname(), field) for field in _meta.fields])

            if _meta.order_with_respect_to:
                fieldmap['_order'] = ORDERING_FIELD

            try:
                table_description = introspection_module.get_table_description(cursor, table_name)
            except Exception, e:
                model_diffs.append((app_model.__name__, [str(e).strip()]))
                transaction.rollback() # reset transaction
                continue
            diffs = []
            for i, row in enumerate(table_description):
                att_name = row[0].lower()
		try:
        	    db_field_reverse_type = introspection_module.data_types_reverse[row[1]]
		except AttributeError:
		    # backwards compatibility for before introspection refactoring (r8296)
		    db_field_reverse_type = introspection_module.DATA_TYPES_REVERSE.get(row[1])
                kwargs = {}
		if isinstance(db_field_reverse_type, tuple):
		    kwargs.update(db_field_reverse_type[1])
		    db_field_reverse_type = db_field_reverse_type[0]
		
                if db_field_reverse_type == "CharField" and row[3]:
                    kwargs['max_length'] = row[3]
		
                if db_field_reverse_type == "DecimalField":
                    kwargs['max_digits'] = row[4]
                    kwargs['decimal_places'] = row[5]
		
                if row[6]:
                    kwargs['blank'] = True
                    if not db_field_reverse_type in ('TextField', 'CharField'):
                        kwargs['null'] = True

                if fieldmap.has_key(att_name):
                    field = fieldmap.pop(att_name)
                    # check type
                    def clean(s):
                        s = s.split(" ")[0]
                        s = s.split("(")[0]
                        return s
                    def cmp_or_serialcmp(x, y):
                        result = x==y
                        if result:
                            return result
                        is_serial = lambda x,y: x.startswith("serial") and y.startswith("integer")
                        strip_serial = lambda x: x.lstrip("serial").lstrip("integer")
                        serial_logic = is_serial(x, y) or is_serial(y, x)
                        if result==False and serial_logic:
                            # use alternate serial logic
                            result = strip_serial(x)==strip_serial(y)
                        return result
                    db_field_type = getattr(models, db_field_reverse_type)(**kwargs).db_type()
                    model_type = field.db_type()
		    
                    # remove mysql's auto_increment keyword
                    if self.is_mysql and model_type.endswith("AUTO_INCREMENT"):
                        model_type = model_type.rsplit(' ', 1)[0].strip()
		    
                    # check if we can for constraints (only enabled on postgresql atm)
                    if self.is_pgsql:
                        if _constraints==None:
                            sql = """
                            SELECT
                                pg_constraint.conname, pg_get_constraintdef(pg_constraint.oid)
                            FROM
                                pg_constraint, pg_attribute
                            WHERE
                                pg_constraint.conrelid = pg_attribute.attrelid
                                AND pg_attribute.attnum = any(pg_constraint.conkey)
                                AND pg_constraint.conname ~ %s"""
                            cursor.execute(sql, [table_name])
                            _constraints = [r for r in cursor.fetchall() if r[0].endswith("_check")]
                        for r_name, r_check in _constraints:
                            if table_name+"_"+att_name==r_name.rsplit("_check")[0]:
                                r_check = r_check.replace("((", "(").replace("))", ")")
                                pos = r_check.find("(")
                                r_check = "%s\"%s" % (r_check[:pos+1], r_check[pos+1:])
                                pos = pos+r_check[pos:].find(" ")
                                r_check = "%s\" %s" % (r_check[:pos], r_check[pos+1:])
                                db_field_type += " "+r_check
                    else:
                        # remove constraints
                        model_type = model_type.split("CHECK")[0].strip()
                    c_db_field_type = clean(db_field_type)
                    c_model_type = clean(model_type)

                    if self.is_sqlite and (c_db_field_type=="varchar" and c_model_type=="char"):
                        c_db_field_type = "char"
                        db_field_type = db_field_type.lstrip("var")

                    if not cmp_or_serialcmp(c_model_type, c_db_field_type):
                        diffs.append({
                            'text' : "field '%s' not of same type: db=%s, model=%s" % (att_name, c_db_field_type, c_model_type),
                            'type' : 'type',
                            'data' : (table_name, att_name, c_db_field_type, c_model_type)
                        })
                        continue
                    if not cmp_or_serialcmp(db_field_type, model_type):
                        diffs.append({
                            'text' : "field '%s' parameters differ: db=%s, model=%s" % (att_name, db_field_type, model_type),
                            'type' : 'param',
                            'data' : (table_name, att_name, db_field_type, model_type)
                        })
                        continue
                else:
                    diffs.append({
                        'text' : "field '%s' missing in model: %s" % (att_name, model_type),
                        'type' : 'missing-in-model',
                        'data' : (table_name, att_name, db_field_type, model_type)
                    })
            for field in _meta.fields:
                if field.db_index:
                    if not field.attname in table_indexes and not field.unique:
                        diffs.append({
                            'text' : "field '%s' INDEX defined in model missing in database" % (field.attname),
                        })
            if fieldmap:
                for att_name, field in fieldmap.items():
                    diffs.append({
                        'text' : "field '%s' missing in database: %s" % (att_name, field.db_type()),
                        'type' : 'missing-in-db',
                        'data' : (table_name, att_name, field.db_type())
                    })
            if diffs:
                model_diffs.append((app_model.__name__, diffs))