def image_to_patches(IMG1, IMG2, PATCH_HEIGHT, PATCH_WIDTH, IMG2_Path, IMG1_Path, filename, k, pair_thres, adj_thres): prev_patch = None for i in range(0,IMG1.size[1],PATCH_HEIGHT): for j in range(0,IMG1.size[0],PATCH_WIDTH): if(j + PATCH_WIDTH <= IMG1.size[0] and i + PATCH_HEIGHT <= IMG1.size[1]): box = (j, i, j+PATCH_WIDTH, i+PATCH_HEIGHT) IMG2_patch = IMG2.crop(box) IMG1_patch = IMG1.crop(box) IMG1_cv2 = IMG1_patch.convert('RGB') IMG1_cv2 = np.array(IMG1_cv2) IMG1_cv2 = cv2.cvtColor(IMG1_cv2, cv2.COLOR_BGR2RGB) #pair_eval = compare_ssim(np.array(IMG1_patch), np.array(IMG2_patch), multichannel=True) pair_eval = MultiScaleSSIM(np.expand_dims(IMG1_patch, axis=0), np.expand_dims(IMG2_patch, axis=0), max_val=255) if(pair_eval >= pair_thres and (prev_patch is None or (prev_patch is not None and compare_ssim(IMG1_cv2, prev_patch, multichannel=True) <= adj_thres))): IMG2_patch.save(IMG2_Path + '(' + str(k) + ").jpg") IMG1_patch.save(IMG1_Path + '(' + str(k) + ").jpg") k = k + 1 prev_patch = IMG1_cv2 return k
def main(args): # loading training and test data logger.info("Loading test data...") test_data, test_answ = load_test_data(args.dataset, args.dataset_dir, args.test_size, args.patch_size) logger.info("Test data was loaded\n") logger.info("Loading training data...") train_data, train_answ = load_batch(args.dataset, args.dataset_dir, args.train_size, args.patch_size) logger.info("Training data was loaded\n") TEST_SIZE = test_data.shape[0] num_test_batches = int(test_data.shape[0] / args.batch_size) # defining system architecture with tf.Graph().as_default(), tf.Session() as sess: # placeholders for training data phone_ = tf.placeholder(tf.float32, [None, args.patch_size]) phone_image = tf.reshape(phone_, [-1, args.patch_height, args.patch_width, 3]) dslr_ = tf.placeholder(tf.float32, [None, args.patch_size]) dslr_image = tf.reshape(dslr_, [-1, args.patch_height, args.patch_width, 3]) adv_ = tf.placeholder(tf.float32, [None, 1]) enhanced = unet(phone_image) [w, h, d] = enhanced.get_shape().as_list()[1:] # # learning rate exponential_decay # global_step = tf.Variable(0) # learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps=args.train_size / args.batch_size, decay_rate=0.98, staircase=True) ## loss introduce ''' content loss three ways : 1. vgg_loss: mat model load; 2. vgg_loss: npy model load; 3. iqa model(meon_loss): feature and scores ''' # vgg = vgg19_loss.Vgg19(vgg_path=args.pretrain_weights) # # load vgg models # vgg_content = 2000*tf.reduce_mean(tf.sqrt(tf.reduce_sum( # tf.square((vgg.extract_feature(enhanced) - vgg.extract_feature(dslr_image))))) / (w * h * d)) # # loss_content = multi_content_loss(args.pretrain_weights, enhanced, dslr_image, args.batch_size) # change another way # meon loss # with tf.variable_scope('meon_loss') as scope: # load ckpt is not conveient. MEON_evaluate_model, loss_content = meon_loss(dslr_image, enhanced) loss_texture, discim_accuracy = texture_loss(enhanced, dslr_image, args.patch_width, args.patch_height, adv_) loss_discrim = -loss_texture loss_color = color_loss(enhanced, dslr_image, args.batch_size) loss_tv = variation_loss(enhanced, args.patch_width, args.patch_height, args.batch_size) loss_psnr = PSNR(enhanced, dslr_image) loss_ssim = MultiScaleSSIM(enhanced, dslr_image) loss_generator = args.w_content * loss_content + args.w_texture * loss_texture + args.w_tv * loss_tv + 1000 * ( 1 - loss_ssim) + args.w_color * loss_color # optimize parameters of image enhancement (generator) and discriminator networks generator_vars = [v for v in tf.global_variables() if v.name.startswith("generator")] discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("discriminator")] meon_vars = [v for v in tf.global_variables() if v.name.startswith("conv") or v.name.startswith("subtask")] # train_step_gen = tf.train.AdamOptimizer(args.learning_rate).minimize(loss_generator, var_list=generator_vars) # train_step_disc = tf.train.AdamOptimizer(args.learning_rate).minimize(loss_discrim, var_list=discriminator_vars) train_step_gen = tf.train.AdamOptimizer(5e-5).minimize(loss_generator, var_list=generator_vars) train_step_disc = tf.train.AdamOptimizer(5e-5).minimize(loss_discrim, var_list=discriminator_vars) saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100) meon_saver = tf.train.Saver(var_list=meon_vars) logger.info('Initializing variables') sess.run(tf.global_variables_initializer()) logger.info('Training network') train_loss_gen = 0.0 train_acc_discrim = 0.0 all_zeros = np.reshape(np.zeros((args.batch_size, 1)), [args.batch_size, 1]) test_crops = test_data[np.random.randint(0, TEST_SIZE, 5), :] # choose five images to visual # summary ,add the scalar you want to see tf.summary.scalar('loss_generator', loss_generator), tf.summary.scalar('loss_content', loss_content), tf.summary.scalar('loss_color', loss_color), tf.summary.scalar('loss_texture', loss_texture), tf.summary.scalar('loss_tv', loss_tv), tf.summary.scalar('discim_accuracy', discim_accuracy), tf.summary.scalar('psnr', loss_psnr), tf.summary.scalar('ssim', loss_ssim), tf.summary.scalar('learning_rate', args.learning_rate), merge_summary = tf.summary.merge_all() train_writer = tf.summary.FileWriter(os.path.join(args.tesorboard_logs_dir, 'train', args.exp_name), sess.graph, filename_suffix=args.exp_name) test_writer = tf.summary.FileWriter(os.path.join(args.tesorboard_logs_dir, 'test', args.exp_name), sess.graph, filename_suffix=args.exp_name) tf.global_variables_initializer().run() '''load ckpt models''' ckpt = tf.train.get_checkpoint_state(args.checkpoint_dir) start_i = 0 if ckpt and ckpt.model_checkpoint_path: logger.info('loading checkpoint:' + ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) import re start_i = int(re.findall("_(\d+).ckpt", ckpt.model_checkpoint_path)[0]) MEON_evaluate_model.initialize(sess, meon_saver, args.meod_ckpt_path) # initialize with anohter model pretrained weights '''start training...''' for i in range(start_i, args.iter_max): iter_start = time.time() # train generator idx_train = np.random.randint(0, args.train_size, args.batch_size) phone_images = train_data[idx_train] dslr_images = train_answ[idx_train] [loss_temp, temp] = sess.run([loss_generator, train_step_gen], feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: all_zeros}) train_loss_gen += loss_temp / args.eval_step # train discriminator idx_train = np.random.randint(0, args.train_size, args.batch_size) # generate image swaps (dslr or enhanced) for discriminator swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1]) phone_images = train_data[idx_train] dslr_images = train_answ[idx_train] # sess.run(train_step_disc)=train_step_disc.compute_gradients(loss,var)+train_step_disc.apply_gradients(var) @20190105 [accuracy_temp, temp] = sess.run([discim_accuracy, train_step_disc], feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps}) train_acc_discrim += accuracy_temp / args.eval_step if i % args.summary_step == 0: # summary intervals # enhance_f1_, enhance_f2_, enhance_s_, vgg_content_ = sess.run([enhance_f1, enhance_f2, enhance_s,vgg_content], # feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps}) # loss_content1_, loss_content2_, loss_content3_ = sess.run([loss_content1,loss_content2,loss_content3], # feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps}) # print("-----------------------------------------------") # print(enhance_f1_, enhance_f2_, enhance_s_,vgg_content_,loss_content1_, loss_content2_, loss_content3_) # print("-----------------------------------------------") train_summary = sess.run(merge_summary, feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps}) train_writer.add_summary(train_summary, i) if i % args.eval_step == 0: # test generator and discriminator CNNs test_losses_gen = np.zeros((1, 7)) test_accuracy_disc = 0.0 for j in range(num_test_batches): be = j * args.batch_size en = (j + 1) * args.batch_size swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1]) phone_images = test_data[be:en] dslr_images = test_answ[be:en] [enhanced_crops, accuracy_disc, losses] = sess.run([enhanced, discim_accuracy, \ [loss_generator, loss_content, loss_color, loss_texture, loss_tv, loss_psnr, loss_ssim]], \ feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps}) test_losses_gen += np.asarray(losses) / num_test_batches test_accuracy_disc += accuracy_disc / num_test_batches logs_disc = "step %d/%d, %s | discriminator accuracy | train: %.4g, test: %.4g" % \ (i, args.iter_max, args.dataset, train_acc_discrim, test_accuracy_disc) logs_gen = "generator losses | train: %.4g, test: %.4g | content: %.4g, color: %.4g, texture: %.4g, tv: %.4g | psnr: %.4g, ssim: %.4g\n" % \ (train_loss_gen, test_losses_gen[0][0], test_losses_gen[0][1], test_losses_gen[0][2], test_losses_gen[0][3], test_losses_gen[0][4], test_losses_gen[0][5], test_losses_gen[0][6]) logger.info(logs_disc) logger.info(logs_gen) test_summary = sess.run(merge_summary, feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps}) test_writer.add_summary(test_summary, i) # save visual results for several test image crops if args.save_visual_result: enhanced_crops = sess.run(enhanced, feed_dict={phone_: test_crops, dslr_: dslr_images, adv_: all_zeros}) idx = 0 for crop in enhanced_crops: before_after = np.hstack( (np.reshape(test_crops[idx], [args.patch_height, args.patch_width, 3]), crop)) misc.imsave( os.path.join(args.checkpoint_dir, str(args.dataset) + str(idx) + '_iteration_' + str(i) + '.jpg'), before_after) idx += 1 # save the model that corresponds to the current iteration if args.save_ckpt_file: saver.save(sess, os.path.join(args.checkpoint_dir, str(args.dataset) + '_iteration_' + str(i) + '.ckpt'), write_meta_graph=False) train_loss_gen = 0.0 train_acc_discrim = 0.0 # reload a different batch of training data del train_data del train_answ del test_data del test_answ test_data, test_answ = load_test_data(args.dataset, args.dataset_dir, args.test_size, args.patch_size) train_data, train_answ = load_batch(args.dataset, args.dataset_dir, args.train_size, args.patch_size)
def Mssim_loss(target, prediction): loss_Mssim = 1 - MultiScaleSSIM(target, prediction) return loss_Mssim * 1000
def main(args, data_params): procname = os.path.basename(args.checkpoint_dir) log.info('Preparing summary and checkpoint directory {}'.format( args.checkpoint_dir)) if not os.path.exists(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) tf.set_random_seed(1234) # Make experiments repeatable # Select an architecture # Add model parameters to the graph (so they are saved to disk at checkpoint) # --- Train/Test datasets --------------------------------------------------- data_pipe = getattr(dp, args.data_pipeline) with tf.variable_scope('train_data'): train_data_pipeline = data_pipe( args.data_dir, shuffle=True, batch_size=args.batch_size, nthreads=args.data_threads, fliplr=args.fliplr, flipud=args.flipud, rotate=args.rotate, random_crop=args.random_crop, params=data_params, output_resolution=args.output_resolution, scale=args.scale) train_samples = train_data_pipeline.samples if args.eval_data_dir is not None: with tf.variable_scope('eval_data'): eval_data_pipeline = data_pipe( args.eval_data_dir, shuffle=True, batch_size=args.batch_size, nthreads=args.data_threads, fliplr=False, flipud=False, rotate=False, random_crop=False, params=data_params, output_resolution=args.output_resolution, scale=args.scale) eval_samples = eval_data_pipeline.samples # --------------------------------------------------------------------------- swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1]) swaps = tf.convert_to_tensor(swaps) swaps = tf.cast(swaps, tf.float32) # Training graph with tf.variable_scope('inference'): prediction = unet(train_samples['image_input']) loss,loss_content,loss_texture,loss_color,loss_Mssim,loss_tv,discim_accuracy =\ compute_loss.total_loss(train_samples['image_output'], prediction, swaps, args.batch_size) psnr = PSNR(train_samples['image_output'], prediction) loss_ssim = MultiScaleSSIM(train_samples['image_output'], prediction) # Evaluation graph if args.eval_data_dir is not None: with tf.name_scope('eval'): with tf.variable_scope('inference', reuse=True): eval_prediction = unet(eval_samples['image_input']) eval_psnr = PSNR(eval_samples['image_output'], eval_prediction) eval_ssim = MultiScaleSSIM(eval_samples['image_output'], eval_prediction) # Optimizer model_vars1 = [ v for v in tf.global_variables() if v.name.startswith("inference/generator") ] discriminator_vars1 = [ v for v in tf.global_variables() if v.name.startswith("inference/l2_loss/discriminator") ] global_step = tf.contrib.framework.get_or_create_global_step() with tf.name_scope('optimizer'): update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) updates = tf.group(*update_ops, name='update_ops') log.info("Adding {} update ops".format(len(update_ops))) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses and args.weight_decay is not None and args.weight_decay > 0: print("Regularization losses:") for rl in reg_losses: print(" ", rl.name) opt_loss = loss + args.weight_decay * sum(reg_losses) else: print("No regularization.") opt_loss = loss with tf.control_dependencies([updates]): opt = tf.train.AdamOptimizer(args.learning_rate) minimize = opt.minimize(opt_loss, name='optimizer', global_step=global_step, var_list=model_vars1) minimize_discrim = opt.minimize(-loss_texture, name='discriminator', global_step=global_step, var_list=discriminator_vars1) # Average loss and psnr for display with tf.name_scope("moving_averages"): ema = tf.train.ExponentialMovingAverage(decay=0.99) update_ma = ema.apply([ loss, loss_content, loss_texture, loss_color, loss_Mssim, loss_tv, discim_accuracy, psnr, loss_ssim ]) loss = ema.average(loss) loss_content = ema.average(loss_content) loss_texture = ema.average(loss_texture) loss_color = ema.average(loss_color) loss_Mssim = ema.average(loss_Mssim) loss_tv = ema.average(loss_tv) discim_accuracy = ema.average(discim_accuracy) psnr = ema.average(psnr) loss_ssim = ema.average(loss_ssim) # Training stepper operation train_op = tf.group(minimize, update_ma) train_discrim_op = tf.group(minimize_discrim, update_ma) # Save a few graphs to summaries = [ tf.summary.scalar('loss', loss), tf.summary.scalar('loss_content', loss_content), tf.summary.scalar('loss_color', loss_color), tf.summary.scalar('loss_texture', loss_texture), tf.summary.scalar('loss_ssim', loss_Mssim), tf.summary.scalar('loss_tv', loss_tv), tf.summary.scalar('discim_accuracy', discim_accuracy), tf.summary.scalar('psnr', psnr), tf.summary.scalar('ssim', loss_ssim), tf.summary.scalar('learning_rate', args.learning_rate), tf.summary.scalar('batch_size', args.batch_size), ] log_fetches = { "loss_content": loss_content, "loss_texture": loss_texture, "loss_color": loss_color, "loss_Mssim": loss_Mssim, "loss_tv": loss_tv, "discim_accuracy": discim_accuracy, "step": global_step, "loss": loss, "psnr": psnr, "loss_ssim": loss_ssim } model_vars = [ v for v in tf.global_variables() if not v.name.startswith("inference/l2_loss/discriminator") ] discriminator_vars = [ v for v in tf.global_variables() if v.name.startswith("inference/l2_loss/discriminator") ] # Train config config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not canibalize the entire GPU sv = tf.train.Supervisor( saver=tf.train.Saver(var_list=model_vars, max_to_keep=100), local_init_op=tf.initialize_variables(discriminator_vars), logdir=args.checkpoint_dir, save_summaries_secs=args.summary_interval, save_model_secs=args.checkpoint_interval) # Train loop with sv.managed_session(config=config) as sess: sv.loop(args.log_interval, log_hook, (sess, log_fetches)) last_eval = time.time() while True: if sv.should_stop(): log.info("stopping supervisor") break try: step, _ = sess.run([global_step, train_op]) _ = sess.run(train_discrim_op) since_eval = time.time() - last_eval if args.eval_data_dir is not None and since_eval > args.eval_interval: log.info("Evaluating on {} images at step {}".format( 3, step)) p_ = 0 s_ = 0 for it in range(3): p_ += sess.run(eval_psnr) s_ += sess.run(eval_ssim) p_ /= 3 s_ /= 3 sv.summary_writer.add_summary(tf.Summary(value=[ tf.Summary.Value(tag="psnr/eval", simple_value=p_) ]), global_step=step) sv.summary_writer.add_summary(tf.Summary(value=[ tf.Summary.Value(tag="ssim/eval", simple_value=s_) ]), global_step=step) log.info(" Evaluation PSNR = {:.2f} dB".format(p_)) log.info(" Evaluation SSIM = {:.4f} ".format(s_)) last_eval = time.time() except tf.errors.AbortedError: log.error("Aborted") break except KeyboardInterrupt: break chkpt_path = os.path.join(args.checkpoint_dir, 'on_stop.ckpt') log.info("Training complete, saving chkpt {}".format(chkpt_path)) sv.saver.save(sess, chkpt_path) sv.request_stop()
enhanced = EDSR(phone_image) print enhanced.shape #loss introduce # loss_texture, discim_accuracy = texture_loss(enhanced,dslr_image,PATCH_WIDTH,PATCH_HEIGHT,adv_) # loss_discrim = -loss_texture # loss_content = content_loss(vgg_dir,enhanced,dslr_image,batch_size) # loss_color = color_loss(enhanced, dslr_image, batch_size) # loss_tv = variation_loss(enhanced,PATCH_WIDTH,PATCH_HEIGHT,batch_size) # loss_generator = w_content * loss_content + w_texture * loss_texture + w_color * loss_color + w_tv * loss_tv loss_generator = tf.losses.absolute_difference(labels=dslr_image, predictions=enhanced) loss_psnr = PSNR(enhanced, dslr_, PATCH_SIZE, batch_size) loss_ssim = MultiScaleSSIM(enhanced, dslr_image) # optimize parameters of image enhancement (generator) and discriminator networks generator_vars = [ v for v in tf.global_variables() if v.name.startswith("generator") ] # discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("discriminator")] train_step_gen = tf.train.AdamOptimizer(learning_rate).minimize( loss_generator, var_list=generator_vars) # train_step_disc = tf.train.AdamOptimizer(learning_rate).minimize(loss_discrim, var_list=discriminator_vars) saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100) print('Initializing variables') sess.run(tf.global_variables_initializer())