示例#1
0
def test(imglist_a,
         imglist_b,
         model_path,
         base_net='resnet',
         batch_size=1,
         image_save_path_a='../output/cat_2_dog_test_cl3/testA/',
         image_save_path_b='../output/cat_2_dog_test_cl3/testB/',
         show_image_every_step=50,
         show_image=False):

    gan = CycleGAN(mode='test', base=base_net, verbose=False)
    gan.build()
    # gan.compile(learning_rate=lr, cycle_loss_weight=cycle_loss_weight, identity_loss_weight=identity_loss_weight,
    #             disc_loss_weight=disc_loss_weight, disc_2_loss_weight=disc_2_loss_weight)
    gan.test(imglist_a,
             model_path,
             is_a2b=True,
             batch_size=1,
             image_save_path=image_save_path_a,
             show_image=show_image,
             show_image_every_step=show_image_every_step)
    gan.test(imglist_b,
             model_path,
             is_a2b=False,
             batch_size=1,
             image_save_path=image_save_path_b,
             show_image=show_image,
             show_image_every_step=show_image_every_step)
示例#2
0
def train(imglist_a,
          imglist_b,
          base_net='resnet',
          lr=2e-4,
          cycle_loss_weight=3,
          identity_loss_weight=0,
          disc_loss_weight=0.5,
          disc_2_loss_weight=0.5,
          epochs=200,
          decay_from=100,
          steps_per_epoch=3000,
          true_label_value=1,
          batch_size=1,
          image_save_path='../output/20190325/',
          model_save_path='../models/20190325/',
          save_image_every_step=100,
          save_model_every_epoch=1,
          show_image=False,
          load_model=False,
          model_load_path='../models/20190325/'):

    gan = CycleGAN(mode='train', base=base_net, verbose=False)
    gan.build()
    gan.compile(learning_rate=lr,
                cycle_loss_weight=cycle_loss_weight,
                identity_loss_weight=identity_loss_weight,
                disc_loss_weight=disc_loss_weight,
                disc_2_loss_weight=disc_2_loss_weight)
    gan.train(imglist_a,
              imglist_b,
              epochs=epochs,
              decay_from=decay_from,
              steps_per_epoch=steps_per_epoch,
              true_label_value=true_label_value,
              batch_size=batch_size,
              image_save_path=image_save_path,
              model_save_path=model_save_path,
              save_image_every_step=save_image_every_step,
              save_model_every_epoch=save_model_every_epoch,
              show_image=show_image,
              load_model=load_model,
              model_load_path=model_load_path)
示例#3
0
def main():

    num_epoch = 40000
    pool_size = 20
    batch_size = 1
    oldpath = FLAGS.buckets
    picFpath = 'picF'
    picGpath = 'picG'
    useCopyfile = True

    if useCopyfile:
        trainfiles = ['picf1.zip', 'picf2.zip', 'picg1.zip']
        # trainfiles.extend(['picf3.zip','picf4.zip','picg2.zip'])

        print(trainfiles)

        for f in trainfiles:
            fn = utils.pai_copy(f, oldpath)
            utils.Unzip(fn)

        picFpath = os.path.join('temp', picFpath)
        picGpath = os.path.join('temp', picGpath)

    print(picFpath)
    print(picGpath)

    sess = tf.InteractiveSession(config=tf.ConfigProto(
        allow_soft_placement=True))

    cycle_gan = CycleGAN(X_train_file=picGpath,
                         Y_train_file=picFpath,
                         batch_size=batch_size,
                         image_size=(270, 480),
                         use_lsgan=True,
                         lossfunc='wgan',
                         norm='instance',
                         learning_rate=3e-3,
                         start_decay_step=5000,
                         decay_steps=350000
                         #optimizer = 'RMSProp'
                         )

    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.build()
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.checkpointDir)
    saver = tf.train.Saver(max_to_keep=0)

    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # save_path = saver.save(sess,os.path.join(FLAGS.checkpointDir,"model_pre.ckpt"))
    # print("Model saved in file: %s" % save_path)

    fake_Y_pool = ImagePool(pool_size)
    fake_X_pool = ImagePool(pool_size)
    print('start train')
    start_time = time.time()

    for step in range(1, num_epoch + 1):
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
            sess.run(
                [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                feed_dict={
                    cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                    cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                }))

        elapsed_time = time.time() - start_time
        start_time = time.time()

        if step % 25 == 0:
            print('G_loss : %s--D_Y_loss : %s--F_loss : %s--D_X_loss : %s--' %
                  (G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val))

            print('step : %s --elapsed_time : %s' % (step, elapsed_time))
            print('adding summary...')
            train_writer.add_summary(summary, step)
            train_writer.flush()

        # if step % 100 == 0:
        #     print('-----------Step %d:-------------' % step)
        #     print('  G_loss   : {}'.format(G_loss_val))
        #     print('  D_Y_loss : {}'.format(D_Y_loss_val))
        #     print('  F_loss   : {}'.format(F_loss_val))
        #     print('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 1000 == 0:
            save_path = saver.save(sess,
                                   os.path.join(FLAGS.checkpointDir,
                                                "model.ckpt"),
                                   global_step=step,
                                   write_meta_graph=False)
            print("Model saved in file: %s" % save_path)

    coord.request_stop()
    coord.join(threads)