Exemplo n.º 1
0
def train():
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    #input
    z_h, images_h, labels_h, domain_labels_h = Holder()

    
    #inference
    D_logits_real, D_logits_fake, D_logits_fake_for_G, reg_g, reg_d = \
        gan.inference(images_h, domain_labels_h, z_h)

    sampler = gan.sampler(z_h, domain_labels_h)

    #loss
    G_loss, D_loss = gan.loss_l2(D_logits_real, D_logits_fake, D_logits_fake_for_G,\
        reg_g, reg_d, FLAGS.weight_G, FLAGS.weight_D)


    #train_op
    G_vars, D_vars, C_vars = GetVars()
    G_train_op, D_train_op = gan.train(G_loss, D_loss, G_vars, D_vars, global_step)

    data_set = data.Data().load(FLAGS.data_path1, FLAGS.half_batch_size)
    data_set2 = data.Data().load(FLAGS.data_path2, FLAGS.half_batch_size)


    sess = sess_init()
    threads = tf.train.start_queue_runners(sess=sess)
    saver = tf.train.Saver()

    for step in xrange(FLAGS.max_steps):
      _, images_v, domain_labels_v = \
          GenValsForHolder_TF(data_set, data_set2, sess)
      z_v_half = np.random.uniform(-1, 1, 
          [FLAGS.half_batch_size, gan.Z_DIM]).astype(np.float32)
      z_for_G_v = np.concatenate((z_v_half, z_v_half),0)

      _, errD, reg_d_v = sess.run([D_train_op, D_loss, reg_d],
          feed_dict={z_h:z_for_G_v, images_h:images_v, domain_labels_h: domain_labels_v})

      _, errG, reg_g_v = sess.run([G_train_op, G_loss, reg_g],
          feed_dict={z_h:z_for_G_v, domain_labels_h: domain_labels_v})


      if step % 100 == 0:
        print "step = %d, errD = %f, errG = %f reg_g = %f reg_d = %f" % (step, errD, errG, reg_g_v, reg_d_v)

      if step % 1000 == 0:
        z_v_half = np.random.uniform(-1, 1, 
            [FLAGS.half_batch_size, gan.Z_DIM]).astype(np.float32)
        z_v = np.concatenate((z_v_half, z_v_half),0)

        samples = sess.run(sampler, 
            feed_dict={z_h:z_v, domain_labels_h: domain_labels_v})
        save_images(samples, [2, FLAGS.half_batch_size],
            './samples/train_{:d}.png'.format(step))
        save_images(images_v, [2, FLAGS.half_batch_size],
            './samples_real/train_{:d}.png'.format(step))
Exemplo n.º 2
0
def train():
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    #input
    z_h, images_h, labels_h, domain_labels_h = Holder()
    reg_D_weight_h = tf.placeholder(tf.float32)
    reg_G_weight_h = tf.placeholder(tf.float32)

    
    #inference
    D_logits_real, D_logits_fake, D_logits_fake_for_G, reg_g, reg_d = \
        gan.inference(images_h, domain_labels_h, z_h)

    sampler = gan.sampler(z_h, domain_labels_h)

    #loss
    G_loss, D_loss = gan.loss(D_logits_real, D_logits_fake, D_logits_fake_for_G, \
                       reg_g, reg_d, reg_G_weight_h, reg_D_weight_h)


    #train_op
    G_vars, D_vars, C_vars = GetVars()
    G_train_op, D_train_op = gan.train(G_loss, D_loss, G_vars, D_vars, global_step)


    data_set = data1.DATA(FLAGS.batch_size)
    data_set.load()
    data_set2 = data2.DATA(FLAGS.batch_size)
    data_set2.load()


    sess = sess_init()
    saver = tf.train.Saver()

    for step in xrange(FLAGS.max_steps):

      z_v, z_for_G_v, images_v, domain_labels_v = \
          GenValsForHolder(data_set, data_set2)

      _, errD, reg_d_v = sess.run([D_train_op, D_loss, reg_d],
          feed_dict={z_h:z_for_G_v, images_h:images_v, domain_labels_h:domain_labels_v, reg_D_weight_h:FLAGS.weight_D})

      _, errG, reg_g_v = sess.run([G_train_op, G_loss, reg_g],
          feed_dict={z_h:z_for_G_v, domain_labels_h:domain_labels_v, reg_G_weight_h:FLAGS.weight_G})

      if step % 100 == 0:
        print "step = %d, errD = %f, errG = %f reg_g = %f reg_d = %f" % (step, errD, errG, reg_g_v, reg_d_v)


      if step % 1000 == 0:
        z_v_half = np.random.uniform(-1, 1, 
            [FLAGS.half_batch_size, gan.Z_DIM]).astype(np.float32)
        z_v = np.concatenate((z_v_half, z_v_half),0)

        domain_labels_v = np.zeros((FLAGS.batch_size, 2))
        domain_labels_v[0:FLAGS.half_batch_size, 0] = 1
        domain_labels_v[FLAGS.half_batch_size:, 1] = 1

        samples = sess.run(sampler, 
            feed_dict={z_h:z_v, domain_labels_h:domain_labels_v})
        save_images(samples, [2, FLAGS.half_batch_size],
            './samples/train_{:d}.bmp'.format(step))
