Exemple #1
0
def eval():
    with tf.Session() as sess:
        ckpt_path = model_path
        saver = tf.train.import_meta_graph(ckpt_path + '.meta')
        saver.restore(sess, ckpt_path)

        input_A_place = tf.get_default_graph().get_tensor_by_name('input_A:0')
        input_B_place = tf.get_default_graph().get_tensor_by_name('input_B:0')
        keep_prob_place = tf.get_default_graph().get_tensor_by_name(
            'keep_prob:0')
        is_training = tf.get_default_graph().get_tensor_by_name(
            'is_training:0')

        A2B_output = tf.get_default_graph().get_tensor_by_name('A2B_output:0')
        B2A_output = tf.get_default_graph().get_tensor_by_name('B2A_output:0')

        dataLoader = Pix2Pix_loader(image_dir,
                                    image_height,
                                    image_width,
                                    batch_size=batch_size)

        index = 1
        while True:
            images_A, images_B = dataLoader.random_next_test_batch()

            _A2B_output = sess.run(A2B_output,
                                   feed_dict={
                                       input_A_place: images_A,
                                       is_training: False,
                                       keep_prob_place: 0.5
                                   })
            _A2B_output = (_A2B_output + 1) / 2 * 255.0
            _A2B_output = _A2B_output.astype(np.uint8)

            _B2A_output = sess.run(B2A_output,
                                   feed_dict={
                                       input_B_place: images_B,
                                       is_training: False,
                                       keep_prob_place: 0.5
                                   })
            _B2A_output = (_B2A_output + 1) / 2 * 255.0
            _B2A_output = _B2A_output.astype(np.uint8)

            cv2.imshow("A", images_A[0])
            cv2.imshow("B", images_B[0])
            cv2.imshow("A2B_output", _A2B_output[0])
            cv2.imshow("B2A_output", _B2A_output[0])

            cv2.waitKey(0)
Exemple #2
0
def eval():
    sess = tf.Session()
    with tf.gfile.FastGFile(model_path, "rb") as fr:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(fr.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name="")

    sess.run(tf.global_variables_initializer())

    input_A_place = sess.graph.get_tensor_by_name('input_A:0')
    input_B_place = sess.graph.get_tensor_by_name('input_B:0')

    A2B_output = sess.graph.get_tensor_by_name('A2B_output:0')
    B2A_output = sess.graph.get_tensor_by_name('B2A_output:0')

    is_training = sess.graph.get_tensor_by_name('is_training:0')

    dataLoader = Pix2Pix_loader(image_dir,
                                image_height,
                                image_width,
                                batch_size=batch_size)

    while True:
        images_A, images_B = dataLoader.random_next_test_batch()

        _A2B_output = sess.run(A2B_output,
                               feed_dict={
                                   input_A_place: images_A,
                                   is_training: False
                               })
        _A2B_output = (_A2B_output + 1) / 2 * 255.0
        _A2B_output = _A2B_output.astype(np.uint8)

        _B2A_output = sess.run(B2A_output,
                               feed_dict={
                                   input_B_place: images_B,
                                   is_training: False
                               })
        _B2A_output = (_B2A_output + 1) / 2 * 255.0
        _B2A_output = _B2A_output.astype(np.uint8)

        cv2.imshow("A", images_A[0])
        cv2.imshow("B", images_B[0])
        cv2.imshow("A2B_output", _A2B_output[0])
        cv2.imshow("B2A_output", _B2A_output[0])
        cv2.waitKey(0)
Exemple #3
0
def eval():
    with tf.Session() as sess:
        ckpt_path = model_path
        saver = tf.train.import_meta_graph(ckpt_path + '.meta')
        saver.restore(sess, ckpt_path)

        input_A_place = tf.get_default_graph().get_tensor_by_name('input_A:0')
        keep_prob_place = tf.get_default_graph().get_tensor_by_name(
            'keep_prob:0')
        is_training = tf.get_default_graph().get_tensor_by_name(
            'is_training:0')

        A2B_output = tf.get_default_graph().get_tensor_by_name('A2B_output:0')

        dataLoader = Pix2Pix_loader(image_dir,
                                    image_height,
                                    image_width,
                                    batch_size=batch_size)

        while True:
            images_A, images_B = dataLoader.random_next_test_batch()

            _A2B_output = sess.run(A2B_output,
                                   feed_dict={
                                       input_A_place: images_A,
                                       is_training: False,
                                       keep_prob_place: 1.0
                                   })
            _A2B_output = (_A2B_output + 1) / 2 * 255.0
            _A2B_output = _A2B_output.astype(np.uint8)

            fig = plt.figure()
            ax1 = fig.add_subplot(131)
            ax1.imshow(np.uint8(images_A[0]))
            ax2 = fig.add_subplot(132)
            ax2.imshow(np.uint8(images_B[0]))
            ax3 = fig.add_subplot(133)
            ax3.imshow(np.uint8(_A2B_output[0]))
            plt.show()
