Example #1
0
def main():
    args = parseArguments()

    MAX_TRAIN_TIME_MINS = args.time
    LEARNING_RATE = args.lrate
    CHECKPOINT_FILE = args.check
    CHECKPOINT_DIR = args.checkpoint_dir
    BATCH_SIZE = args.batch_size
    SAMPLE_STEP = args.sample_freq
    SAVE_STEP = args.checkpoint_freq
    SOFT_LABELS = args.softL
    LOG_DIR = args.logdir
    LOG_FREQUENCY = args.log_frequency
    PIPELINE_TWEAKS['random_flip'] = args.random_flip
    PIPELINE_TWEAKS['random_brightness'] = PIPELINE_TWEAKS[
        'random_saturation'] = PIPELINE_TWEAKS[
            'random_contrast'] = args.random_q
    PIPELINE_TWEAKS['crop_size'] = args.crop

    if SOFT_LABELS:
        softL_c = 0.05
        #softL_c = np.random.normal(1,0.05)
        #if softL_c > 1.15: softL_c = 1.15
        #if softL_c < 0.85: softL_c = 0.85
    else:
        softL_c = 0.0
    print('Soft Labeling: ', softL_c)

    sess = tf.Session()

    # DEFINE OUR MODEL AND LOSS FUNCTIONS
    # -------------------------------------------------------

    real_X = Images(args.input_prefix + '_trainA.tfrecords',
                    batch_size=BATCH_SIZE,
                    name='real_X').feed()
    real_Y = Images(args.input_prefix + '_trainB.tfrecords',
                    batch_size=BATCH_SIZE,
                    name='real_Y').feed()

    # genG(X) => Y            - fake_B
    genG = generator(real_X,
                     norm=args.norm,
                     rnorm=args.rnorm,
                     name="generatorG")
    # genF(Y) => X            - fake_A
    genF = generator(real_Y,
                     norm=args.norm,
                     rnorm=args.rnorm,
                     name="generatorF")
    # genF( genG(Y) ) => Y    - fake_A_
    genF_back = generator(genG,
                          norm=args.norm,
                          rnorm=args.rnorm,
                          name="generatorF",
                          reuse=True)
    # genF( genG(X)) => X     - fake_B_
    genG_back = generator(genF,
                          norm=args.norm,
                          rnorm=args.rnorm,
                          name="generatorG",
                          reuse=True)

    # DY_fake is the discriminator for Y that takes in genG(X)
    # DX_fake is the discriminator for X that takes in genF(Y)
    discY_fake = discriminator(genG, norm=args.norm, reuse=False, name="discY")
    discX_fake = discriminator(genF, norm=args.norm, reuse=False, name="discX")

    g_loss_G = tf.reduce_mean((discY_fake - tf.ones_like(discY_fake) * np.abs(np.random.normal(1.0,softL_c))) ** 2) \
            + L1_lambda * tf.reduce_mean(tf.abs(real_X - genF_back)) \
            + L1_lambda * tf.reduce_mean(tf.abs(real_Y - genG_back))

    g_loss_F = tf.reduce_mean((discX_fake - tf.ones_like(discX_fake) * np.abs(np.random.normal(1.0,softL_c))) ** 2) \
            + L1_lambda * tf.reduce_mean(tf.abs(real_X - genF_back)) \
            + L1_lambda * tf.reduce_mean(tf.abs(real_Y - genG_back))

    fake_X_sample = tf.placeholder(tf.float32, [None, 256, 256, 3],
                                   name="fake_X_sample")
    fake_Y_sample = tf.placeholder(tf.float32, [None, 256, 256, 3],
                                   name="fake_Y_sample")

    # DY is the discriminator for Y that takes in Y
    # DX is the discriminator for X that takes in X
    DY = discriminator(real_Y, norm=args.norm, reuse=True, name="discY")
    DX = discriminator(real_X, norm=args.norm, reuse=True, name="discX")
    DY_fake_sample = discriminator(fake_Y_sample,
                                   norm=args.norm,
                                   reuse=True,
                                   name="discY")
    DX_fake_sample = discriminator(fake_X_sample,
                                   norm=args.norm,
                                   reuse=True,
                                   name="discX")

    DY_loss_real = tf.reduce_mean(
        (DY - tf.ones_like(DY) * np.abs(np.random.normal(1.0, softL_c)))**2)
    DY_loss_fake = tf.reduce_mean(
        (DY_fake_sample - tf.zeros_like(DY_fake_sample))**2)
    DY_loss = (DY_loss_real + DY_loss_fake) / 2

    DX_loss_real = tf.reduce_mean(
        (DX - tf.ones_like(DX) * np.abs(np.random.normal(1.0, softL_c)))**2)
    DX_loss_fake = tf.reduce_mean(
        (DX_fake_sample - tf.zeros_like(DX_fake_sample))**2)
    DX_loss = (DX_loss_real + DX_loss_fake) / 2

    test_X = Images(args.input_prefix + '_testA.tfrecords',
                    shuffle=False,
                    name='test_A').feed()
    test_Y = Images(args.input_prefix + '_testB.tfrecords',
                    shuffle=False,
                    name='test_B').feed()

    testG = generator(test_X,
                      norm=args.norm,
                      rnorm=args.rnorm,
                      name="generatorG",
                      reuse=True)
    testF = generator(test_Y,
                      norm=args.norm,
                      rnorm=args.rnorm,
                      name="generatorF",
                      reuse=True)
    testF_back = generator(testG,
                           norm=args.norm,
                           rnorm=args.rnorm,
                           name="generatorF",
                           reuse=True)
    testG_back = generator(testF,
                           norm=args.norm,
                           rnorm=args.rnorm,
                           name="generatorG",
                           reuse=True)

    t_vars = tf.trainable_variables()
    DY_vars = [v for v in t_vars if 'discY' in v.name]
    DX_vars = [v for v in t_vars if 'discX' in v.name]
    g_vars_G = [v for v in t_vars if 'generatorG' in v.name]
    g_vars_F = [v for v in t_vars if 'generatorF' in v.name]

    # SETUP OUR SUMMARY VARIABLES FOR MONITORING
    # -------------------------------------------------------

    G_loss_sum = tf.summary.scalar("loss/G", g_loss_G)
    F_loss_sum = tf.summary.scalar("loss/F", g_loss_F)
    DY_loss_sum = tf.summary.scalar("loss/DY", DY_loss)
    DX_loss_sum = tf.summary.scalar("loss/DX", DX_loss)
    DY_loss_real_sum = tf.summary.scalar("loss/DY_real", DY_loss_real)
    DY_loss_fake_sum = tf.summary.scalar("loss/DY_fake", DY_loss_fake)
    DX_loss_real_sum = tf.summary.scalar("loss/DX_real", DX_loss_real)
    DX_loss_fake_sum = tf.summary.scalar("loss/DX_fake", DX_loss_fake)

    imgX = tf.summary.image('real_X', real_X, max_outputs=1)
    imgF = tf.summary.image('fake_X', genF, max_outputs=1)
    imgY = tf.summary.image('real_Y', real_Y, max_outputs=1)
    imgG = tf.summary.image('fake_Y', genG, max_outputs=1)

    # SETUP OUR TRAINING
    # -------------------------------------------------------

    def adam(loss, variables, start_lr, end_lr, lr_decay_start, start_beta,
             name_prefix):
        name = name_prefix + '_adam'
        global_step = tf.Variable(0, trainable=False)
        # The paper recommends learning at a fixed rate for several steps, and then linearly stepping down to 0
        learning_rate = (tf.where(
            tf.greater_equal(global_step, lr_decay_start),
            tf.train.polynomial_decay(start_lr,
                                      global_step - lr_decay_start,
                                      lr_decay_start,
                                      end_lr,
                                      power=1.0), start_lr))
        lr_sum = tf.summary.scalar('learning_rate/{}'.format(name),
                                   learning_rate)

        learning_step = (tf.train.AdamOptimizer(learning_rate,
                                                beta1=start_beta,
                                                name=name).minimize(
                                                    loss,
                                                    global_step=global_step,
                                                    var_list=variables))
        return learning_step, lr_sum

    DX_optim, DX_lr = adam(DX_loss, DX_vars, LEARNING_RATE, args.end_lr,
                           args.lr_decay_start, MOMENTUM, 'D_X')

    DY_optim, DY_lr = adam(DY_loss, DY_vars, LEARNING_RATE, args.end_lr,
                           args.lr_decay_start, MOMENTUM, 'D_Y')

    G_optim, G_lr = adam(g_loss_G, g_vars_G, LEARNING_RATE, args.end_lr,
                         args.lr_decay_start, MOMENTUM, 'G')

    F_optim, F_lr = adam(g_loss_F, g_vars_F, LEARNING_RATE, args.end_lr,
                         args.lr_decay_start, MOMENTUM, 'F')

    G_sum = tf.summary.merge([G_loss_sum, G_lr])
    F_sum = tf.summary.merge([F_loss_sum, F_lr])
    DY_sum = tf.summary.merge(
        [DY_loss_sum, DY_loss_real_sum, DY_loss_fake_sum, DY_lr])
    DX_sum = tf.summary.merge(
        [DX_loss_sum, DX_loss_real_sum, DX_loss_fake_sum, DX_lr])

    images_sum = tf.summary.merge([imgX, imgG, imgY, imgF])

    # CREATE AND RUN OUR TRAINING LOOP
    # -------------------------------------------------------

    print("Starting the time")
    timer = utils.Timer()

    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(tf.global_variables())
    ckpt = tf.train.get_checkpoint_state('./checkpoint/')

    if ckpt and ckpt.model_checkpoint_path and not args.ignore_checkpoint:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    summary_op = tf.summary.merge_all()
    writer = tf.summary.FileWriter(LOG_DIR, sess.graph)

    cache_X = ImageCache(50)
    cache_Y = ImageCache(50)

    counter = 0
    try:
        while not coord.should_stop():

            # FORWARD PASS
            generated_X, generated_Y = sess.run([genF, genG])
            _, _, _, _, summary_str = sess.run(
                [G_optim, DY_optim, F_optim, DX_optim, summary_op],
                feed_dict={
                    fake_Y_sample: cache_Y.fetch(generated_Y),
                    fake_X_sample: cache_X.fetch(generated_X)
                })

            counter += 1
            print("[%4d] time: %4.4f" % (counter, time.time() - start_time))

            if np.mod(counter, LOG_FREQUENCY) == 0:
                print('writing')
                writer.add_summary(summary_str, counter)

            if np.mod(counter, SAMPLE_STEP) == 0:
                sample_model(sess, counter, test_X, test_Y, testG, testF,
                             testG_back, testF_back)

            if np.mod(counter, SAVE_STEP) == 0:
                save_path = save_model(saver, sess, counter)
                print("Running for '{0:.2}' mins, saving to {1}".format(
                    timer.elapsed() / 60, save_path))

            if np.mod(counter, SAVE_STEP) == 0:
                elapsed_min = timer.elapsed() / 60
                if (elapsed_min >= MAX_TRAIN_TIME_MINS):
                    print(
                        "Trained for '{0:.2}' mins and reached the max limit of {1}. Saving model."
                        .format(elapsed_min, MAX_TRAIN_TIME_MINS))
                    coord.request_stop()

    except KeyboardInterrupt:
        print('Interrupted')
        coord.request_stop()
    except Exception as e:
        coord.request_stop(e)
    finally:
        save_path = save_model(saver, sess, counter)
        print("Model saved in file: %s" % save_path)
        # When done, ask the threads to stop.
        coord.request_stop()
        coord.join(threads)