Exemplo n.º 3
0
def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        #input
        z_h, images_h, labels_h1, domain_labels_h = Holder()
        test_images_h1, test_images_h2, test_labels_h1, test_labels_h2, \
            test_domain_labels_h, test_domain_labels_h2 = holder_test()
        reg_D_weight_h = tf.placeholder(tf.float32)
        reg_G_weight_h = tf.placeholder(tf.float32)

        #inference
        D_logits_real, D_logits_fake, D_logits_fake_for_G, reg_g, reg_d = \
            gan.inference(images_h, domain_labels_h, z_h)

        C_logits = gan.classifier(images_h, domain_labels_h, FLAGS.batch_size)

        sampler = gan.sampler(z_h, domain_labels_h)

        #loss
        G_loss, D_loss = gan.loss(D_logits_real, D_logits_fake, D_logits_fake_for_G, \
                           reg_g, reg_d, reg_G_weight_h, reg_D_weight_h, C_logits, labels_h1)

        #train_op
        G_vars, D_vars, C_vars = GetVars()
        G_train_op, D_train_op = gan.train(G_loss, D_loss, G_vars, D_vars,
                                           global_step)

        #for_eval
        test_logits = gan.classifier(test_images_h1, test_domain_labels_h,
                                     FLAGS.test_batch_size, True)
        test_op = gan.correct_num(test_logits, test_labels_h1)
        test_logits2 = gan.classifier(test_images_h2, test_domain_labels_h2,
                                      FLAGS.test_batch_size2, True)
        test_op2 = gan.correct_num(test_logits2, test_labels_h2)

        data_set, data_set2 = load_data()

        sess = sess_init()
        saver = tf.train.Saver()

        for step in xrange(FLAGS.max_steps):

            _, z_for_G_v, images_v, labels_v1, domain_labels_v = \
                GenValsForHolder(data_set, data_set2)

            errC, errD, errG, reg_g_v, reg_d_v = (0, 0, 0, 0, 0)

            if step > FLAGS.start_regD:
                _, errD, reg_d_v = sess.run([D_train_op, D_loss, reg_d],
                    feed_dict={z_h:z_for_G_v, images_h:images_v, labels_h1:labels_v1,\
                        domain_labels_h:domain_labels_v, reg_D_weight_h:FLAGS.weight_D})
            else:
                _, errD, reg_d_v = sess.run([D_train_op, D_loss, reg_d],
                    feed_dict={z_h:z_for_G_v, images_h:images_v, labels_h1:labels_v1,\
                        domain_labels_h:domain_labels_v, reg_D_weight_h:0.0})

            _, errG, reg_g_v = sess.run([G_train_op, G_loss, reg_g],
                feed_dict={z_h:z_for_G_v, domain_labels_h:domain_labels_v, \
                    reg_G_weight_h:FLAGS.weight_G})

            if step % 100 == 0:
                print "step = %d, errD = %f, errG = %f errC = %f reg_g = %f reg_d = %f" \
                    % (step, errD, errG, errC, reg_g_v, reg_d_v)

            if step % 1000 == 0:
                test_domain_labels_v = np.zeros((FLAGS.test_batch_size, 2))
                test_domain_labels_v[:, 0] = 1
                test_domain_labels_v2 = np.zeros((FLAGS.test_batch_size2, 2))
                test_domain_labels_v2[:, 1] = 1
                para_list = (test_op, test_images_h1, test_labels_h1, test_domain_labels_h, \
                    test_domain_labels_v, data_set, FLAGS.test_batch_size, FLAGS.test_sample_size)
                Eval(para_list, sess, step, 'Source')
                para_list = (test_op2, test_images_h2, test_labels_h2, test_domain_labels_h2, \
                    test_domain_labels_v2, data_set2, FLAGS.test_batch_size2, FLAGS.test_sample_size2)
                Eval(para_list, sess, step, 'Target')

                z_v_half = np.random.uniform(
                    -1, 1,
                    [FLAGS.half_batch_size, gan.Z_DIM]).astype(np.float32)
                z_v = np.concatenate((z_v_half, z_v_half), 0)

                domain_labels_v = np.zeros((FLAGS.batch_size, 2))
                domain_labels_v[0:FLAGS.half_batch_size, 0] = 1
                domain_labels_v[FLAGS.half_batch_size:, 1] = 1

                samples = sess.run(sampler,
                                   feed_dict={
                                       z_h: z_v,
                                       domain_labels_h: domain_labels_v
                                   })
                save_images(samples, [2, FLAGS.half_batch_size],
                            './samples/train_{:d}.bmp'.format(step))