def train(): with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False) #input z_h, images_h, labels_h, domain_labels_h = Holder() #inference D_logits_real, D_logits_fake, D_logits_fake_for_G, reg_g, reg_d = \ gan.inference(images_h, domain_labels_h, z_h) sampler = gan.sampler(z_h, domain_labels_h) #loss G_loss, D_loss = gan.loss_l2(D_logits_real, D_logits_fake, D_logits_fake_for_G,\ reg_g, reg_d, FLAGS.weight_G, FLAGS.weight_D) #train_op G_vars, D_vars, C_vars = GetVars() G_train_op, D_train_op = gan.train(G_loss, D_loss, G_vars, D_vars, global_step) data_set = data.Data().load(FLAGS.data_path1, FLAGS.half_batch_size) data_set2 = data.Data().load(FLAGS.data_path2, FLAGS.half_batch_size) sess = sess_init() threads = tf.train.start_queue_runners(sess=sess) saver = tf.train.Saver() for step in xrange(FLAGS.max_steps): _, images_v, domain_labels_v = \ GenValsForHolder_TF(data_set, data_set2, sess) z_v_half = np.random.uniform(-1, 1, [FLAGS.half_batch_size, gan.Z_DIM]).astype(np.float32) z_for_G_v = np.concatenate((z_v_half, z_v_half),0) _, errD, reg_d_v = sess.run([D_train_op, D_loss, reg_d], feed_dict={z_h:z_for_G_v, images_h:images_v, domain_labels_h: domain_labels_v}) _, errG, reg_g_v = sess.run([G_train_op, G_loss, reg_g], feed_dict={z_h:z_for_G_v, domain_labels_h: domain_labels_v}) if step % 100 == 0: print "step = %d, errD = %f, errG = %f reg_g = %f reg_d = %f" % (step, errD, errG, reg_g_v, reg_d_v) if step % 1000 == 0: z_v_half = np.random.uniform(-1, 1, [FLAGS.half_batch_size, gan.Z_DIM]).astype(np.float32) z_v = np.concatenate((z_v_half, z_v_half),0) samples = sess.run(sampler, feed_dict={z_h:z_v, domain_labels_h: domain_labels_v}) save_images(samples, [2, FLAGS.half_batch_size], './samples/train_{:d}.png'.format(step)) save_images(images_v, [2, FLAGS.half_batch_size], './samples_real/train_{:d}.png'.format(step))
def train(): with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False) #input z_h, images_h, labels_h, domain_labels_h = Holder() reg_D_weight_h = tf.placeholder(tf.float32) reg_G_weight_h = tf.placeholder(tf.float32) #inference D_logits_real, D_logits_fake, D_logits_fake_for_G, reg_g, reg_d = \ gan.inference(images_h, domain_labels_h, z_h) sampler = gan.sampler(z_h, domain_labels_h) #loss G_loss, D_loss = gan.loss(D_logits_real, D_logits_fake, D_logits_fake_for_G, \ reg_g, reg_d, reg_G_weight_h, reg_D_weight_h) #train_op G_vars, D_vars, C_vars = GetVars() G_train_op, D_train_op = gan.train(G_loss, D_loss, G_vars, D_vars, global_step) data_set = data1.DATA(FLAGS.batch_size) data_set.load() data_set2 = data2.DATA(FLAGS.batch_size) data_set2.load() sess = sess_init() saver = tf.train.Saver() for step in xrange(FLAGS.max_steps): z_v, z_for_G_v, images_v, domain_labels_v = \ GenValsForHolder(data_set, data_set2) _, errD, reg_d_v = sess.run([D_train_op, D_loss, reg_d], feed_dict={z_h:z_for_G_v, images_h:images_v, domain_labels_h:domain_labels_v, reg_D_weight_h:FLAGS.weight_D}) _, errG, reg_g_v = sess.run([G_train_op, G_loss, reg_g], feed_dict={z_h:z_for_G_v, domain_labels_h:domain_labels_v, reg_G_weight_h:FLAGS.weight_G}) if step % 100 == 0: print "step = %d, errD = %f, errG = %f reg_g = %f reg_d = %f" % (step, errD, errG, reg_g_v, reg_d_v) if step % 1000 == 0: z_v_half = np.random.uniform(-1, 1, [FLAGS.half_batch_size, gan.Z_DIM]).astype(np.float32) z_v = np.concatenate((z_v_half, z_v_half),0) domain_labels_v = np.zeros((FLAGS.batch_size, 2)) domain_labels_v[0:FLAGS.half_batch_size, 0] = 1 domain_labels_v[FLAGS.half_batch_size:, 1] = 1 samples = sess.run(sampler, feed_dict={z_h:z_v, domain_labels_h:domain_labels_v}) save_images(samples, [2, FLAGS.half_batch_size], './samples/train_{:d}.bmp'.format(step))
def train(): with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False) #input z_h, images_h, labels_h1, domain_labels_h = Holder() test_images_h1, test_images_h2, test_labels_h1, test_labels_h2, \ test_domain_labels_h, test_domain_labels_h2 = holder_test() reg_D_weight_h = tf.placeholder(tf.float32) reg_G_weight_h = tf.placeholder(tf.float32) #inference D_logits_real, D_logits_fake, D_logits_fake_for_G, reg_g, reg_d = \ gan.inference(images_h, domain_labels_h, z_h) C_logits = gan.classifier(images_h, domain_labels_h, FLAGS.batch_size) sampler = gan.sampler(z_h, domain_labels_h) #loss G_loss, D_loss = gan.loss(D_logits_real, D_logits_fake, D_logits_fake_for_G, \ reg_g, reg_d, reg_G_weight_h, reg_D_weight_h, C_logits, labels_h1) #train_op G_vars, D_vars, C_vars = GetVars() G_train_op, D_train_op = gan.train(G_loss, D_loss, G_vars, D_vars, global_step) #for_eval test_logits = gan.classifier(test_images_h1, test_domain_labels_h, FLAGS.test_batch_size, True) test_op = gan.correct_num(test_logits, test_labels_h1) test_logits2 = gan.classifier(test_images_h2, test_domain_labels_h2, FLAGS.test_batch_size2, True) test_op2 = gan.correct_num(test_logits2, test_labels_h2) data_set, data_set2 = load_data() sess = sess_init() saver = tf.train.Saver() for step in xrange(FLAGS.max_steps): _, z_for_G_v, images_v, labels_v1, domain_labels_v = \ GenValsForHolder(data_set, data_set2) errC, errD, errG, reg_g_v, reg_d_v = (0, 0, 0, 0, 0) if step > FLAGS.start_regD: _, errD, reg_d_v = sess.run([D_train_op, D_loss, reg_d], feed_dict={z_h:z_for_G_v, images_h:images_v, labels_h1:labels_v1,\ domain_labels_h:domain_labels_v, reg_D_weight_h:FLAGS.weight_D}) else: _, errD, reg_d_v = sess.run([D_train_op, D_loss, reg_d], feed_dict={z_h:z_for_G_v, images_h:images_v, labels_h1:labels_v1,\ domain_labels_h:domain_labels_v, reg_D_weight_h:0.0}) _, errG, reg_g_v = sess.run([G_train_op, G_loss, reg_g], feed_dict={z_h:z_for_G_v, domain_labels_h:domain_labels_v, \ reg_G_weight_h:FLAGS.weight_G}) if step % 100 == 0: print "step = %d, errD = %f, errG = %f errC = %f reg_g = %f reg_d = %f" \ % (step, errD, errG, errC, reg_g_v, reg_d_v) if step % 1000 == 0: test_domain_labels_v = np.zeros((FLAGS.test_batch_size, 2)) test_domain_labels_v[:, 0] = 1 test_domain_labels_v2 = np.zeros((FLAGS.test_batch_size2, 2)) test_domain_labels_v2[:, 1] = 1 para_list = (test_op, test_images_h1, test_labels_h1, test_domain_labels_h, \ test_domain_labels_v, data_set, FLAGS.test_batch_size, FLAGS.test_sample_size) Eval(para_list, sess, step, 'Source') para_list = (test_op2, test_images_h2, test_labels_h2, test_domain_labels_h2, \ test_domain_labels_v2, data_set2, FLAGS.test_batch_size2, FLAGS.test_sample_size2) Eval(para_list, sess, step, 'Target') z_v_half = np.random.uniform( -1, 1, [FLAGS.half_batch_size, gan.Z_DIM]).astype(np.float32) z_v = np.concatenate((z_v_half, z_v_half), 0) domain_labels_v = np.zeros((FLAGS.batch_size, 2)) domain_labels_v[0:FLAGS.half_batch_size, 0] = 1 domain_labels_v[FLAGS.half_batch_size:, 1] = 1 samples = sess.run(sampler, feed_dict={ z_h: z_v, domain_labels_h: domain_labels_v }) save_images(samples, [2, FLAGS.half_batch_size], './samples/train_{:d}.bmp'.format(step))