def create_model(inputs, targets, max_steps): model = Pix2Pix() out_channels = int(targets.get_shape()[-1]) outputs = model.get_generator(inputs, out_channels, ngf=args.ngf, conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='SAME', net_type=args.net_type, reuse=False, upsampe_method=args.upsampe_method) # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_real = model.get_discriminator(inputs, targets, ndf=args.ndf, spectral_normed=True, update_collection=None, conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='VALID', net_type=args.net_type, reuse=False) # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_fake = model.get_discriminator(inputs, outputs, ndf=args.ndf, spectral_normed=True, update_collection='NO_OPS', conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='VALID', net_type=args.net_type, reuse=True) with tf.name_scope("d_loss"): # discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) discrim_loss, _ = lib.misc.get_loss(predict_real, predict_fake, loss_type=args.loss_type) if args.loss_type == 'WGAN-GP': # Gradient Penalty alpha = tf.random_uniform(shape=[args.batch_size, 1, 1, 1], minval=0., maxval=1.) differences = outputs - targets interpolates = targets + (alpha * differences) # with tf.variable_scope("discriminator", reuse=True): gradients = tf.gradients( model.get_discriminator(inputs, interpolates, ndf=args.ndf, spectral_normed=True, update_collection=None, conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='VALID', net_type=args.net_type, reuse=True), [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]) + 1e-10) gradient_penalty = 10 * tf.reduce_mean(tf.square((slopes - 1.))) discrim_loss += gradient_penalty with tf.name_scope("g_loss"): # gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) _, gen_loss_GAN = lib.misc.get_loss(predict_real, predict_fake, loss_type=args.loss_type) if args.g_bce: outputs_ = deprocess(outputs) targets_ = deprocess(targets) gen_loss_content = -tf.reduce_mean( targets_ * tf.log(tf.clip_by_value(outputs_, 1e-10, 1.0 - 1e-10)) + (1.0 - targets_) * tf.log(tf.clip_by_value(1.0 - outputs_, 1e-10, 1.0 - 1e-10))) # gen_loss_content = -tf.reduce_mean( # targets * tf.log(tf.clip_by_value(outputs, 1e-10, 1.0)) + # (1.0 - targets) * tf.log(tf.clip_by_value(1.0 - outputs, 1e-10, 1.0))) else: gen_loss_content = tf.reduce_mean(tf.abs(targets - outputs)) gen_loss = gen_loss_GAN * args.gan_weight + gen_loss_content * args.l1_weight with tf.name_scope('global_step'): global_step = tf.train.get_or_create_global_step() # with tf.name_scope("global_step_summary"): # tf.summary.scalar("global_step", global_step) with tf.name_scope('lr_decay'): # learning_rate = tf.train.polynomial_decay( # learning_rate=args.initial_lr, # global_step=global_step, # decay_steps=max_steps, # end_learning_rate=args.end_lr # ) decay = 1. # decay = tf.where( # tf.less(global_step, 23600), tf.maximum(0., 1. - (tf.cast(global_step, tf.float32) / 47200)), 0.5) # decay = tf.where( # tf.less(global_step, int(max_steps * 0.5)), # 1., # tf.maximum(0., 1. - ((tf.cast(global_step, tf.float32) - int(max_steps * 0.5)) / max_steps))) if args.TTUR: print('\nUsing TTUR!\n') LR_D = tf.constant(0.0004) # 2e-4 # Initial learning rate LR_G = tf.constant(0.0001) # 2e-4 # Initial learning rate lr_d = LR_D * decay lr_g = LR_G * decay else: print('\nNot using TTUR!\n') LR_D = tf.constant(0.0002) # 2e-4 # Initial learning rate LR_G = tf.constant(0.0002) # 2e-4 # Initial learning rate lr_d = LR_D * decay lr_g = LR_G * decay # with tf.name_scope("lr_summary"): # tf.summary.scalar("lr", learning_rate) with tf.name_scope("d_train"): discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("d_net")] discrim_optim = tf.train.AdamOptimizer(lr_d, beta1=args.beta1, beta2=args.beta2) # discrim_optim = tf.train.AdamOptimizer(learning_rate, beta1=args.beta1, beta2=args.beta2) discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) with tf.name_scope("g_train"): gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("g_net")] gen_optim = tf.train.AdamOptimizer(lr_g, beta1=args.beta1, beta2=args.beta2) # gen_optim = tf.train.AdamOptimizer(learning_rate, beta1=args.beta1, beta2=args.beta2) gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) gen_train = gen_optim.apply_gradients(gen_grads_and_vars, global_step=global_step) ema = tf.train.ExponentialMovingAverage(decay=0.99) update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_content]) # global_step = tf.train.get_or_create_global_step() # incr_global_step = tf.assign(global_step, global_step + 1) return Model( lr=lr_d + lr_g, outputs=outputs, predict_real=predict_real, predict_fake=predict_fake, discrim_loss=ema.average(discrim_loss), discrim_grads_and_vars=discrim_grads_and_vars, gen_loss_GAN=ema.average(gen_loss_GAN), gen_loss_content=ema.average(gen_loss_content), gen_grads_and_vars=gen_grads_and_vars, d_train=discrim_train, g_train=gen_train, losses=update_losses, global_step=global_step )
def create_model(inputs, targets, max_steps): model = Pix2Pix() out_channels = int(targets.get_shape()[-1]) outputs = model.get_generator(inputs, out_channels, ngf=args.ngf, conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='SAME', net_type=args.net_type, reuse=False, upsampe_method=args.upsampe_method) # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_real = model.get_discriminator(inputs, targets, ndf=args.ndf, spectral_normed=True, update_collection=None, conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='VALID', net_type=args.net_type, reuse=False) # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_fake = model.get_discriminator(inputs, outputs, ndf=args.ndf, spectral_normed=True, update_collection=None, conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='VALID', net_type=args.net_type, reuse=True) with tf.name_scope("d_loss"): # discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) discrim_loss, _ = lib.misc.get_loss(predict_real, predict_fake, loss_type=args.loss_type) if args.loss_type == 'WGAN-GP': # Gradient Penalty alpha = tf.random_uniform(shape=[args.batch_size, 1, 1, 1], minval=0., maxval=1.) differences = outputs - targets interpolates = targets + (alpha * differences) # with tf.variable_scope("discriminator", reuse=True): gradients = tf.gradients( model.get_discriminator(inputs, interpolates, ndf=args.ndf, spectral_normed=True, update_collection=None, conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='VALID', net_type=args.net_type, reuse=True), [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]) + 1e-10) gradient_penalty = 10 * tf.reduce_mean(tf.square((slopes - 1.))) discrim_loss += gradient_penalty with tf.name_scope("g_loss"): # gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) _, gen_loss_GAN = lib.misc.get_loss(predict_real, predict_fake, loss_type=args.loss_type) gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) gen_loss = gen_loss_GAN * args.gan_weight + gen_loss_L1 * args.l1_weight with tf.name_scope('global_step'): global_step = tf.train.get_or_create_global_step() # with tf.name_scope("global_step_summary"): # tf.summary.scalar("global_step", global_step) with tf.name_scope('lr_decay'): learning_rate = tf.train.polynomial_decay( learning_rate=args.initial_lr, global_step=global_step, decay_steps=max_steps, end_learning_rate=args.end_lr ) # with tf.name_scope("lr_summary"): # tf.summary.scalar("lr", learning_rate) with tf.name_scope("d_train"): discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("d_net")] discrim_optim = tf.train.AdamOptimizer(0.0004, beta1=args.beta1, beta2=args.beta2) # discrim_optim = tf.train.AdamOptimizer(learning_rate, beta1=args.beta1, beta2=args.beta2) discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) with tf.name_scope("g_train"): gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("g_net")] gen_optim = tf.train.AdamOptimizer(0.0001, beta1=args.beta1, beta2=args.beta2) # gen_optim = tf.train.AdamOptimizer(learning_rate, beta1=args.beta1, beta2=args.beta2) gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) gen_train = gen_optim.apply_gradients(gen_grads_and_vars, global_step=global_step) ema = tf.train.ExponentialMovingAverage(decay=0.99) update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1]) # global_step = tf.train.get_or_create_global_step() # incr_global_step = tf.assign(global_step, global_step + 1) return Model( outputs=outputs, predict_real=predict_real, predict_fake=predict_fake, discrim_loss=ema.average(discrim_loss), discrim_grads_and_vars=discrim_grads_and_vars, gen_loss_GAN=ema.average(gen_loss_GAN), gen_loss_L1=ema.average(gen_loss_L1), gen_grads_and_vars=gen_grads_and_vars, d_train=discrim_train, g_train=gen_train, losses=update_losses, global_step=global_step )
def train(): if args.seed is None: args.seed = random.randint(0, 2**31 - 1) tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if args.mode == "test" or args.mode == "export": if args.checkpoint_dir is None: raise Exception("checkpoint required for test mode") # load some options from the checkpoint # options = {"which_direction", "ngf", "ndf", "lab_colorization"} # with open(os.path.join(args.checkpoint_dir, "options.json"), 'r') as f: # for key, val in json.loads(f.read()).items(): # if key in options: # print("loaded", key, "=", val) # setattr(args, key, val) # disable these features in test mode args.scale_size = CROP_SIZE args.flip = False for k, v in args._get_kwargs(): print(k, "=", v) with open(os.path.join(args.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) if args.mode == "export": # export the generator to a meta graph that can be imported later for standalone generation if args.lab_colorization: raise Exception("export not supported for lab_colorization") inputs = tf.placeholder(tf.string, shape=[1]) input_data = tf.decode_base64(inputs[0]) input_image = tf.image.decode_png(input_data) # remove alpha channel if present input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:, :, :3], lambda: input_image) # convert grayscale to RGB input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image) input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32) input_image.set_shape([CROP_SIZE, CROP_SIZE, 3]) batch_input = tf.expand_dims(input_image, axis=0) model_ = Pix2Pix() batch_output = deprocess( model_.get_generator(preprocess(batch_input), 3, ngf=args.ngf, conv_type=args.conv_type, channel_multiplier=args.channel_multiplier, padding='SAME')) # with tf.variable_scope("generator"): # batch_output = deprocess(model.get_generator(preprocess(batch_input), 3)) output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0] if args.output_filetype == "png": output_data = tf.image.encode_png(output_image) elif args.output_filetype == "jpeg": output_data = tf.image.encode_jpeg(output_image, quality=80) else: raise Exception("invalid filetype") output = tf.convert_to_tensor([tf.encode_base64(output_data)]) key = tf.placeholder(tf.string, shape=[1]) inputs = {"key": key.name, "input": inputs.name} tf.add_to_collection("inputs", json.dumps(inputs)) outputs = { "key": tf.identity(key).name, "output": output.name, } tf.add_to_collection("outputs", json.dumps(outputs)) init_op = tf.global_variables_initializer() restore_saver = tf.train.Saver() export_saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) print("loading model from checkpoint") checkpoint = tf.train.latest_checkpoint(args.checkpoint_dir) restore_saver.restore(sess, checkpoint) print("exporting model") export_saver.export_meta_graph( filename=os.path.join(args.output_dir, "export.meta")) export_saver.save(sess, os.path.join(args.output_dir, "export"), write_meta_graph=False) return examples = load_examples() print("examples count = %d" % examples.count) max_steps = 2**32 if args.max_epochs is not None: max_steps = examples.steps_per_epoch * args.max_epochs if args.max_steps is not None: max_steps = args.max_steps # inputs and targets are [batch_size, height, width, channels] modelNamedtuple = create_model(examples.inputs, examples.targets, max_steps) # undo colorization splitting on images that we use for display/output if args.lab_colorization: if args.which_direction == "AtoB": # inputs is brightness, this will be handled fine as a grayscale image # need to augment targets and outputs with brightness targets = augment(examples.targets, examples.inputs) outputs = augment(modelNamedtuple.outputs, examples.inputs) # inputs can be deprocessed normally and handled as if they are single channel # grayscale images inputs = deprocess(examples.inputs) elif args.which_direction == "BtoA": # inputs will be color channels only, get brightness from targets inputs = augment(examples.inputs, examples.targets) targets = deprocess(examples.targets) outputs = deprocess(modelNamedtuple.outputs) else: raise Exception("invalid direction") else: inputs = deprocess(examples.inputs) targets = deprocess(examples.targets) outputs = deprocess(modelNamedtuple.outputs) def convert(image): if args.aspect_ratio != 1.0: # upscale to correct aspect ratio size = [CROP_SIZE, int(round(CROP_SIZE * args.aspect_ratio))] image = tf.image.resize_images( image, size=size, method=tf.image.ResizeMethod.BICUBIC) return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True) # reverse any processing on images so they can be written to disk or displayed to user with tf.name_scope("convert_inputs"): converted_inputs = convert(inputs) with tf.name_scope("convert_targets"): converted_targets = convert(targets) with tf.name_scope("convert_outputs"): converted_outputs = convert(outputs) with tf.name_scope("encode_images"): if args.multiple_A: # channels = converted_inputs.shape.as_list()[3] converted_inputs = tf.split(converted_inputs, 2, 3)[1] print('\n----642----: {}\n'.format( converted_inputs.shape.as_list())) display_fetches = { "paths": examples.paths, "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), } # summaries # with tf.name_scope("inputs_summary"): # tf.summary.image("inputs", converted_inputs) with tf.name_scope("targets_summary"): tf.summary.image("targets", converted_targets) with tf.name_scope("outputs_summary"): tf.summary.image("outputs", converted_outputs) tf.summary.scalar("discriminator_loss", modelNamedtuple.discrim_loss) tf.summary.scalar("generator_loss_GAN", modelNamedtuple.gen_loss_GAN) tf.summary.scalar("generator_loss_L1", modelNamedtuple.gen_loss_L1) # for var in tf.trainable_variables(): # tf.summary.histogram(var.op.name + "/values", var) for grad, var in modelNamedtuple.discrim_grads_and_vars + modelNamedtuple.gen_grads_and_vars: tf.summary.histogram(var.op.name + "/gradients", grad) with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) summary_op = tf.summary.merge_all() saver = tf.train.Saver(max_to_keep=5) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: summary_writer = tf.summary.FileWriter(args.output_dir, sess.graph) sess.run(tf.global_variables_initializer()) print("parameter_count =", sess.run(parameter_count)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) if args.checkpoint_dir is not None: print("loading model from checkpoint") checkpoint = tf.train.latest_checkpoint(args.checkpoint_dir) saver.restore(sess, checkpoint) # max_steps = 2 ** 32 # if args.max_epochs is not None: # max_steps = examples.steps_per_epoch * args.max_epochs # if args.max_steps is not None: # max_steps = args.max_steps if args.mode == "test": # testing # at most, process the test data once start = time.time() max_steps = min(examples.steps_per_epoch, max_steps) for step in range(max_steps): results = sess.run(display_fetches) filesets = save_images(results) for i, f in enumerate(filesets): print("evaluated image", f["name"]) index_path = append_index(filesets) print("wrote index at", index_path) print("rate", (time.time() - start) / max_steps) else: # training start = time.time() for step in range(max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) for i in range(args.n_dis): sess.run(modelNamedtuple.d_train) fetches = { "g_train": modelNamedtuple.g_train, "losses": modelNamedtuple.losses, "global_step": modelNamedtuple.global_step, } if should(args.progress_freq): fetches["discrim_loss"] = modelNamedtuple.discrim_loss fetches["gen_loss_GAN"] = modelNamedtuple.gen_loss_GAN fetches["gen_loss_L1"] = modelNamedtuple.gen_loss_L1 if should(args.summary_freq): fetches["summary"] = summary_op if should(args.display_freq): fetches["display"] = display_fetches # results = sess.run(fetches, options=options, run_metadata=run_metadata) results = sess.run(fetches) if should(args.summary_freq): # print("recording summary") summary_writer.add_summary(results["summary"], results["global_step"]) if should(args.display_freq): # print("saving display images") filesets = save_images(results["display"], step=results["global_step"]) append_index(filesets, step=True) if should(args.progress_freq): # global_step will have the correct step count if we resume from a checkpoint train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch) train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1 rate = (step + 1) * args.batch_size / (time.time() - start) remaining = (max_steps - step) * args.batch_size / rate print( "progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) if should(args.save_freq): print("saving model...") saver.save(sess, os.path.join(args.output_dir, "model"), global_step=modelNamedtuple.global_step) coord.request_stop() coord.join(threads)