def add_not_null(db, migrator, table, column_name, field): cmds = [] compiler = db.compiler() if field.default is not None: # if default is a function, turn it into a value # this won't work on columns requiring uniquiness, like UUIDs # as all columns will share the same called value default = field.default() if hasattr(field.default, '__call__') else field.default op = pw.Clause(pw.SQL('UPDATE'), pw.Entity(table), pw.SQL('SET'), field.as_entity(), pw.SQL('='), default, pw.SQL('WHERE'), field.as_entity(), pw.SQL('IS NULL')) cmds.append(compiler.parse_node(op)) if is_postgres(db) or is_sqlite(db): junk = migrator.add_not_null(table, column_name, generate=True) cmds += normalize_whatever_junk_peewee_migrations_gives_you( migrator, junk) return cmds elif is_mysql(db): op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table), pw.SQL('MODIFY'), compiler.field_definition(field)) cmds.append(compiler.parse_node(op)) return cmds raise Exception('how do i add a not null for %s?' % db)
def change_column_type(db, migrator, table_name, column_name, field): column_type = _field_type(field) if is_postgres(db): op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER'), field.as_entity(), pw.SQL('TYPE'), field.__ddl_column__(column_type)) elif is_mysql(db): op = pw.Clause(*[pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('MODIFY')] + field.__ddl__(column_type)) else: raise Exception('how do i change a column type for %s?' % db) return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)
def set_default(db, migrator, table_name, column_name, field): default = field.default if callable(default): default = default() param = pw.Param(field.db_value(default)) op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER COLUMN'), pw.Entity(column_name), pw.SQL('SET DEFAULT'), param) return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)
def field_definition(self, field): column_type = self.get_column_type(field.get_db_field()) ddl = field.__ddl__(column_type) ddl = [ x for x in ddl if not (isinstance(x, peewee.SQL) and ( x.value == 'NOT NULL' or x.value.startswith('DEFAULT NEXTVAL(') or x.value.startswith('CHECK ('))) ] return peewee.Clause(*ddl)
def rename_column(db, migrator, ntn, ocn, ncn, field): compiler = db.compiler() if is_mysql(db): junk = pw.Clause( pw.SQL('ALTER TABLE'), pw.Entity(ntn), pw.SQL('CHANGE'), pw.Entity(ocn), compiler.field_definition(field) ) else: junk = migrator.rename_column(ntn, ocn, ncn, generate=True) return normalize_whatever_junk_peewee_migrations_gives_you(migrator, junk)
def add_not_null(db, migrator, table, field, column_name): qc = db.compiler() if is_postgres(db) or is_sqlite(db): junk = migrator.add_not_null(table, column_name, generate=True) return normalize_whatever_junk_peewee_migrations_gives_you( db, migrator, junk) elif is_mysql(db): op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table), pw.SQL('MODIFY'), qc.field_definition(field)) return [qc.parse_node(op)] raise Exception('how do i add a not null for %s?' % db)
def _create_table(self, model_class, safe=False): statement = 'CREATE TABLE' meta = model_class._meta columns = [] if meta.composite_key: pk_cols = [ meta.fields[f].as_entity() for f in meta.primary_key.field_names ] for field in meta.sorted_fields: columns.append(self.field_definition(field)) return peewee.Clause(peewee.SQL(statement), model_class.as_entity(), peewee.EnclosedClause(*(columns)))
def drop_default(db, migrator, table_name, column_name, field): op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER COLUMN'), pw.Entity(column_name), pw.SQL('DROP DEFAULT')) return normalize_whatever_junk_peewee_migrations_gives_you( migrator, op)
def drop_foreign_key(db, migrator, table_name, fk_name): drop_stmt = 'drop foreign key' if is_mysql(db) else 'DROP CONSTRAINT' op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL(drop_stmt), pw.Entity(fk_name)) return normalize_whatever_junk_peewee_migrations_gives_you( migrator, op)
def drop_table(migrator, table_name): compiler = migrator.database.compiler() return [ compiler.parse_node( pw.Clause(pw.SQL('DROP TABLE'), pw.Entity(table_name))) ]
def calc_changes(db): migrator = None # expose eventually? if migrator is None: migrator = auto_detect_migrator(db) existing_tables = [unicode(t) for t in db.get_tables()] existing_indexes = { table: db.get_indexes(table) for table in existing_tables } existing_columns_by_table = get_columns_by_table(db) foreign_keys_by_table = get_foreign_keys_by_table(db) table_names_to_models = { cls._meta.db_table: cls for cls in all_models.keys() } qc = db.compiler() to_run = [] table_adds, add_fks, table_deletes, table_renames = calc_table_changes( existing_tables) table_renamed_from = {v: k for k, v in table_renames.items()} to_run += [ qc.create_table(table_names_to_models[tbl]) for tbl in table_adds ] for field in add_fks: if hasattr(field, '__pwdbev__not_deferred') and field.__pwdbev__not_deferred: field.deferred = False op = qc._create_foreign_key(field.model_class, field) to_run.append(qc.parse_node(op)) for k, v in table_renames.items(): ops = migrator.rename_table(k, v, generate=True) if not hasattr(ops, '__iter__'): ops = [ops] # sometimes pw return arrays, sometimes not to_run += [qc.parse_node(op) for op in ops] rename_cols_by_table = {} deleted_cols_by_table = {} for etn, ecols in existing_columns_by_table.items(): if etn in table_deletes: continue ntn = table_renames.get(etn, etn) model = table_names_to_models.get(ntn) if not model: continue defined_fields = model._meta.sorted_fields defined_column_name_to_field = { unicode(f.db_column): f for f in defined_fields } adds, deletes, renames, alter_statements = calc_column_changes( db, migrator, etn, ntn, ecols, defined_fields, foreign_keys_by_table[etn]) for column_name in adds: field = defined_column_name_to_field[column_name] to_run += alter_add_column(db, migrator, ntn, column_name, field) if not field.null: # alter_add_column strips null constraints # add them back after setting any defaults if field.default is not None: operation = migrator.apply_default(ntn, column_name, field, generate=True) to_run.append(qc.parse_node(operation)) else: to_run.append(( '-- adding a not null column without a default will fail if the table is not empty', [])) to_run += add_not_null(db, migrator, ntn, field, column_name) for column_name in deletes: to_run += drop_column(db, migrator, ntn, column_name) for ocn, ncn in renames.items(): field = defined_column_name_to_field[ncn] to_run += rename_column(db, migrator, ntn, ocn, ncn, field) to_run += alter_statements rename_cols_by_table[ntn] = renames deleted_cols_by_table[ntn] = deletes for ntn, model in table_names_to_models.items(): etn = table_renamed_from.get(ntn, ntn) deletes = deleted_cols_by_table.get(ntn, set()) existing_indexes_for_table = [ i for i in existing_indexes.get(etn, []) if not any([(c in deletes) for c in i.columns]) ] to_run += calc_index_changes(db, migrator, existing_indexes_for_table, model, rename_cols_by_table.get(ntn, {})) ''' to_run += calc_perms_changes($schema_tables, noop) unless $check_perms_for.empty? ''' to_run += [ qc.parse_node(pw.Clause(pw.SQL('DROP TABLE'), pw.Entity(tbl))) for tbl in table_deletes ] return to_run