Exemple #4
0
def eval():
    sess = tf.Session()
    with tf.gfile.FastGFile(model_path, "rb") as fr:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(fr.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name="")

    sess.run(tf.global_variables_initializer())

    input_A_place = sess.graph.get_tensor_by_name('input_A:0')

    A2B_output = sess.graph.get_tensor_by_name('A2B_output:0')
    keep_prob_place = sess.graph.get_tensor_by_name('keep_prob:0')
    is_training = sess.graph.get_tensor_by_name('is_training:0')

    dataLoader = Pix2Pix_loader(image_dir, image_height, image_width, batch_size=batch_size)

    while True:
        images_A, images_B = dataLoader.random_next_test_batch()

        _A2B_output = sess.run(A2B_output, feed_dict={input_A_place: images_A,
                                        is_training:False,keep_prob_place:1.0})
        print(_A2B_output)
        _A2B_output = (_A2B_output + 1) / 2 * 255.0
        _A2B_output = _A2B_output.astype(np.uint8)

		fig =plt.figure()
        ax1 = fig.add_subplot(131)
        ax1.imshow(np.uint8(images_A[0]))
        ax2 = fig.add_subplot(132)
        ax2.imshow(np.uint8(images_B[0]))
        ax3 = fig.add_subplot(133)
        ax3.imshow(np.uint8(_A2B_output[0]))

        plt.show()
