Ejemplo n.º 1
0
def train(args):

    input_photo = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    input_superpixel = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    input_cartoon = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    # output=>fake picture
    output = network.unet_generator(input_photo)
    #
    output = guided_filter(input_photo, output, r=1)

    blur_fake = guided_filter(output, output, r=5, eps=2e-1)
    blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1)

    gray_fake, gray_cartoon = utils.color_shift(output, input_cartoon)

    d_loss_gray, g_loss_gray = loss.lsgan_loss(network.disc_sn,
                                               gray_cartoon,
                                               gray_fake,
                                               scale=1,
                                               patch=True,
                                               name='disc_gray')
    d_loss_blur, g_loss_blur = loss.lsgan_loss(network.disc_sn,
                                               blur_cartoon,
                                               blur_fake,
                                               scale=1,
                                               patch=True,
                                               name='disc_blur')

    vgg_model = loss.Vgg19('vgg19_no_fc.npy')
    vgg_photo = vgg_model.build_conv4_4(input_photo)
    vgg_output = vgg_model.build_conv4_4(output)
    vgg_superpixel = vgg_model.build_conv4_4(input_superpixel)
    h, w, c = vgg_photo.get_shape().as_list()[1:]

    photo_loss = tf.reduce_mean(
        tf.losses.absolute_difference(vgg_photo, vgg_output)) / (h * w * c)
    superpixel_loss = tf.reduce_mean(tf.losses.absolute_difference\
                                     (vgg_superpixel, vgg_output))/(h*w*c)
    recon_loss = photo_loss + superpixel_loss
    tv_loss = loss.total_variation_loss(output)

    g_loss_total = 1e4 * tv_loss + 1e-1 * g_loss_blur + g_loss_gray + 2e2 * recon_loss
    d_loss_total = d_loss_blur + d_loss_gray

    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'gene' in var.name]
    disc_vars = [var for var in all_vars if 'disc' in var.name]

    tf.summary.scalar('tv_loss', tv_loss)
    tf.summary.scalar('photo_loss', photo_loss)
    tf.summary.scalar('superpixel_loss', superpixel_loss)
    tf.summary.scalar('recon_loss', recon_loss)
    tf.summary.scalar('d_loss_gray', d_loss_gray)
    tf.summary.scalar('g_loss_gray', g_loss_gray)
    tf.summary.scalar('d_loss_blur', d_loss_blur)
    tf.summary.scalar('g_loss_blur', g_loss_blur)
    tf.summary.scalar('d_loss_total', d_loss_total)
    tf.summary.scalar('g_loss_total', g_loss_total)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):

        g_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
                                        .minimize(g_loss_total, var_list=gene_vars)

        d_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
                                        .minimize(d_loss_total, var_list=disc_vars)
    '''
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    '''
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_fraction)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    train_writer = tf.summary.FileWriter(args.save_dir + '/train_log')
    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20)

    with tf.device('/device:GPU:0'):

        sess.run(tf.global_variables_initializer())
        saver.restore(sess,
                      tf.train.latest_checkpoint('pretrain/saved_models'))

        face_photo_dir = 'dataset/photo_face'
        face_photo_list = utils.load_image_list(face_photo_dir)
        scenery_photo_dir = 'dataset/photo_scenery'
        scenery_photo_list = utils.load_image_list(scenery_photo_dir)

        face_cartoon_dir = 'dataset/cartoon_face'
        face_cartoon_list = utils.load_image_list(face_cartoon_dir)
        scenery_cartoon_dir = 'dataset/cartoon_scenery'
        scenery_cartoon_list = utils.load_image_list(scenery_cartoon_dir)

        for total_iter in tqdm(range(args.total_iter)):

            if np.mod(total_iter, 5) == 0:
                photo_batch = utils.next_batch(face_photo_list,
                                               args.batch_size)
                cartoon_batch = utils.next_batch(face_cartoon_list,
                                                 args.batch_size)
            else:
                photo_batch = utils.next_batch(scenery_photo_list,
                                               args.batch_size)
                cartoon_batch = utils.next_batch(scenery_cartoon_list,
                                                 args.batch_size)

            inter_out = sess.run(output,
                                 feed_dict={
                                     input_photo: photo_batch,
                                     input_superpixel: photo_batch,
                                     input_cartoon: cartoon_batch
                                 })
            '''
            adaptive coloring has to be applied with the clip_by_value 
            in the last layer of generator network, which is not very stable.
            to stabiliy reproduce our results, please use power=1.0
            and comment the clip_by_value function in the network.py first
            If this works, then try to use adaptive color with clip_by_value.
            '''
            if args.use_enhance:
                superpixel_batch = utils.selective_adacolor(inter_out,
                                                            power=1.2)
            else:
                superpixel_batch = utils.simple_superpixel(inter_out,
                                                           seg_num=200)

            _, g_loss, r_loss = sess.run(
                [g_optim, g_loss_total, recon_loss],
                feed_dict={
                    input_photo: photo_batch,
                    input_superpixel: superpixel_batch,
                    input_cartoon: cartoon_batch
                })

            _, d_loss, train_info = sess.run(
                [d_optim, d_loss_total, summary_op],
                feed_dict={
                    input_photo: photo_batch,
                    input_superpixel: superpixel_batch,
                    input_cartoon: cartoon_batch
                })

            train_writer.add_summary(train_info, total_iter)

            if np.mod(total_iter + 1, 50) == 0:

                print('Iter: {}, d_loss: {}, g_loss: {}, recon_loss: {}'.\
                        format(total_iter, d_loss, g_loss, r_loss))
                if np.mod(total_iter + 1, 500) == 0:
                    saver.save(sess,
                               args.save_dir + '/saved_models/model',
                               write_meta_graph=False,
                               global_step=total_iter)

                    photo_face = utils.next_batch(face_photo_list,
                                                  args.batch_size)
                    cartoon_face = utils.next_batch(face_cartoon_list,
                                                    args.batch_size)
                    photo_scenery = utils.next_batch(scenery_photo_list,
                                                     args.batch_size)
                    cartoon_scenery = utils.next_batch(scenery_cartoon_list,
                                                       args.batch_size)

                    result_face = sess.run(output,
                                           feed_dict={
                                               input_photo: photo_face,
                                               input_superpixel: photo_face,
                                               input_cartoon: cartoon_face
                                           })

                    result_scenery = sess.run(output,
                                              feed_dict={
                                                  input_photo: photo_scenery,
                                                  input_superpixel:
                                                  photo_scenery,
                                                  input_cartoon:
                                                  cartoon_scenery
                                              })

                    utils.write_batch_image(
                        result_face, args.save_dir + '/images',
                        str(total_iter) + '_face_result.jpg', 4)
                    utils.write_batch_image(
                        photo_face, args.save_dir + '/images',
                        str(total_iter) + '_face_photo.jpg', 4)

                    utils.write_batch_image(
                        result_scenery, args.save_dir + '/images',
                        str(total_iter) + '_scenery_result.jpg', 4)
                    utils.write_batch_image(
                        photo_scenery, args.save_dir + '/images',
                        str(total_iter) + '_scenery_photo.jpg', 4)
