Ejemplo n.º 1
0
def main(_):
    X_image = tf.placeholder(tf.float32, [None, FLAGS.Out_DIm])
    y_label_index = tf.placeholder(tf.int32, [None])
    y_label = tf.one_hot(y_label_index, FLAGS.n_class)

    con = tf.random_normal([FLAGS.batch_size, 2])
    z = tf.random_normal([FLAGS.batch_size, FLAGS.z_dim])
    G_image = Generator(z, con, labels=y_label)

    disc_real, real_class, _, _class_r = Discriminator(X_image)
    disc_fake, fake_class, con_fake, _class_f = Discriminator(G_image, True)

    class_label_real = tf.arg_max(_class_r, 1)
    class_label_fake = tf.arg_max(_class_f, 1)
    gen_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope="Generator")
    disc_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope="Discriminator")

    FSR_cost = 0
    #********feasible set reduce***************
    if FLAGS.is_fsr:
        reduce_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
        FSR_cost = tf.nn.relu(reduce_cost)
    #******************************************
    class_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=y_label,
                                                logits=real_class))
    class_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=y_label,
                                                logits=fake_class))
    con_cost = tf.reduce_mean(tf.square(con_fake - con))

    gen_cost = -tf.reduce_mean(disc_fake) + 20 * (class_loss_fake + con_cost)
    disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(
        disc_real) + 20 * (class_loss_fake + class_loss_real + con_cost)

    global_step = tf.Variable(0)
    learning_rate = tf.train.exponential_decay(1e-3,
                                               global_step,
                                               200,
                                               0.96,
                                               staircase=True)

    clip_ops = []
    if FLAGS.is_clip:
        clip_bound = [-.01, .01]
        for v in disc_params:
            clip_ops.append(
                tf.assign(v, tf.clip_by_value(v, clip_bound[0],
                                              clip_bound[1])))

        clip_weight_clip = tf.group(*clip_ops)

    elif FLAGS.is_l2:
        for v in disc_params:
            tf.add_to_collection(
                "loss", tf.multiply(tf.nn.l2_loss(v), FLAGS.l2_regular))
    elif FLAGS.is_gp:
        alpha = tf.random_uniform(shape=[FLAGS.batch_size, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_image - X_image
        interpolates = X_image + (alpha * differences)
        source_logit, class_logit, _, _ = Discriminator(interpolates,
                                                        reuse=True)
        gradients = tf.gradients(source_logit, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
        disc_cost += 10 * gradient_penalty

    tf.add_to_collection("loss", disc_cost)
    dis_losses = tf.add_n(tf.get_collection_ref("loss"))

    #dis_losses = disc_cost
    gen_train = tf.train.AdamOptimizer(
        learning_rate, beta1=0.5, beta2=0.9).minimize(gen_cost,
                                                      global_step=global_step,
                                                      var_list=gen_params)
    disc_train = tf.train.AdamOptimizer(
        learning_rate, beta1=0.5, beta2=0.9).minimize(dis_losses,
                                                      global_step=global_step,
                                                      var_list=disc_params)

    #tensor_noise = tf.random_normal([128,128])
    tensor_noise = tf.constant(
        np.random.normal(size=(128, 128)).astype('float32'))
    #tensor_noise = tf.random_normal([128,128])
    label = [1 for i in range(128)]
    label_tensor = tf.one_hot(np.array(label), FLAGS.n_class)
    con_tensor = tf.constant(np.random.normal(size=(128, 2)).astype('float32'))
    gen_save_image = Generator(tensor_noise,
                               con=con_tensor,
                               labels=label_tensor,
                               reuse=True,
                               nums=128)
    _, _, class_gen_label, _ = Discriminator(gen_save_image, reuse=True)
    gen_label = tf.argmax(class_gen_label, 1)

    #mnist_data  = input_data.read_data_sets("../data",one_hot=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        gen = inf_train_gen()
        for i in xrange(FLAGS.iter_range):
            start_time = time.time()
            #data = mnist_data.train.next_batch(FLAGS.batch_size)
            data_x, data_y = gen.next()
            if i > 0:
                _genc, _ = sess.run([gen_cost, gen_train],
                                    feed_dict={
                                        X_image: data_x,
                                        y_label_index: data_y
                                    })

            for x in xrange(FLAGS.disc_inter):
                _disc, _class_real, _class_fake, _ = sess.run(
                    [disc_cost, class_loss_real, class_loss_fake, disc_train],
                    feed_dict={
                        X_image: data_x,
                        y_label_index: data_y
                    })

            if i > 0:
                #plot.plot("Generator_cost",_genc)
                plot.plot("Discriminator", _disc)
                #plot.plot("class_real",_class_real)
                # plot.plot("class_fake",_class_fake)
                plot.plot('time', time.time() - start_time)
            #if clip_ops is not None:
            #    sess.run(clip_weight_clip)

            if i % 100 == 99:
                image = sess.run(gen_save_image)
                save_images.save_images(image.reshape((128, 28, 28)),
                                        "./gen_image_{}.png".format(i))
                gen_label_ = sess.run(gen_label)
                val_dis_list = []
                #for n in xrange(20):
                #val_data= mnist_data.validation.next_batch(FLAGS.batch_size)

                #_val_disc = sess.run(disc_cost,feed_dict={X_image:val_data[0],y_label:val_data[1]})
                #    val_dis_list.append(_val_disc)
                print "true_label:"
                print data_y
                print "class_real:"
                print sess.run(class_label_real,
                               feed_dict={
                                   X_image: data_x,
                                   y_label_index: data_y
                               })
                print "class_fake"
                print sess.run(class_label_fake,
                               feed_dict={
                                   X_image: data_x,
                                   y_label_index: data_y
                               })
                print "class_gen:"
                print gen_label_
                print "con_cost:"
                print sess.run(con_cost,
                               feed_dict={
                                   X_image: data_x,
                                   y_label_index: data_y
                               })
                print ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
                for images_, val_label in dev_data():
                    _dev_disc_cost = sess.run(disc_cost,
                                              feed_dict={
                                                  X_image: images_,
                                                  y_label_index: val_label
                                              })
                    val_dis_list.append(_dev_disc_cost)
                plot.plot("val_cost", np.mean(val_dis_list))
            if i < 5 or i % 100 == 99:
                plot.flush()

            plot.tick()
Ejemplo n.º 2
0
def main(_):
    X_image = tf.placeholder(tf.float32, [None, FLAGS.Out_DIm])
    #y_label = tf.placeholder(tf.float32,[None,FLAGS.n_class])

    z = tf.random_normal([FLAGS.batch_size, FLAGS.z_dim])
    G_image = Generator(z)

    disc_real = Discriminator(X_image)
    disc_fake = Discriminator(G_image, True)

    gen_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope="Generator")
    disc_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope="Discriminator")

    FSR_cost = 0
    #********feasible set reduce***************
    if FLAGS.is_fsr:
        reduce_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
        FSR_cost = tf.nn.relu(reduce_cost)

    #******************************************

    gen_cost = -tf.reduce_mean(disc_fake)
    disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

    global_step = tf.Variable(0)
    learning_rate = tf.train.exponential_decay(1e-3,
                                               global_step,
                                               200,
                                               0.96,
                                               staircase=True)

    clip_ops = []
    if FLAGS.is_clip:
        clip_bound = [-.01, .01]
        for v in disc_params:
            clip_ops.append(
                tf.assign(v, tf.clip_by_value(v, clip_bound[0],
                                              clip_bound[1])))

        clip_weight_clip = tf.group(*clip_ops)

    elif FLAGS.is_l2:
        for v in disc_params:
            tf.add_to_collection(
                "loss", tf.multiply(tf.nn.l2_loss(v), FLAGS.l2_regular))
    elif FLAGS.is_gp:
        alpha = tf.random_uniform(shape=[FLAGS.batch_size, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_image - X_image
        interpolates = X_image + (alpha * differences)
        gradients = tf.gradients(Discriminator(interpolates, reuse=True),
                                 [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
        gp_cost = 10 * gradient_penalty

    #tf.add_to_collection("loss",disc_cost)
    #dis_losses = tf.add_n(tf.get_collection_ref("loss"))
    dis_losses = disc_cost + gp_cost
    #dis_losses = disc_cost
    gen_train = tf.train.AdamOptimizer(
        learning_rate, beta1=0.5, beta2=0.9).minimize(gen_cost,
                                                      global_step=global_step,
                                                      var_list=gen_params)
    disc_train = tf.train.AdamOptimizer(
        learning_rate, beta1=0.5, beta2=0.9).minimize(dis_losses,
                                                      global_step=global_step,
                                                      var_list=disc_params)

    #tensor_noise = tf.random_normal([128,128])
    tensor_noise = tf.constant(
        np.random.normal(size=(128, 128)).astype('float32'))
    gen_save_image = Generator(tensor_noise, reuse=True, nums=128)

    #mnist_data  = input_data.read_data_sets("../data",one_hot=True)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        gen = inf_train_gen()
        for i in xrange(FLAGS.iter_range):
            start_time = time.time()
            #data = mnist_data.train.next_batch(FLAGS.batch_size)
            data = gen.next()
            if i > 0:
                _genc, _ = sess.run([gen_cost, gen_train],
                                    feed_dict={X_image: data})

            for x in xrange(FLAGS.disc_inter):
                _disc, _ = sess.run([disc_cost, disc_train],
                                    feed_dict={X_image: data})
            if i > 0:
                D_real, D_fake = sess.run([disc_real, disc_fake],
                                          feed_dict={X_image: data})
                #plot.plot("Generator_cost",_genc)
                plot.plot("Discriminator", _disc)
                plot.plot("D_real", np.mean(D_real))
                plot.plot("D_fake", np.mean(D_fake))
                plot.plot('time', time.time() - start_time)
            #if clip_ops is not None:
            #    sess.run(clip_weight_clip)
            if i % 1000 == 999:
                print "***************************************************"
                print "D:"
                print D_real - D_fake
                print "****************************************************"
            if i % 100 == 99:
                print "gp_cost:"
                print sess.run(gp_cost, feed_dict={X_image: data})
            if i % 100 == 99:
                image = sess.run(gen_save_image)
                save_images.save_images(image.reshape((128, 28, 28)),
                                        "./gen_image_{}.png".format(i))
                val_dis_list = []
                #for n in xrange(20):
                #val_data= mnist_data.validation.next_batch(FLAGS.batch_size)

                #_val_disc = sess.run(disc_cost,feed_dict={X_image:val_data[0],y_label:val_data[1]})
                #    val_dis_list.append(_val_disc)
                for images_, _ in dev_data():
                    _dev_disc_cost = sess.run(disc_cost,
                                              feed_dict={X_image: images_})
                    val_dis_list.append(_dev_disc_cost)
                plot.plot("val_cost", np.mean(val_dis_list))
            if i < 5 or i % 100 == 99:
                plot.flush()

            plot.tick()