def main(): args = parseArguments() MAX_TRAIN_TIME_MINS = args.time LEARNING_RATE = args.lrate CHECKPOINT_FILE = args.check CHECKPOINT_DIR = args.checkpoint_dir BATCH_SIZE = args.batch_size SAMPLE_STEP = args.sample_freq SAVE_STEP = args.checkpoint_freq SOFT_LABELS = args.softL LOG_DIR = args.logdir LOG_FREQUENCY = args.log_frequency PIPELINE_TWEAKS['random_flip'] = args.random_flip PIPELINE_TWEAKS['random_brightness'] = PIPELINE_TWEAKS[ 'random_saturation'] = PIPELINE_TWEAKS[ 'random_contrast'] = args.random_q PIPELINE_TWEAKS['crop_size'] = args.crop if SOFT_LABELS: softL_c = 0.05 #softL_c = np.random.normal(1,0.05) #if softL_c > 1.15: softL_c = 1.15 #if softL_c < 0.85: softL_c = 0.85 else: softL_c = 0.0 print('Soft Labeling: ', softL_c) sess = tf.Session() # DEFINE OUR MODEL AND LOSS FUNCTIONS # ------------------------------------------------------- real_X = Images(args.input_prefix + '_trainA.tfrecords', batch_size=BATCH_SIZE, name='real_X').feed() real_Y = Images(args.input_prefix + '_trainB.tfrecords', batch_size=BATCH_SIZE, name='real_Y').feed() # genG(X) => Y - fake_B genG = generator(real_X, norm=args.norm, rnorm=args.rnorm, name="generatorG") # genF(Y) => X - fake_A genF = generator(real_Y, norm=args.norm, rnorm=args.rnorm, name="generatorF") # genF( genG(Y) ) => Y - fake_A_ genF_back = generator(genG, norm=args.norm, rnorm=args.rnorm, name="generatorF", reuse=True) # genF( genG(X)) => X - fake_B_ genG_back = generator(genF, norm=args.norm, rnorm=args.rnorm, name="generatorG", reuse=True) # DY_fake is the discriminator for Y that takes in genG(X) # DX_fake is the discriminator for X that takes in genF(Y) discY_fake = discriminator(genG, norm=args.norm, reuse=False, name="discY") discX_fake = discriminator(genF, norm=args.norm, reuse=False, name="discX") g_loss_G = tf.reduce_mean((discY_fake - tf.ones_like(discY_fake) * np.abs(np.random.normal(1.0,softL_c))) ** 2) \ + L1_lambda * tf.reduce_mean(tf.abs(real_X - genF_back)) \ + L1_lambda * tf.reduce_mean(tf.abs(real_Y - genG_back)) g_loss_F = tf.reduce_mean((discX_fake - tf.ones_like(discX_fake) * np.abs(np.random.normal(1.0,softL_c))) ** 2) \ + L1_lambda * tf.reduce_mean(tf.abs(real_X - genF_back)) \ + L1_lambda * tf.reduce_mean(tf.abs(real_Y - genG_back)) fake_X_sample = tf.placeholder(tf.float32, [None, 256, 256, 3], name="fake_X_sample") fake_Y_sample = tf.placeholder(tf.float32, [None, 256, 256, 3], name="fake_Y_sample") # DY is the discriminator for Y that takes in Y # DX is the discriminator for X that takes in X DY = discriminator(real_Y, norm=args.norm, reuse=True, name="discY") DX = discriminator(real_X, norm=args.norm, reuse=True, name="discX") DY_fake_sample = discriminator(fake_Y_sample, norm=args.norm, reuse=True, name="discY") DX_fake_sample = discriminator(fake_X_sample, norm=args.norm, reuse=True, name="discX") DY_loss_real = tf.reduce_mean( (DY - tf.ones_like(DY) * np.abs(np.random.normal(1.0, softL_c)))**2) DY_loss_fake = tf.reduce_mean( (DY_fake_sample - tf.zeros_like(DY_fake_sample))**2) DY_loss = (DY_loss_real + DY_loss_fake) / 2 DX_loss_real = tf.reduce_mean( (DX - tf.ones_like(DX) * np.abs(np.random.normal(1.0, softL_c)))**2) DX_loss_fake = tf.reduce_mean( (DX_fake_sample - tf.zeros_like(DX_fake_sample))**2) DX_loss = (DX_loss_real + DX_loss_fake) / 2 test_X = Images(args.input_prefix + '_testA.tfrecords', shuffle=False, name='test_A').feed() test_Y = Images(args.input_prefix + '_testB.tfrecords', shuffle=False, name='test_B').feed() testG = generator(test_X, norm=args.norm, rnorm=args.rnorm, name="generatorG", reuse=True) testF = generator(test_Y, norm=args.norm, rnorm=args.rnorm, name="generatorF", reuse=True) testF_back = generator(testG, norm=args.norm, rnorm=args.rnorm, name="generatorF", reuse=True) testG_back = generator(testF, norm=args.norm, rnorm=args.rnorm, name="generatorG", reuse=True) t_vars = tf.trainable_variables() DY_vars = [v for v in t_vars if 'discY' in v.name] DX_vars = [v for v in t_vars if 'discX' in v.name] g_vars_G = [v for v in t_vars if 'generatorG' in v.name] g_vars_F = [v for v in t_vars if 'generatorF' in v.name] # SETUP OUR SUMMARY VARIABLES FOR MONITORING # ------------------------------------------------------- G_loss_sum = tf.summary.scalar("loss/G", g_loss_G) F_loss_sum = tf.summary.scalar("loss/F", g_loss_F) DY_loss_sum = tf.summary.scalar("loss/DY", DY_loss) DX_loss_sum = tf.summary.scalar("loss/DX", DX_loss) DY_loss_real_sum = tf.summary.scalar("loss/DY_real", DY_loss_real) DY_loss_fake_sum = tf.summary.scalar("loss/DY_fake", DY_loss_fake) DX_loss_real_sum = tf.summary.scalar("loss/DX_real", DX_loss_real) DX_loss_fake_sum = tf.summary.scalar("loss/DX_fake", DX_loss_fake) imgX = tf.summary.image('real_X', real_X, max_outputs=1) imgF = tf.summary.image('fake_X', genF, max_outputs=1) imgY = tf.summary.image('real_Y', real_Y, max_outputs=1) imgG = tf.summary.image('fake_Y', genG, max_outputs=1) # SETUP OUR TRAINING # ------------------------------------------------------- def adam(loss, variables, start_lr, end_lr, lr_decay_start, start_beta, name_prefix): name = name_prefix + '_adam' global_step = tf.Variable(0, trainable=False) # The paper recommends learning at a fixed rate for several steps, and then linearly stepping down to 0 learning_rate = (tf.where( tf.greater_equal(global_step, lr_decay_start), tf.train.polynomial_decay(start_lr, global_step - lr_decay_start, lr_decay_start, end_lr, power=1.0), start_lr)) lr_sum = tf.summary.scalar('learning_rate/{}'.format(name), learning_rate) learning_step = (tf.train.AdamOptimizer(learning_rate, beta1=start_beta, name=name).minimize( loss, global_step=global_step, var_list=variables)) return learning_step, lr_sum DX_optim, DX_lr = adam(DX_loss, DX_vars, LEARNING_RATE, args.end_lr, args.lr_decay_start, MOMENTUM, 'D_X') DY_optim, DY_lr = adam(DY_loss, DY_vars, LEARNING_RATE, args.end_lr, args.lr_decay_start, MOMENTUM, 'D_Y') G_optim, G_lr = adam(g_loss_G, g_vars_G, LEARNING_RATE, args.end_lr, args.lr_decay_start, MOMENTUM, 'G') F_optim, F_lr = adam(g_loss_F, g_vars_F, LEARNING_RATE, args.end_lr, args.lr_decay_start, MOMENTUM, 'F') G_sum = tf.summary.merge([G_loss_sum, G_lr]) F_sum = tf.summary.merge([F_loss_sum, F_lr]) DY_sum = tf.summary.merge( [DY_loss_sum, DY_loss_real_sum, DY_loss_fake_sum, DY_lr]) DX_sum = tf.summary.merge( [DX_loss_sum, DX_loss_real_sum, DX_loss_fake_sum, DX_lr]) images_sum = tf.summary.merge([imgX, imgG, imgY, imgF]) # CREATE AND RUN OUR TRAINING LOOP # ------------------------------------------------------- print("Starting the time") timer = utils.Timer() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state('./checkpoint/') if ckpt and ckpt.model_checkpoint_path and not args.ignore_checkpoint: saver.restore(sess, ckpt.model_checkpoint_path) print("Reading model parameters from %s" % ckpt.model_checkpoint_path) else: print("Created model with fresh parameters.") coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) summary_op = tf.summary.merge_all() writer = tf.summary.FileWriter(LOG_DIR, sess.graph) cache_X = ImageCache(50) cache_Y = ImageCache(50) counter = 0 try: while not coord.should_stop(): # FORWARD PASS generated_X, generated_Y = sess.run([genF, genG]) _, _, _, _, summary_str = sess.run( [G_optim, DY_optim, F_optim, DX_optim, summary_op], feed_dict={ fake_Y_sample: cache_Y.fetch(generated_Y), fake_X_sample: cache_X.fetch(generated_X) }) counter += 1 print("[%4d] time: %4.4f" % (counter, time.time() - start_time)) if np.mod(counter, LOG_FREQUENCY) == 0: print('writing') writer.add_summary(summary_str, counter) if np.mod(counter, SAMPLE_STEP) == 0: sample_model(sess, counter, test_X, test_Y, testG, testF, testG_back, testF_back) if np.mod(counter, SAVE_STEP) == 0: save_path = save_model(saver, sess, counter) print("Running for '{0:.2}' mins, saving to {1}".format( timer.elapsed() / 60, save_path)) if np.mod(counter, SAVE_STEP) == 0: elapsed_min = timer.elapsed() / 60 if (elapsed_min >= MAX_TRAIN_TIME_MINS): print( "Trained for '{0:.2}' mins and reached the max limit of {1}. Saving model." .format(elapsed_min, MAX_TRAIN_TIME_MINS)) coord.request_stop() except KeyboardInterrupt: print('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = save_model(saver, sess, counter) print("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)