def main(_): bst_acc = 0.0 acc_list = [] writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) tn_dis.saver.restore(sess, flags.dis_model_ckpt) tn_gen.saver.restore(sess, flags.gen_model_ckpt) feed_dict = { vd_dis.image_ph: dis_mnist.test.images, vd_dis.hard_label_ph: dis_mnist.test.labels, } ini_dis = sess.run(vd_dis.accuracy, feed_dict=feed_dict) feed_dict = { vd_gen.image_ph: gen_mnist.test.images, vd_gen.hard_label_ph: gen_mnist.test.labels, } ini_gen = sess.run(vd_gen.accuracy, feed_dict=feed_dict) print('ini dis=%.4f ini gen=%.4f' % (ini_dis, ini_gen)) # exit() start = time.time() batch_d, batch_g = -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): # print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) num_batch_d = math.ceil(tn_size / flags.batch_size) for image_np_d, label_dat_d in dis_datagen.generate( batch_size=flags.batch_size): # for _ in range(num_batch_d): # image_np_d, label_dat_d = dis_mnist.train.next_batch(flags.batch_size) batch_d += 1 feed_dict = {tn_gen.image_ph: image_np_d} label_gen_d, = sess.run([tn_gen.labels], feed_dict=feed_dict) # print('label_dat_d={} label_gen_d={}'.format(label_dat_d.shape, label_gen_d.shape)) sample_np_d, label_np_d = utils.gan_dis_sample_dev( flags, label_dat_d, label_gen_d) feed_dict = { tn_dis.image_ph: image_np_d, tn_dis.sample_ph: sample_np_d, tn_dis.dis_label_ph: label_np_d, } _, summary_d = sess.run( [tn_dis.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) for gen_epoch in range(flags.num_gen_epoch): # print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) num_batch_g = math.ceil(tn_size / flags.batch_size) for image_np_g, label_dat_g in gen_datagen.generate( batch_size=flags.batch_size): # for _ in range(num_batch_g): # image_np_g, label_dat_g = gen_mnist.train.next_batch(flags.batch_size) batch_g += 1 feed_dict = {tn_gen.image_ph: image_np_g} label_gen_g, = sess.run([tn_gen.labels], feed_dict=feed_dict) sample_np_g = utils.generate_label(flags, label_dat_g, label_gen_g) # sample_np_g, rescale_np_g = utils.generate_label(flags, label_dat_g, label_gen_g) # print(sample_np_g.shape, rescale_np_g.shape) feed_dict = { tn_dis.image_ph: image_np_g, tn_dis.sample_ph: sample_np_g, } reward_np_g, = sess.run([tn_dis.rewards], feed_dict=feed_dict) # reward_np_g *= rescale_np_g # print(reward_np_g) feed_dict = { tn_gen.image_ph: image_np_g, tn_gen.sample_ph: sample_np_g, tn_gen.reward_ph: reward_np_g, } _, summary_g = sess.run( [tn_gen.gan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) if flags.collect_cr_data: feed_dict = { vd_gen.image_ph: gen_mnist.test.images, vd_gen.hard_label_ph: gen_mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) acc_list.append(acc) if (batch_g + 1) % eval_interval != 0: continue else: if (batch_g + 1) % eval_interval != 0: continue feed_dict = { vd_gen.image_ph: gen_mnist.test.images, vd_gen.hard_label_ph: gen_mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) bst_acc = max(acc, bst_acc) tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) avg_time = (tot_time / global_step) * (tn_size / flags.batch_size) print( '#%08d curacc=%.4f curbst=%.4f tot=%.0fs avg=%.2fs/epoch' % (batch_g, acc, bst_acc, tot_time, avg_time)) if acc <= bst_acc: continue # save gen parameters if necessary tot_time = time.time() - start print('#mnist=%d bstacc=%.4f et=%.0fs' % (tn_size, bst_acc, tot_time)) if flags.collect_cr_data: utils.create_pardir(flags.all_learning_curve_p) pickle.dump(acc_list, open(flags.all_learning_curve_p, 'wb'))
def main(_): best_prec, bst_epk = 0.0, 0 epk_score_list = [] writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) tn_dis.saver.restore(sess, flags.dis_model_ckpt) tn_gen.saver.restore(sess, flags.gen_model_ckpt) tn_tch.saver.restore(sess, flags.tch_model_ckpt) start = time.time() ini_dis = yfcceval.compute_prec(flags, sess, vd_dis) ini_gen = yfcceval.compute_prec(flags, sess, vd_gen) ini_tch = yfcceval.compute_prec(flags, sess, vd_tch) print('ini dis=%.4f gen=%.4f tch=%.4f' % (ini_dis, ini_gen, ini_tch)) batch_d, batch_g, batch_t = -1, -1, -1 for epoch in range(flags.num_epoch): num_batch_d = math.ceil(flags.num_dis_epoch * tn_size / flags.batch_size) for _ in range(num_batch_d): batch_d += 1 image_d, text_d, label_dat_d = yfccdata_d.next_batch(flags, sess) feed_dict = {tn_gen.image_ph:image_d} label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict) sample_gen_d, gen_label_d = utils.gan_dis_sample(flags, label_dat_d, label_gen_d) feed_dict = {tn_tch.image_ph:image_d, tn_tch.text_ph:text_d} label_tch_d = sess.run(tn_tch.labels, feed_dict=feed_dict) sample_tch_d, tch_label_d = utils.gan_dis_sample(flags, label_dat_d, label_tch_d) feed_dict = { tn_dis.image_ph:image_d, tn_dis.gen_sample_ph:sample_gen_d, tn_dis.gen_label_ph:gen_label_d, tn_dis.tch_sample_ph:sample_tch_d, tn_dis.tch_label_ph:tch_label_d, } _, summary_d = sess.run([tn_dis.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) num_batch_t = math.ceil(flags.num_tch_epoch * tn_size / flags.batch_size) for _ in range(num_batch_t): batch_t += 1 image_t, text_t, label_dat_t = yfccdata_t.next_batch(flags, sess) feed_dict = {tn_tch.image_ph:image_t, tn_tch.text_ph:text_t} label_tch_t = sess.run(tn_tch.labels, feed_dict=feed_dict) sample_t = utils.generate_label(flags, label_dat_t, label_tch_t) feed_dict = {tn_dis.image_ph:image_t, tn_dis.tch_sample_ph:sample_t} reward_t = sess.run(tn_dis.tch_rewards, feed_dict=feed_dict) feed_dict = {vd_gen.image_ph:image_t} soft_logit_t = sess.run(vd_gen.logits, feed_dict=feed_dict) feed_dict = { tn_tch.image_ph:image_t, tn_tch.text_ph:text_t, tn_tch.sample_ph:sample_t, tn_tch.reward_ph:reward_t, tn_tch.hard_label_ph:label_dat_t, tn_tch.soft_logit_ph:soft_logit_t, } _, summary_t = sess.run([tn_tch.kdgan_update, tch_summary_op], feed_dict=feed_dict) writer.add_summary(summary_t, batch_t) num_batch_g = math.ceil(flags.num_gen_epoch * tn_size / flags.batch_size) for _ in range(num_batch_g): batch_g += 1 image_g, text_g, label_dat_g = yfccdata_g.next_batch(flags, sess) feed_dict = {tn_tch.image_ph:image_g, tn_tch.text_ph:text_g} logit_tch_g = sess.run(tn_tch.logits, feed_dict=feed_dict) # print('tch label {}'.format(logit_tch_g.shape)) feed_dict = {tn_gen.image_ph:image_g} label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict) sample_g = utils.generate_label(flags, label_dat_g, label_gen_g) feed_dict = {tn_dis.image_ph:image_g, tn_dis.gen_sample_ph:sample_g} reward_g = sess.run(tn_dis.gen_rewards, feed_dict=feed_dict) feed_dict = { tn_gen.image_ph:image_g, tn_gen.hard_label_ph:label_dat_g, tn_gen.soft_logit_ph:logit_tch_g, tn_gen.sample_ph:sample_g, tn_gen.reward_ph:reward_g, } _, summary_g = sess.run([tn_gen.kdgan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) if (batch_g + 1) % eval_interval != 0: continue scores = yfcceval.compute_score(flags, sess, vd_gen) p3, p5, f3, f5, ndcg3, ndcg5, ap, rr = scores # print('p3=%.4f p5=%.4f f3=%.4f f5=%.4f ndcg3=%.4f ndcg5=%.4f ap=%.4f rr=%.4f' % # (p3, p5, f3, f5, ndcg3, ndcg5, ap, rr)) epk_score_list.append((scores)) prec = yfcceval.compute_prec(flags, sess, vd_gen) if prec > best_prec: bst_epk = epoch best_prec = max(prec, best_prec) tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) avg_time = (tot_time / global_step) * (tn_size / flags.batch_size) print('#%08d@%d prec@%d=%.4f best@%d=%.4f tot=%.0fs avg=%.2fs/epoch' % (global_step, epoch, flags.cutoff, prec, bst_epk, best_prec, tot_time, avg_time)) if prec < best_prec: continue # save if necessary tot_time = time.time() - start print('best@%d=%.4f et=%.0fs' % (bst_epk, best_prec, tot_time)) utils.create_pardir(flags.epk_learning_curve_p) pickle.dump(epk_score_list, open(flags.epk_learning_curve_p, 'wb'))
def main(_): best_prec = 0.0 prec_list = [] writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) tn_dis.saver.restore(sess, flags.dis_model_ckpt) tn_gen.saver.restore(sess, flags.gen_model_ckpt) ini_dis = yfcceval.compute_prec(flags, sess, vd_dis) ini_gen = yfcceval.compute_prec(flags, sess, vd_gen) print('ini dis=%.4f gen=%.4f' % (ini_dis, ini_gen)) start = time.time() batch_d, batch_g = -1, -1 for epoch in range(flags.num_epoch): num_batch_d = math.ceil(flags.num_dis_epoch * tn_data_size / flags.batch_size) for _ in range(num_batch_d): batch_d += 1 image_np_d, _, label_dat_d = yfccdata_d.next_batch(flags, sess) feed_dict = {tn_gen.image_ph: image_np_d} label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict) sample_np_d, label_np_d = utils.gan_dis_sample( flags, label_dat_d, label_gen_d) feed_dict = { tn_dis.image_ph: image_np_d, tn_dis.sample_ph: sample_np_d, tn_dis.dis_label_ph: label_np_d, } _, summary_d = sess.run([tn_dis.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) num_batch_g = math.ceil(flags.num_gen_epoch * tn_data_size / flags.batch_size) for _ in range(num_batch_g): batch_g += 1 image_np_g, _, label_dat_g = yfccdata_g.next_batch(flags, sess) feed_dict = {tn_gen.image_ph: image_np_g} label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict) sample_np_g = utils.generate_label(flags, label_dat_g, label_gen_g) feed_dict = { tn_dis.image_ph: image_np_g, tn_dis.sample_ph: sample_np_g, } reward_np_g = sess.run(tn_dis.rewards, feed_dict=feed_dict) feed_dict = { tn_gen.image_ph: image_np_g, tn_gen.sample_ph: sample_np_g, tn_gen.reward_ph: reward_np_g, } _, summary_g = sess.run([tn_gen.gan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) prec = yfcceval.compute_prec(flags, sess, vd_gen) prec_list.append(prec) if (batch_g + 1) % eval_interval != 0: continue best_prec = max(prec, best_prec) tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) avg_time = (tot_time / global_step) * (tn_data_size / flags.batch_size) print( '#%08d prec@%d=%.4f best=%.4f tot=%.0fs avg=%.2fs/epoch' % (global_step, flags.cutoff, prec, best_prec, tot_time, avg_time)) if prec < best_prec: continue # hit_v = utils.evaluate(flags, sess, vd_gen, bt_list_v) # tot_time = time.time() - start # print('#%03d curbst=%.4f %.0fs' % (epoch, hit_v, tot_time)) # figure_data.append((epoch, hit_v)) # if hit_v <= best_prec: # continue # best_prec = hit_v tot_time = time.time() - start print('best@%d=%.4f et=%.0fs' % (flags.cutoff, best_prec, tot_time)) utils.create_pardir(flags.all_learning_curve_p) pickle.dump(prec_list, open(flags.all_learning_curve_p, 'wb'))
def main(_): bst_acc = 0.0 start_time = time.time() with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) tn_dis.saver.restore(sess, flags.dis_model_ckpt) tn_std.saver.restore(sess, flags.std_model_ckpt) ini_dis = cifar.compute_acc(sess, vd_dis) ini_std = cifar.compute_acc(sess, vd_std) print('ini dis=%.4f ini std=%.4f' % (ini_dis, ini_std)) batch_d, batch_s = -1, -1 for epoch in range(flags.num_epoch): num_batch_d = math.ceil(flags.num_dis_epoch * flags.train_size / flags.batch_size) for _ in range(num_batch_d): batch_d += 1 tn_image_d, label_dat_d = cifar.next_batch(sess) feed_dict = {tn_std.image_ph: tn_image_d} label_std_d = sess.run(tn_std.labels, feed_dict=feed_dict) sample_np_d, std_label_d = utils.gan_dis_sample( flags, label_dat_d, label_std_d) feed_dict = { tn_dis.image_ph: tn_image_d, tn_dis.std_sample_ph: sample_np_d, tn_dis.std_label_ph: std_label_d, } sess.run(tn_dis.gan_train, feed_dict=feed_dict) if (batch_d + 1) % eval_interval != 0: continue end_time = time.time() duration = (end_time - start_time) / 3600 print('dis #batch=%d duration=%.4fh' % (batch_d, duration)) # evaluate dis if necessary num_batch_s = math.ceil(flags.num_std_epoch * flags.train_size / flags.batch_size) for _ in range(num_batch_s): batch_s += 1 tn_image_s, label_dat_s = cifar.next_batch(sess) feed_dict = {tn_std.image_ph: tn_image_s} label_gen_s = sess.run(tn_std.labels, feed_dict=feed_dict) sample_np_s = utils.generate_label(flags, label_dat_s, label_gen_s) feed_dict = { tn_dis.image_ph: tn_image_s, tn_dis.std_sample_ph: sample_np_s } reward_np_s = sess.run(tn_dis.std_rewards, feed_dict=feed_dict) feed_dict = { tn_std.image_ph: tn_image_s, tn_std.sample_ph: sample_np_s, tn_std.reward_ph: reward_np_s, } sess.run(tn_std.gan_train, feed_dict=feed_dict) if (batch_s + 1) % eval_interval != 0: continue acc = cifar.compute_acc(sess, vd_std) bst_acc = max(acc, bst_acc) end_time = time.time() duration = (end_time - start_time) / 3600 print('gen #batch=%d acc=%.4f bst_acc=%.4f duration=%.4fh' % (batch_s + 1, acc, bst_acc, duration)) if acc < bst_acc: continue # save std if necessary tot_time = time.time() - start_time print('#cifar=%d final=%.4f et=%.0fs' % (flags.train_size, bst_acc, tot_time))
def main(_): bst_gen_acc, bst_tch_acc, bst_eph = 0.0, 0.0, 0 acc_list = [] writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) tn_dis.saver.restore(sess, flags.dis_model_ckpt) tn_gen.saver.restore(sess, flags.gen_model_ckpt) tn_tch.saver.restore(sess, flags.tch_model_ckpt) feed_dict = { vd_dis.image_ph: dis_mnist.test.images, vd_dis.hard_label_ph: dis_mnist.test.labels, } ini_dis = sess.run(vd_dis.accuracy, feed_dict=feed_dict) feed_dict = { vd_gen.image_ph: gen_mnist.test.images, vd_gen.hard_label_ph: gen_mnist.test.labels, } ini_gen = sess.run(vd_gen.accuracy, feed_dict=feed_dict) print('ini dis=%.4f ini gen=%.4f' % (ini_dis, ini_gen)) # exit() start = time.time() batch_d, batch_g, batch_t = -1, -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): # print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) # num_batch_d = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_d): # image_d, label_dat_d = dis_mnist.train.next_batch(flags.batch_size) for image_d, label_dat_d in dis_datagen.generate( batch_size=flags.batch_size): batch_d += 1 # feed_dict = {tn_gen.image_ph:image_d} # label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict) # sample_gen_d, dis_label_gen = utils.gan_dis_sample(flags, label_dat_d, label_gen_d) # feed_dict = { # tn_dis.image_ph:image_d, # tn_dis.sample_ph:sample_gen_d, # tn_dis.dis_label_ph:dis_label_gen, # } # sess.run(tn_dis.gan_update, feed_dict=feed_dict) feed_dict = {tn_tch.image_ph: image_d} label_tch_d = sess.run(tn_tch.labels, feed_dict=feed_dict) sample_tch_d, dis_label_tch = utils.gan_dis_sample( flags, label_dat_d, label_tch_d) feed_dict = { tn_dis.image_ph: image_d, tn_dis.sample_ph: sample_tch_d, tn_dis.dis_label_ph: dis_label_tch, } sess.run(tn_dis.gan_update, feed_dict=feed_dict) for tch_epoch in range(flags.num_tch_epoch): # num_batch_t = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_t): # image_t, label_dat_t = tch_mnist.train.next_batch(flags.batch_size) for image_t, label_dat_t in tch_datagen.generate( batch_size=flags.batch_size): batch_t += 1 feed_dict = {tn_tch.image_ph: image_t} label_tch_t = sess.run(tn_tch.labels, feed_dict=feed_dict) sample_t = utils.generate_label(flags, label_dat_t, label_tch_t) feed_dict = { tn_dis.image_ph: image_t, tn_dis.sample_ph: sample_t, } reward_t = sess.run(tn_dis.rewards, feed_dict=feed_dict) feed_dict = { tn_tch.image_ph: image_t, tn_tch.sample_ph: sample_t, tn_tch.reward_ph: reward_t, } if flags.kdgan_model != config.kdgan_odgan_flag: feed_dict = {vd_gen.image_ph: image_t} soft_logit_t = sess.run(vd_gen.logits, feed_dict=feed_dict) feed_dict = { tn_tch.image_ph: image_t, tn_tch.sample_ph: sample_t, tn_tch.reward_ph: reward_t, tn_tch.hard_label_ph: label_dat_t, tn_tch.soft_logit_ph: soft_logit_t, } sess.run(tn_tch.kdgan_update, feed_dict=feed_dict) if (batch_t + 1) % eval_interval != 0: continue feed_dict = { vd_tch.image_ph: gen_mnist.test.images, vd_tch.hard_label_ph: gen_mnist.test.labels, } tch_acc = sess.run(vd_tch.accuracy, feed_dict=feed_dict) # bst_tch_acc = max(tch_acc, bst_tch_acc) # print('#%08d tchcur=%.4f tchbst=%.4f' % (batch_t, tch_acc, bst_tch_acc)) for gen_epoch in range(flags.num_gen_epoch): # print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) # num_batch_g = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_g): # image_g, label_dat_g = gen_mnist.train.next_batch(flags.batch_size) for image_g, label_dat_g in gen_datagen.generate( batch_size=flags.batch_size): batch_g += 1 feed_dict = {tn_gen.image_ph: image_g} label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict) sample_g = utils.generate_label(flags, label_dat_g, label_gen_g) feed_dict = { tn_dis.image_ph: image_g, tn_dis.sample_ph: sample_g, } reward_g = sess.run(tn_dis.rewards, feed_dict=feed_dict) feed_dict = {vd_tch.image_ph: image_g} soft_logit_g = sess.run(vd_tch.logits, feed_dict=feed_dict) feed_dict = { tn_gen.image_ph: image_g, tn_gen.sample_ph: sample_g, tn_gen.reward_ph: reward_g, tn_gen.hard_label_ph: label_dat_g, tn_gen.soft_logit_ph: soft_logit_g, } sess.run(tn_gen.kdgan_update, feed_dict=feed_dict) if flags.collect_cr_data: feed_dict = { vd_gen.image_ph: gen_mnist.test.images, vd_gen.hard_label_ph: gen_mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) acc_list.append(acc) if (batch_g + 1) % eval_interval != 0: continue else: if (batch_g + 1) % eval_interval != 0: continue feed_dict = { vd_gen.image_ph: gen_mnist.test.images, vd_gen.hard_label_ph: gen_mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) if acc > bst_gen_acc: bst_gen_acc = max(acc, bst_gen_acc) bst_eph = epoch tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) avg_time = (tot_time / global_step) * (tn_size / flags.batch_size) print( '#%08d gencur=%.4f genbst=%.4f tot=%.0fs avg=%.2fs/epoch' % (batch_g, acc, bst_gen_acc, tot_time, avg_time)) if acc <= bst_gen_acc: continue # save gen parameters if necessary tot_time = time.time() - start bst_gen_acc *= 100 bst_eph += 1 print('#mnist=%d kdgan_%s@%d=%.2f et=%.0fs' % (tn_size, flags.kdgan_model, bst_eph, bst_gen_acc, tot_time)) if flags.collect_cr_data: utils.create_pardir(flags.all_learning_curve_p) pickle.dump(acc_list, open(flags.all_learning_curve_p, 'wb'))
def main(_): best_prec, bst_epk = 0.0, 0 writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) tn_dis.saver.restore(sess, flags.dis_model_ckpt) tn_gen.saver.restore(sess, flags.gen_model_ckpt) start = time.time() init_prec = yfcceval.compute_prec(flags, sess, vd_gen) print('init@%d=%.4f' % (flags.cutoff, init_prec)) batch_d, batch_g = -1, -1 for epoch in range(flags.num_epoch): num_batch_d = math.ceil(flags.num_dis_epoch * tn_size / flags.batch_size) for _ in range(num_batch_d): batch_d += 1 image_np_d, _, label_dat_d = yfccdata_d.next_batch(flags, sess) feed_dict = {tn_gen.image_ph:image_np_d} label_gen_d, = sess.run([tn_gen.labels], feed_dict=feed_dict) sample_np_d, label_np_d = utils.gan_dis_sample(flags, label_dat_d, label_gen_d) feed_dict = { tn_dis.image_ph:image_np_d, tn_dis.sample_ph:sample_np_d, tn_dis.dis_label_ph:label_np_d, } _, summary_d = sess.run([tn_dis.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) num_batch_g = math.ceil(flags.num_gen_epoch * tn_size / flags.batch_size) for _ in range(num_batch_g): batch_g += 1 image_np_g, _, label_dat_g = yfccdata_g.next_batch(flags, sess) feed_dict = {tn_gen.image_ph:image_np_g} label_gen_g, = sess.run([tn_gen.labels], feed_dict=feed_dict) sample_np_g = utils.generate_label(flags, label_dat_g, label_gen_g) feed_dict = { tn_dis.image_ph:image_np_g, tn_dis.sample_ph:sample_np_g, } reward_np_g, = sess.run([tn_dis.rewards], feed_dict=feed_dict) feed_dict = { tn_gen.image_ph:image_np_g, tn_gen.sample_ph:sample_np_g, tn_gen.reward_ph:reward_np_g, } _, summary_g = sess.run([tn_gen.gan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) if (batch_g + 1) % eval_interval != 0: continue prec = yfcceval.compute_prec(flags, sess, vd_gen) if prec > best_prec: bst_epk = epoch best_prec = max(prec, best_prec) tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) avg_time = (tot_time / global_step) * (tn_size / flags.batch_size) print('#%08d@%d prec@%d=%.4f best@%d=%.4f tot=%.0fs avg=%.2fs/epoch' % (global_step, epoch, flags.cutoff, prec, bst_epk, best_prec, tot_time, avg_time)) if prec < best_prec: continue # save if necessary tot_time = time.time() - start print('best@%d=%.4f et=%.0fs' % (flags.cutoff, best_prec, tot_time))
def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) dis_summary_op = tf.summary.merge([ tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate), tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss), ]) gen_summary_op = tf.summary.merge([ tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate), tf.summary.scalar(gen_t.gan_loss.name, gen_t.gan_loss), ]) print(type(dis_summary_op), type(gen_summary_op)) init_op = tf.global_variables_initializer() data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size) user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size) user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) dis_t.saver.restore(sess, flags.dis_model_ckpt) gen_t.saver.restore(sess, flags.gen_model_ckpt) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v) print('init hit=%.4f' % (hit_v)) batch_d, batch_g = -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) num_batch_d = math.ceil(train_data_size / flags.batch_size) for _ in range(num_batch_d): batch_d += 1 image_np_d, label_dat_d = sess.run( [image_bt_d, label_bt_d]) feed_dict = {gen_t.image_ph: image_np_d} label_gen_d, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_np_d, label_np_d = utils.gan_dis_sample( flags, label_dat_d, label_gen_d) feed_dict = { dis_t.image_ph: image_np_d, dis_t.sample_ph: sample_np_d, dis_t.dis_label_ph: label_np_d, } _, summary_d = sess.run( [dis_t.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) for gen_epoch in range(flags.num_gen_epoch): print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) num_batch_g = math.ceil(train_data_size / flags.batch_size) for _ in range(num_batch_g): batch_g += 1 image_np_g, label_dat_g = sess.run( [image_bt_g, label_bt_g]) feed_dict = {gen_t.image_ph: image_np_g} label_gen_g, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_np_g = utils.generate_label( flags, label_dat_g, label_gen_g) feed_dict = { dis_t.image_ph: image_np_g, dis_t.sample_ph: sample_np_g, } reward_np_g, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_np_g, gen_t.sample_ph: sample_np_g, gen_t.reward_ph: reward_np_g, } _, summary_g = sess.run( [gen_t.gan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) # if (batch_g + 1) % eval_interval != 0: # continue # hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v) # tot_time = time.time() - start # print('#%08d hit=%.4f %06ds' % (batch_g, hit_v, int(tot_time))) # if hit_v <= best_hit_v: # continue # best_hit_v = hit_v # print('best hit=%.4f' % (best_hit_v)) hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tot_time = time.time() - start print('#%03d curbst=%.4f %.0fs' % (epoch, hit_v, tot_time)) figure_data.append((epoch, hit_v)) if hit_v <= best_hit_v: continue best_hit_v = hit_v print('bsthit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.gan_figure_data)) fout = open(flags.gan_figure_data, 'w') for epoch, hit_v in figure_data: fout.write('%d\t%.4f\n' % (epoch, hit_v)) fout.close()
def main(_): bst_gen_acc, bst_tch_acc, bst_eph = 0.0, 0.0, 0 utils.create_if_nonexist(flags.gradient_dir) if flags.log_accuracy: acc_history = [] if flags.evaluate_tch: tch_history = [] with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) tn_dis.saver.restore(sess, flags.dis_model_ckpt) tn_gen.saver.restore(sess, flags.gen_model_ckpt) tn_tch.saver.restore(sess, flags.tch_model_ckpt) feed_dict = { vd_dis.image_ph:dis_mnist.test.images, vd_dis.hard_label_ph:dis_mnist.test.labels, } ini_dis = sess.run(vd_dis.accuracy, feed_dict=feed_dict) feed_dict = { vd_gen.image_ph:gen_mnist.test.images, vd_gen.hard_label_ph:gen_mnist.test.labels, } ini_gen = sess.run(vd_gen.accuracy, feed_dict=feed_dict) print('ini dis=%.4f ini gen=%.4f' % (ini_dis, ini_gen)) # exit() start = time.time() batch_d, batch_g, batch_t = -1, -1, -1 gumbel_times = (math.log(flags.gumbel_end_temperature / flags.gumbel_temperature) / math.log(flags.gumbel_temperature_decay_factor)) for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): # print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) # num_batch_d = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_d): # image_d, label_dat_d = dis_mnist.train.next_batch(flags.batch_size) for image_d, label_dat_d in dis_datagen.generate(batch_size=flags.batch_size): batch_d += 1 feed_dict = {tn_gen.image_ph:image_d} label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict) sample_gen_d, gen_label_d = utils.gan_dis_sample(flags, label_dat_d, label_gen_d) feed_dict = {tn_tch.image_ph:image_d} label_tch_d = sess.run(tn_tch.labels, feed_dict=feed_dict) sample_tch_d, tch_label_d = utils.gan_dis_sample(flags, label_dat_d, label_tch_d) feed_dict = { tn_dis.image_ph:image_d, tn_dis.gen_sample_ph:sample_gen_d, tn_dis.gen_label_ph:gen_label_d, tn_dis.tch_sample_ph:sample_tch_d, tn_dis.tch_label_ph:tch_label_d, } sess.run(tn_dis.gan_update, feed_dict=feed_dict) for tch_epoch in range(flags.num_tch_epoch): # num_batch_t = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_t): # image_t, label_dat_t = tch_mnist.train.next_batch(flags.batch_size) for image_t, label_dat_t in tch_datagen.generate(batch_size=flags.batch_size): batch_t += 1 feed_dict = {tn_tch.image_ph:image_t} label_tch_t = sess.run(tn_tch.labels, feed_dict=feed_dict) sample_t = utils.generate_label(flags, label_dat_t, label_tch_t) feed_dict = { tn_dis.image_ph:image_t, tn_dis.tch_sample_ph:sample_t, } reward_t = sess.run(tn_dis.tch_rewards, feed_dict=feed_dict) feed_dict = {vd_gen.image_ph:image_t} soft_logit_t = sess.run(vd_gen.logits, feed_dict=feed_dict) feed_dict = { tn_tch.image_ph:image_t, tn_tch.sample_ph:sample_t, tn_tch.reward_ph:reward_t, tn_tch.hard_label_ph:label_dat_t, tn_tch.soft_logit_ph:soft_logit_t, } sess.run(tn_tch.kdgan_update, feed_dict=feed_dict) if not flags.evaluate_tch: continue if (batch_t + 1) % eval_interval != 0: continue feed_dict = { vd_tch.image_ph:gen_mnist.test.images, vd_tch.hard_label_ph:gen_mnist.test.labels, } tch_acc = sess.run(vd_tch.accuracy, feed_dict=feed_dict) bst_tch_acc = max(tch_acc, bst_tch_acc) print('#%08d tchcur=%.4f tchbst=%.4f' % (batch_t, tch_acc, bst_tch_acc)) tch_history.append(tch_acc) #### gumbel softmax if flags.enable_gumbel: if (epoch + 1) % max(int(flags.num_epoch / gumbel_times), 1) == 0: sess.run(tn_gen.gt_update) for gen_epoch in range(flags.num_gen_epoch): batch = -1 # num_batch_g = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_g): # image_g, label_dat_g = gen_mnist.train.next_batch(flags.batch_size) for image_g, label_dat_g in gen_datagen.generate(batch_size=flags.batch_size): batch_g += 1 batch += 1 epk_bat = '%d.%d' % (epoch*flags.num_gen_epoch+gen_epoch, batch) ggrads_file = path.join(flags.gradient_dir, 'kdgan_ggrads.%s.p' % epk_bat) kgrads_file = path.join(flags.gradient_dir, 'kdgan_kgrads.%s.p' % epk_bat) feed_dict = {tn_gen.image_ph:image_g} if not flags.enable_gumbel: label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict) else: label_gen_g = sess.run(tn_gen.gumbel_labels, feed_dict=feed_dict) sample_g = utils.generate_label(flags, label_dat_g, label_gen_g) feed_dict = { tn_dis.image_ph:image_g, tn_dis.gen_sample_ph:sample_g, } reward_g = sess.run(tn_dis.gen_rewards, feed_dict=feed_dict) # reward_g[reward_g>0.5] = 0.7 # reward_g[reward_g<0.5] = 0.3 feed_dict = {vd_tch.image_ph:image_g} soft_logit_g = sess.run(vd_tch.logits, feed_dict=feed_dict) # print(sample_g.shape, reward_g.shape, image_g.shape, soft_logit_g.shape) # exit() feed_dict = { tn_gen.image_ph:image_g, tn_gen.sample_ph:sample_g, tn_gen.reward_ph:reward_g, tn_gen.hard_label_ph:label_dat_g, tn_gen.soft_logit_ph:soft_logit_g, } # sess.run(tn_gen.kdgan_update, feed_dict=feed_dict) if flags.log_gradient: fetches = [tn_gen.kdgan_ggrads, tn_gen.kdgan_kgrads, tn_gen.kdgan_update] kdgan_ggrads, kdgan_kgrads, _ = sess.run(fetches, feed_dict=feed_dict) pickle.dump(kdgan_ggrads, open(ggrads_file, 'wb')) pickle.dump(kdgan_kgrads, open(kgrads_file, 'wb')) else: sess.run(tn_gen.kdgan_update, feed_dict=feed_dict) if flags.log_accuracy: feed_dict = { vd_gen.image_ph:gen_mnist.test.images, vd_gen.hard_label_ph:gen_mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) acc_history.append(acc) if (batch_g + 1) % eval_interval != 0: continue else: if (batch_g + 1) % eval_interval != 0: continue feed_dict = { vd_gen.image_ph:gen_mnist.test.images, vd_gen.hard_label_ph:gen_mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) if acc > bst_gen_acc: bst_gen_acc = max(acc, bst_gen_acc) bst_eph = epoch tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) # avg_time = (tot_time / global_step) * (tn_size / flags.batch_size) if flags.evaluate_tch: gen_tch_pct =100 * bst_gen_acc / bst_tch_acc print('#%08d/%08d gencur=%.4f genbst=%.4f (%.2f) tot=%.0fs' % (batch_g, tot_batch, acc, bst_gen_acc, gen_tch_pct, tot_time)) else: print('#%08d/%08d gencur=%.4f genbst=%.4f tot=%.0fs' % (batch_g, tot_batch, acc, bst_gen_acc, tot_time)) stdout.flush() if acc <= bst_gen_acc: continue # save gen parameters if necessary gumbel_temperature = sess.run(tn_gen.gumbel_temperature) print('gumbel_temperature=%.4f' % gumbel_temperature) tot_time = time.time() - start bst_gen_acc *= 100 bst_eph += 1 print('#mnist=%d kdgan@%d=%.2f et=%.0fs' % (tn_size, bst_eph, bst_gen_acc, tot_time)) if flags.log_accuracy: utils.create_pardir(flags.acc_file) pickle.dump(acc_history, open(flags.acc_file, 'wb')) if flags.evaluate_tch: utils.create_pardir(flags.tch_file) pickle.dump(tch_history, open(flags.tch_file, 'wb'))
def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) dis_summary_op = tf.summary.merge([ tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate), tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss), ]) gen_summary_op = tf.summary.merge([ tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate), tf.summary.scalar(gen_t.kdgan_loss.name, gen_t.kdgan_loss), ]) tch_summary_op = tf.summary.merge([ tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate), tf.summary.scalar(tch_t.kdgan_loss.name, tch_t.kdgan_loss), ]) init_op = tf.global_variables_initializer() data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size) user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size) user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) dis_t.saver.restore(sess, flags.dis_model_ckpt) gen_t.saver.restore(sess, flags.gen_model_ckpt) tch_t.saver.restore(sess, flags.tch_model_ckpt) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v) print('hit gen=%.4f tch=%.4f' % (gen_hit, tch_hit)) batch_d, batch_g, batch_t = -1, -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) for _ in range(num_batch_per_epoch): #continue batch_d += 1 image_d, text_d, label_dat_d = sess.run( [image_bt_d, text_bt_d, label_bt_d]) feed_dict = {gen_t.image_ph: image_d} label_gen_d, = sess.run([gen_t.labels], feed_dict=feed_dict) # print('gen label', label_gen_d.shape) feed_dict = { tch_t.text_ph: text_d, tch_t.image_ph: image_d } label_tch_d, = sess.run([tch_t.labels], feed_dict=feed_dict) # print('tch label', label_tch_d.shape) sample_d, label_d = utils.kdgan_dis_sample( flags, label_dat_d, label_gen_d, label_tch_d) # print(sample_d.shape, label_d.shape) feed_dict = { dis_t.image_ph: image_d, dis_t.sample_ph: sample_d, dis_t.dis_label_ph: label_d, } _, summary_d = sess.run( [dis_t.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) for tch_epoch in range(flags.num_tch_epoch): print('epoch %03d tch_epoch %03d' % (epoch, tch_epoch)) for _ in range(num_batch_per_epoch): #continue batch_t += 1 image_t, text_t, label_dat_t = sess.run( [image_bt_t, text_bt_t, label_bt_t]) feed_dict = { tch_t.text_ph: text_t, tch_t.image_ph: image_t } label_tch_t, = sess.run([tch_t.labels], feed_dict=feed_dict) sample_t = utils.generate_label( flags, label_dat_t, label_tch_t) feed_dict = { dis_t.image_ph: image_t, dis_t.sample_ph: sample_t, } reward_t, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_t, } label_gen_g = sess.run(gen_t.logits, feed_dict=feed_dict) #print(len(label_dat_t), len(label_dat_t[0])) #exit() feed_dict = { tch_t.text_ph: text_t, tch_t.image_ph: image_t, tch_t.sample_ph: sample_t, tch_t.reward_ph: reward_t, tch_t.hard_label_ph: label_dat_t, tch_t.soft_label_ph: label_gen_g, } _, summary_t, tch_kdgan_loss = sess.run( [ tch_t.kdgan_update, tch_summary_op, tch_t.kdgan_loss ], feed_dict=feed_dict) writer.add_summary(summary_t, batch_t) #print("teacher kdgan loss:", tch_kdgan_loss) for gen_epoch in range(flags.num_gen_epoch): print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) for _ in range(num_batch_per_epoch): batch_g += 1 image_g, text_g, label_dat_g = sess.run( [image_bt_g, text_bt_g, label_bt_g]) feed_dict = { tch_t.text_ph: text_g, tch_t.image_ph: image_g } label_tch_g, = sess.run([tch_t.labels], feed_dict=feed_dict) # print('tch label {}'.format(label_tch_g.shape)) feed_dict = {gen_t.image_ph: image_g} label_gen_g, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_g = utils.generate_label( flags, label_dat_g, label_gen_g) feed_dict = { dis_t.image_ph: image_g, dis_t.sample_ph: sample_g, } reward_g, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_g, gen_t.hard_label_ph: label_dat_g, gen_t.soft_label_ph: label_tch_g, gen_t.sample_ph: sample_g, gen_t.reward_ph: reward_g, } _, summary_g = sess.run( [gen_t.kdgan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) # if (batch_g + 1) % eval_interval != 0: # continue # gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) # tot_time = time.time() - start # print('#%08d hit=%.4f %06ds' % (batch_g, gen_hit, int(tot_time))) # if gen_hit <= best_hit_v: # continue # best_hit_v = gen_hit # print('best hit=%.4f' % (best_hit_v)) gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v) tot_time = time.time() - start print('#%03d curgen=%.4f curtch=%.4f %.0fs' % (epoch, gen_hit, tch_hit, tot_time)) figure_data.append((epoch, gen_hit, tch_hit)) if gen_hit <= best_hit_v: continue best_hit_v = gen_hit print("epoch ", epoch + 1, ":, new best validation hit:", best_hit_v, "saving...") gen_t.saver.save(sess, flags.kdgan_model_ckpt, global_step=epoch + 1) print("finish saving") print('best hit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.kdgan_figure_data)) fout = open(flags.kdgan_figure_data, 'w') for epoch, gen_hit, tch_hit in figure_data: fout.write('%d\t%.4f\t%.4f\n' % (epoch, gen_hit, tch_hit)) fout.close()