Ejemplo n.º 2
0
def train(args):

    input_photo = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])

    output = network.unet_generator(input_photo)

    recon_loss = tf.reduce_mean(
        tf.losses.absolute_difference(input_photo, output))

    tf.summary.scalar('recon_loss', recon_loss)
    summary_op = tf.summary.merge_all()

    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'gene' in var.name]

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):

        optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
                                        .minimize(recon_loss, var_list=gene_vars)
    '''
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    '''
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_fraction)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20)

    summary_writer = tf.summary.FileWriter(args.save_dir + 'save_models/model',
                                           tf.get_default_graph())

    with tf.device('/device:GPU:0'):

        sess.run(tf.global_variables_initializer())
        face_photo_dir = 'dataset/photo_face'
        face_photo_list = utils.load_image_list(face_photo_dir)
        scenery_photo_dir = 'dataset/photo_scenery'
        scenery_photo_list = utils.load_image_list(scenery_photo_dir)

        for total_iter in tqdm(range(args.total_iter)):

            if np.mod(total_iter, 5) == 0:
                photo_batch = utils.next_batch(face_photo_list,
                                               args.batch_size)
            else:
                photo_batch = utils.next_batch(scenery_photo_list,
                                               args.batch_size)

            _, r_loss, summary_str = sess.run(
                [optim, recon_loss, summary_op],
                feed_dict={input_photo: photo_batch})

            summary_writer.add_summary(summary_str, global_step=total_iter)

            if np.mod(total_iter + 1, 50) == 0:

                print('pretrain, iter: {}, recon_loss: {}'.format(
                    total_iter, r_loss))

                if np.mod(total_iter + 1, 500) == 0:
                    saver.save(sess,
                               args.save_dir + 'save_models/model',
                               write_meta_graph=False,
                               global_step=total_iter)

                    photo_face = utils.next_batch(face_photo_list,
                                                  args.batch_size)
                    photo_scenery = utils.next_batch(scenery_photo_list,
                                                     args.batch_size)

                    result_face = sess.run(output,
                                           feed_dict={input_photo: photo_face})

                    result_scenery = sess.run(
                        output, feed_dict={input_photo: photo_scenery})

                    utils.write_batch_image(
                        result_face, args.save_dir + '/images',
                        str(total_iter) + '_face_result.jpg', 4)
                    utils.write_batch_image(
                        photo_face, args.save_dir + '/images',
                        str(total_iter) + '_face_photo.jpg', 4)
                    utils.write_batch_image(
                        result_scenery, args.save_dir + '/images',
                        str(total_iter) + '_scenery_result.jpg', 4)
                    utils.write_batch_image(
                        photo_scenery, args.save_dir + '/images',
                        str(total_iter) + '_scenery_photo.jpg', 4)