Exemple #5
0
def train():
    input_A_place = tf.placeholder(tf.float32,
                                   shape=[None, image_height, image_width, 3],
                                   name="input_A")
    input_B_place = tf.placeholder(tf.float32,
                                   shape=[None, image_height, image_width, 3],
                                   name="input_B")
    fake_pool_A_place = tf.placeholder(
        tf.float32,
        shape=[None, image_height, image_width, 3],
        name="fake_pool_A")
    fake_pool_B_place = tf.placeholder(
        tf.float32,
        shape=[None, image_height, image_width, 3],
        name="fake_pool_B")
    is_training_place = tf.placeholder(tf.bool, shape=(), name="is_training")

    cycleGAN = CycleGAN(is_training_place, lambda_reconst)

    Gen_AB_loss, Gen_BA_loss, Dis_A_loss, Dis_B_loss,fake_A,fake_B= \
        cycleGAN.build_CycleGAN(input_A_place,input_B_place,
                                fake_pool_A_place,fake_pool_B_place)

    gen_A2B_vars, gen_B2A_vars, dis_A_vars, dis_B_vars = cycleGAN.get_vars()
    global_step = tf.Variable(-1, trainable=False, name="global_step")
    global_step_increase = tf.assign(global_step, tf.add(global_step, 1))

    learning_rate = (tf.where(
        tf.greater_equal(global_step, start_decay_step),
        tf.train.polynomial_decay(starter_learning_rate,
                                  global_step - start_decay_step,
                                  decay_steps,
                                  end_learning_rate,
                                  power=1.0), starter_learning_rate))
    # 不要使用with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS))来更新batchnorm的参数
    # 因为其tf.GraphKeys.UPDATE_OPS包含了生成器和判别器所有的batchnorm的参数
    train_op_G = tf.train.AdamOptimizer(learning_rate, beta1=0.5, ). \
        minimize(Gen_AB_loss+Gen_BA_loss, var_list=gen_A2B_vars+gen_B2A_vars)
    train_op_D = tf.train.AdamOptimizer(learning_rate, beta1=0.5, ). \
        minimize(Dis_A_loss+Dis_B_loss, var_list=dis_A_vars+dis_B_vars)

    A2B_out, ABA_out = cycleGAN.sample_generate(input_A_place, "A2B")
    A2B_output = tf.identity(A2B_out, name="A2B_output")
    B2A_out, BAB_out = cycleGAN.sample_generate(input_B_place, "B2A")
    B2A_output = tf.identity(B2A_out, name="B2A_output")

    fake_A_pool = ImagePool(pool_size)
    fake_B_pool = ImagePool(pool_size)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(ckpt_path)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(ckpt_path, ckpt_name))

        _global_step = sess.run(global_step_increase)
        dataLoader = Pix2Pix_loader(image_dir,
                                    image_height,
                                    image_width,
                                    batch_size=batch_size,
                                    global_step=_global_step)
        while _global_step < Train_Step:
            images_A, images_B = dataLoader.next_batch()

            feed_dict_pool = {
                input_A_place: images_A,
                input_B_place: images_B,
                is_training_place: True
            }

            fake_A_vals, fake_B_vals = sess.run([fake_A, fake_B],
                                                feed_dict=feed_dict_pool)

            feed_dict_train = {
                input_A_place: images_A,
                input_B_place: images_B,
                is_training_place: True,
                fake_pool_A_place: fake_A_pool.query(fake_A_vals),
                fake_pool_B_place: fake_B_pool.query(fake_B_vals)
            }

            sess.run(train_op_D, feed_dict=feed_dict_train)
            sess.run(train_op_G, feed_dict=feed_dict_train)
            sess.run(train_op_G, feed_dict=feed_dict_train)


            _Gen_AB_loss, _Gen_BA_loss,_Dis_A_loss, _Dis_B_loss\
                = sess.run([Gen_AB_loss, Gen_BA_loss, Dis_A_loss, Dis_B_loss],feed_dict=feed_dict_train)

            # if _global_step%50==0:
            print(
                "Step:{},Gen_AB_loss:{},Gen_BA_loss:{},Dis_A_loss:{},Dis_B_loss:{}"
                .format(
                    _global_step,
                    _Gen_AB_loss,
                    _Gen_BA_loss,
                    _Dis_A_loss,
                    _Dis_B_loss,
                ))

            if _global_step % 100 == 0:
                test_images_A, test_images_B = dataLoader.random_next_train_batch(
                )

                #save result form A to B
                _A2B_output, _ABA_out = sess.run([A2B_output, ABA_out],
                                                 feed_dict={
                                                     input_A_place:
                                                     test_images_A,
                                                     is_training_place: False
                                                 })
                _A2B_output = (_A2B_output + 1) / 2 * 255.0
                _ABA_out = (_ABA_out + 1) / 2 * 255.0
                for ind, trg_image in enumerate(_A2B_output[:sample_num]):
                    scipy.misc.imsave(
                        result_dir + "/{}_{}_A.jpg".format(_global_step, ind),
                        test_images_A[ind])
                    scipy.misc.imsave(
                        result_dir +
                        "/{}_{}_A2B.jpg".format(_global_step, ind),
                        _A2B_output[ind])
                    scipy.misc.imsave(
                        result_dir +
                        "/{}_{}_ABA.jpg".format(_global_step, ind),
                        _ABA_out[ind])

                # save result form B to A
                _B2A_output, _BAB_out = sess.run([B2A_output, BAB_out],
                                                 feed_dict={
                                                     input_B_place:
                                                     test_images_B,
                                                     is_training_place: False
                                                 })
                _B2A_output = (_B2A_output + 1) / 2 * 255.0
                _BAB_out = (_BAB_out + 1) / 2 * 255.0
                for ind, trg_image in enumerate(_B2A_output[:sample_num]):
                    scipy.misc.imsave(
                        result_dir + "/{}_{}_B.jpg".format(_global_step, ind),
                        test_images_B[ind])
                    scipy.misc.imsave(
                        result_dir +
                        "/{}_{}_B2A.jpg".format(_global_step, ind),
                        _B2A_output[ind])
                    scipy.misc.imsave(
                        result_dir +
                        "/{}_{}_BAB.jpg".format(_global_step, ind),
                        _BAB_out[ind])

            if _global_step % 100000 == 0:
                # 保存PB
                constant_graph = graph_util.convert_variables_to_constants(
                    sess, sess.graph_def, ["A2B_output", "B2A_output"])
                save_model_name = model_name + "-" + str(_global_step) + ".pb"
                with tf.gfile.FastGFile(pb_path + save_model_name,
                                        mode="wb") as fw:
                    fw.write(constant_graph.SerializeToString())
                # 保存CKPT
                saver.save(sess,
                           ckpt_path + model_name + ".ckpt",
                           global_step=_global_step)
                print("Successfully saved model {}".format(save_model_name))
                # return
            _global_step = sess.run(global_step_increase)
Exemple #6
0
def train():
    input_A_place = tf.placeholder(tf.float32,shape=[None,image_height,image_width, 3],name="input_A")
    input_B_place = tf.placeholder(tf.float32, shape=[None, image_height,image_width, 3], name="input_B")
    is_training_place = tf.placeholder_with_default(False, shape=(),name="is_training")
    reconst_rate_place = tf.placeholder(tf.float32, shape=(),name="reconst_rate")
    discoGan = DiscoGAN(is_training_place,reconst_rate_place)

    G_loss,D_loss = discoGan.build_DiscoGAN(input_A_place,input_B_place)
    g_vars,d_vars = discoGan.get_vars()

    global_step = tf.Variable(-1, trainable=False,name="global_step")
    global_step_increase = tf.assign(global_step, tf.add(global_step, 1))
    train_op_D = tf.train.AdamOptimizer(learning_rate, beta1=0.5).minimize(D_loss, var_list=d_vars)
    train_op_G = tf.train.AdamOptimizer(learning_rate, beta1=0.5).minimize(G_loss, var_list=g_vars)

    A2B_out,ABA_out = discoGan.sample_generate(input_A_place, "A2B")
    A2B_output = tf.identity(A2B_out, name="A2B_output")
    B2A_out,BAB_out = discoGan.sample_generate(input_B_place, "B2A")
    B2A_output = tf.identity(B2A_out, name="B2A_output")

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(ckpt_path)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(ckpt_path, ckpt_name))

        _global_step = sess.run(global_step_increase)
        dataLoader = Pix2Pix_loader(image_dir, image_height, image_width,batch_size=batch_size,global_step=_global_step)
        
        while _global_step<Train_Step:
            if _global_step<10000:
                reconst_rate = starting_rate
            else:
                reconst_rate = change_rate

            images_A,images_B = dataLoader.next_batch()      #0~255
            feed_dict = {input_A_place:images_A,input_B_place:images_B,
                         is_training_place:True,reconst_rate_place:reconst_rate}

            if _global_step%2==0:
                sess.run(train_op_D,feed_dict=feed_dict)
            sess.run(train_op_G, feed_dict=feed_dict)
            _global_step,_D_loss,_G_loss = sess.run([global_step,D_loss,G_loss],
                                                    feed_dict=feed_dict)

            if _global_step%50==0:
                print("Step:{},Reconst_rate:{},D_loss:{},G_loss:{}".format(_global_step,reconst_rate, _D_loss, _G_loss,))

            if _global_step%100==0:
                test_images_A, test_images_B = dataLoader.random_next_test_batch()
                #save result form A to B
                _A2B_output,_ABA_out = sess.run([A2B_output,ABA_out],feed_dict={input_A_place:test_images_A})
                _A2B_output = (_A2B_output + 1) / 2 * 255.0
                _ABA_out = (_ABA_out + 1) / 2 * 255.0
                for ind,trg_image in enumerate(_A2B_output[:sample_num]):
                    scipy.misc.imsave(result_dir + "/{}_{}_A.jpg".format(_global_step,ind),test_images_A[ind])
                    scipy.misc.imsave(result_dir + "/{}_{}_A2B.jpg".format(_global_step,ind), _A2B_output[ind])
                    scipy.misc.imsave(result_dir + "/{}_{}_ABA.jpg".format(_global_step, ind), _ABA_out[ind])

                # save result form B to A
                _B2A_output,_BAB_out = sess.run([B2A_output,BAB_out], feed_dict={input_B_place: test_images_B})
                _B2A_output = (_B2A_output + 1) / 2 * 255.0
                _BAB_out = (_BAB_out + 1) / 2 * 255.0
                for ind,trg_image in enumerate(_B2A_output[:sample_num]):
                    scipy.misc.imsave(result_dir + "/{}_{}_B.jpg".format(_global_step,ind),test_images_B[ind])
                    scipy.misc.imsave(result_dir + "/{}_{}_B2A.jpg".format(_global_step,ind), _B2A_output[ind])
                    scipy.misc.imsave(result_dir + "/{}_{}_BAB.jpg".format(_global_step, ind), _BAB_out[ind])

            if _global_step==Train_Step-5:
                # 保存PB
                constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                           ["A2B_output","B2A_output"])
                save_model_name = model_name + "-" + str(_global_step) + ".pb"
                with tf.gfile.FastGFile(pb_path + save_model_name, mode="wb") as fw:
                    fw.write(constant_graph.SerializeToString())
                # 保存CKPT
                saver.save(sess, ckpt_path + model_name + ".ckpt", global_step=_global_step)
                print("Successfully saved model {}".format(save_model_name))
                return 

            _global_step = sess.run(global_step_increase)