示例#1
0
def main(_):
    bst_acc = 0.0
    acc_list = []
    writer = tf.summary.FileWriter(config.logs_dir,
                                   graph=tf.get_default_graph())
    with tf.train.MonitoredTrainingSession() as sess:
        sess.run(init_op)
        tn_dis.saver.restore(sess, flags.dis_model_ckpt)
        tn_gen.saver.restore(sess, flags.gen_model_ckpt)

        feed_dict = {
            vd_dis.image_ph: dis_mnist.test.images,
            vd_dis.hard_label_ph: dis_mnist.test.labels,
        }
        ini_dis = sess.run(vd_dis.accuracy, feed_dict=feed_dict)
        feed_dict = {
            vd_gen.image_ph: gen_mnist.test.images,
            vd_gen.hard_label_ph: gen_mnist.test.labels,
        }
        ini_gen = sess.run(vd_gen.accuracy, feed_dict=feed_dict)
        print('ini dis=%.4f ini gen=%.4f' % (ini_dis, ini_gen))
        # exit()

        start = time.time()
        batch_d, batch_g = -1, -1
        for epoch in range(flags.num_epoch):
            for dis_epoch in range(flags.num_dis_epoch):
                # print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                num_batch_d = math.ceil(tn_size / flags.batch_size)
                for image_np_d, label_dat_d in dis_datagen.generate(
                        batch_size=flags.batch_size):
                    # for _ in range(num_batch_d):
                    #   image_np_d, label_dat_d = dis_mnist.train.next_batch(flags.batch_size)
                    batch_d += 1
                    feed_dict = {tn_gen.image_ph: image_np_d}
                    label_gen_d, = sess.run([tn_gen.labels],
                                            feed_dict=feed_dict)
                    # print('label_dat_d={} label_gen_d={}'.format(label_dat_d.shape, label_gen_d.shape))
                    sample_np_d, label_np_d = utils.gan_dis_sample_dev(
                        flags, label_dat_d, label_gen_d)
                    feed_dict = {
                        tn_dis.image_ph: image_np_d,
                        tn_dis.sample_ph: sample_np_d,
                        tn_dis.dis_label_ph: label_np_d,
                    }
                    _, summary_d = sess.run(
                        [tn_dis.gan_update, dis_summary_op],
                        feed_dict=feed_dict)
                    writer.add_summary(summary_d, batch_d)

            for gen_epoch in range(flags.num_gen_epoch):
                # print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                num_batch_g = math.ceil(tn_size / flags.batch_size)
                for image_np_g, label_dat_g in gen_datagen.generate(
                        batch_size=flags.batch_size):
                    # for _ in range(num_batch_g):
                    #   image_np_g, label_dat_g = gen_mnist.train.next_batch(flags.batch_size)
                    batch_g += 1
                    feed_dict = {tn_gen.image_ph: image_np_g}
                    label_gen_g, = sess.run([tn_gen.labels],
                                            feed_dict=feed_dict)
                    sample_np_g = utils.generate_label(flags, label_dat_g,
                                                       label_gen_g)
                    # sample_np_g, rescale_np_g = utils.generate_label(flags, label_dat_g, label_gen_g)
                    # print(sample_np_g.shape, rescale_np_g.shape)
                    feed_dict = {
                        tn_dis.image_ph: image_np_g,
                        tn_dis.sample_ph: sample_np_g,
                    }
                    reward_np_g, = sess.run([tn_dis.rewards],
                                            feed_dict=feed_dict)
                    # reward_np_g *= rescale_np_g
                    # print(reward_np_g)
                    feed_dict = {
                        tn_gen.image_ph: image_np_g,
                        tn_gen.sample_ph: sample_np_g,
                        tn_gen.reward_ph: reward_np_g,
                    }
                    _, summary_g = sess.run(
                        [tn_gen.gan_update, gen_summary_op],
                        feed_dict=feed_dict)
                    writer.add_summary(summary_g, batch_g)

                    if flags.collect_cr_data:
                        feed_dict = {
                            vd_gen.image_ph: gen_mnist.test.images,
                            vd_gen.hard_label_ph: gen_mnist.test.labels,
                        }
                        acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict)
                        acc_list.append(acc)
                        if (batch_g + 1) % eval_interval != 0:
                            continue
                    else:
                        if (batch_g + 1) % eval_interval != 0:
                            continue
                        feed_dict = {
                            vd_gen.image_ph: gen_mnist.test.images,
                            vd_gen.hard_label_ph: gen_mnist.test.labels,
                        }
                        acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict)

                    bst_acc = max(acc, bst_acc)
                    tot_time = time.time() - start
                    global_step = sess.run(tn_gen.global_step)
                    avg_time = (tot_time / global_step) * (tn_size /
                                                           flags.batch_size)
                    print(
                        '#%08d curacc=%.4f curbst=%.4f tot=%.0fs avg=%.2fs/epoch'
                        % (batch_g, acc, bst_acc, tot_time, avg_time))

                    if acc <= bst_acc:
                        continue
                    # save gen parameters if necessary
    tot_time = time.time() - start
    print('#mnist=%d bstacc=%.4f et=%.0fs' % (tn_size, bst_acc, tot_time))

    if flags.collect_cr_data:
        utils.create_pardir(flags.all_learning_curve_p)
        pickle.dump(acc_list, open(flags.all_learning_curve_p, 'wb'))
示例#2
0
def main(_):
  best_prec, bst_epk = 0.0, 0
  epk_score_list = []
  writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph())
  with tf.train.MonitoredTrainingSession() as sess:
    sess.run(init_op)
    tn_dis.saver.restore(sess, flags.dis_model_ckpt)
    tn_gen.saver.restore(sess, flags.gen_model_ckpt)
    tn_tch.saver.restore(sess, flags.tch_model_ckpt)
    start = time.time()

    ini_dis = yfcceval.compute_prec(flags, sess, vd_dis)
    ini_gen = yfcceval.compute_prec(flags, sess, vd_gen)
    ini_tch = yfcceval.compute_prec(flags, sess, vd_tch)
    print('ini dis=%.4f gen=%.4f tch=%.4f' % (ini_dis, ini_gen, ini_tch))

    batch_d, batch_g, batch_t = -1, -1, -1
    for epoch in range(flags.num_epoch):
      num_batch_d = math.ceil(flags.num_dis_epoch * tn_size / flags.batch_size)
      for _ in range(num_batch_d):
        batch_d += 1
        image_d, text_d, label_dat_d = yfccdata_d.next_batch(flags, sess)
        
        feed_dict = {tn_gen.image_ph:image_d}
        label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict)
        sample_gen_d, gen_label_d = utils.gan_dis_sample(flags, label_dat_d, label_gen_d)

        feed_dict = {tn_tch.image_ph:image_d, tn_tch.text_ph:text_d}
        label_tch_d = sess.run(tn_tch.labels, feed_dict=feed_dict)
        sample_tch_d, tch_label_d = utils.gan_dis_sample(flags, label_dat_d, label_tch_d)

        feed_dict = {
          tn_dis.image_ph:image_d,
          tn_dis.gen_sample_ph:sample_gen_d,
          tn_dis.gen_label_ph:gen_label_d,
          tn_dis.tch_sample_ph:sample_tch_d,
          tn_dis.tch_label_ph:tch_label_d,
        }
        _, summary_d = sess.run([tn_dis.gan_update, dis_summary_op], feed_dict=feed_dict)
        writer.add_summary(summary_d, batch_d)

      num_batch_t = math.ceil(flags.num_tch_epoch * tn_size / flags.batch_size)
      for _ in range(num_batch_t):
        batch_t += 1
        image_t, text_t, label_dat_t = yfccdata_t.next_batch(flags, sess)

        feed_dict = {tn_tch.image_ph:image_t, tn_tch.text_ph:text_t}
        label_tch_t = sess.run(tn_tch.labels, feed_dict=feed_dict)
        sample_t = utils.generate_label(flags, label_dat_t, label_tch_t)
        feed_dict = {tn_dis.image_ph:image_t, tn_dis.tch_sample_ph:sample_t}
        reward_t = sess.run(tn_dis.tch_rewards, feed_dict=feed_dict)

        feed_dict = {vd_gen.image_ph:image_t}
        soft_logit_t = sess.run(vd_gen.logits, feed_dict=feed_dict)
        feed_dict = {
          tn_tch.image_ph:image_t,
          tn_tch.text_ph:text_t,
          tn_tch.sample_ph:sample_t,
          tn_tch.reward_ph:reward_t,
          tn_tch.hard_label_ph:label_dat_t,
          tn_tch.soft_logit_ph:soft_logit_t,
        }
        
        _, summary_t = sess.run([tn_tch.kdgan_update, tch_summary_op], feed_dict=feed_dict)
        writer.add_summary(summary_t, batch_t)

      num_batch_g = math.ceil(flags.num_gen_epoch * tn_size / flags.batch_size)
      for _ in range(num_batch_g):
        batch_g += 1
        image_g, text_g, label_dat_g = yfccdata_g.next_batch(flags, sess)

        feed_dict = {tn_tch.image_ph:image_g, tn_tch.text_ph:text_g}
        logit_tch_g = sess.run(tn_tch.logits, feed_dict=feed_dict)
        # print('tch label {}'.format(logit_tch_g.shape))

        feed_dict = {tn_gen.image_ph:image_g}
        label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict)
        sample_g = utils.generate_label(flags, label_dat_g, label_gen_g)
        feed_dict = {tn_dis.image_ph:image_g, tn_dis.gen_sample_ph:sample_g}
        reward_g = sess.run(tn_dis.gen_rewards, feed_dict=feed_dict)

        feed_dict = {
          tn_gen.image_ph:image_g,
          tn_gen.hard_label_ph:label_dat_g,
          tn_gen.soft_logit_ph:logit_tch_g,
          tn_gen.sample_ph:sample_g,
          tn_gen.reward_ph:reward_g,
        }
        _, summary_g = sess.run([tn_gen.kdgan_update, gen_summary_op], feed_dict=feed_dict)
        writer.add_summary(summary_g, batch_g)

        if (batch_g + 1) % eval_interval != 0:
            continue
        scores = yfcceval.compute_score(flags, sess, vd_gen)
        p3, p5, f3, f5, ndcg3, ndcg5, ap, rr = scores
        # print('p3=%.4f p5=%.4f f3=%.4f f5=%.4f ndcg3=%.4f ndcg5=%.4f ap=%.4f rr=%.4f' % 
        #     (p3, p5, f3, f5, ndcg3, ndcg5, ap, rr))
        epk_score_list.append((scores))
        prec = yfcceval.compute_prec(flags, sess, vd_gen)
        if prec > best_prec:
          bst_epk = epoch
        best_prec = max(prec, best_prec)
        tot_time = time.time() - start
        global_step = sess.run(tn_gen.global_step)
        avg_time = (tot_time / global_step) * (tn_size / flags.batch_size)
        print('#%08d@%d prec@%d=%.4f best@%d=%.4f tot=%.0fs avg=%.2fs/epoch' % 
            (global_step, epoch, flags.cutoff, prec, bst_epk, best_prec, tot_time, avg_time))

        if prec < best_prec:
          continue
        # save if necessary
  tot_time = time.time() - start
  print('best@%d=%.4f et=%.0fs' % (bst_epk, best_prec, tot_time))

  utils.create_pardir(flags.epk_learning_curve_p)
  pickle.dump(epk_score_list, open(flags.epk_learning_curve_p, 'wb'))
示例#3
0
def main(_):
    best_prec = 0.0
    prec_list = []
    writer = tf.summary.FileWriter(config.logs_dir,
                                   graph=tf.get_default_graph())
    with tf.train.MonitoredTrainingSession() as sess:
        sess.run(init_op)
        tn_dis.saver.restore(sess, flags.dis_model_ckpt)
        tn_gen.saver.restore(sess, flags.gen_model_ckpt)

        ini_dis = yfcceval.compute_prec(flags, sess, vd_dis)
        ini_gen = yfcceval.compute_prec(flags, sess, vd_gen)
        print('ini dis=%.4f gen=%.4f' % (ini_dis, ini_gen))

        start = time.time()
        batch_d, batch_g = -1, -1
        for epoch in range(flags.num_epoch):
            num_batch_d = math.ceil(flags.num_dis_epoch * tn_data_size /
                                    flags.batch_size)
            for _ in range(num_batch_d):
                batch_d += 1
                image_np_d, _, label_dat_d = yfccdata_d.next_batch(flags, sess)
                feed_dict = {tn_gen.image_ph: image_np_d}
                label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict)
                sample_np_d, label_np_d = utils.gan_dis_sample(
                    flags, label_dat_d, label_gen_d)
                feed_dict = {
                    tn_dis.image_ph: image_np_d,
                    tn_dis.sample_ph: sample_np_d,
                    tn_dis.dis_label_ph: label_np_d,
                }
                _, summary_d = sess.run([tn_dis.gan_update, dis_summary_op],
                                        feed_dict=feed_dict)
                writer.add_summary(summary_d, batch_d)

            num_batch_g = math.ceil(flags.num_gen_epoch * tn_data_size /
                                    flags.batch_size)
            for _ in range(num_batch_g):
                batch_g += 1
                image_np_g, _, label_dat_g = yfccdata_g.next_batch(flags, sess)
                feed_dict = {tn_gen.image_ph: image_np_g}
                label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict)
                sample_np_g = utils.generate_label(flags, label_dat_g,
                                                   label_gen_g)
                feed_dict = {
                    tn_dis.image_ph: image_np_g,
                    tn_dis.sample_ph: sample_np_g,
                }
                reward_np_g = sess.run(tn_dis.rewards, feed_dict=feed_dict)
                feed_dict = {
                    tn_gen.image_ph: image_np_g,
                    tn_gen.sample_ph: sample_np_g,
                    tn_gen.reward_ph: reward_np_g,
                }
                _, summary_g = sess.run([tn_gen.gan_update, gen_summary_op],
                                        feed_dict=feed_dict)
                writer.add_summary(summary_g, batch_g)

                prec = yfcceval.compute_prec(flags, sess, vd_gen)
                prec_list.append(prec)
                if (batch_g + 1) % eval_interval != 0:
                    continue
                best_prec = max(prec, best_prec)
                tot_time = time.time() - start
                global_step = sess.run(tn_gen.global_step)
                avg_time = (tot_time / global_step) * (tn_data_size /
                                                       flags.batch_size)
                print(
                    '#%08d prec@%d=%.4f best=%.4f tot=%.0fs avg=%.2fs/epoch' %
                    (global_step, flags.cutoff, prec, best_prec, tot_time,
                     avg_time))

                if prec < best_prec:
                    continue
                # hit_v = utils.evaluate(flags, sess, vd_gen, bt_list_v)
                # tot_time = time.time() - start
                # print('#%03d curbst=%.4f %.0fs' % (epoch, hit_v, tot_time))
                # figure_data.append((epoch, hit_v))
                # if hit_v <= best_prec:
                #   continue
                # best_prec = hit_v
    tot_time = time.time() - start
    print('best@%d=%.4f et=%.0fs' % (flags.cutoff, best_prec, tot_time))

    utils.create_pardir(flags.all_learning_curve_p)
    pickle.dump(prec_list, open(flags.all_learning_curve_p, 'wb'))
示例#4
0
def main(_):
    bst_acc = 0.0
    start_time = time.time()
    with tf.train.MonitoredTrainingSession() as sess:
        sess.run(init_op)
        tn_dis.saver.restore(sess, flags.dis_model_ckpt)
        tn_std.saver.restore(sess, flags.std_model_ckpt)

        ini_dis = cifar.compute_acc(sess, vd_dis)
        ini_std = cifar.compute_acc(sess, vd_std)

        print('ini dis=%.4f ini std=%.4f' % (ini_dis, ini_std))

        batch_d, batch_s = -1, -1
        for epoch in range(flags.num_epoch):
            num_batch_d = math.ceil(flags.num_dis_epoch * flags.train_size /
                                    flags.batch_size)
            for _ in range(num_batch_d):
                batch_d += 1
                tn_image_d, label_dat_d = cifar.next_batch(sess)
                feed_dict = {tn_std.image_ph: tn_image_d}
                label_std_d = sess.run(tn_std.labels, feed_dict=feed_dict)
                sample_np_d, std_label_d = utils.gan_dis_sample(
                    flags, label_dat_d, label_std_d)
                feed_dict = {
                    tn_dis.image_ph: tn_image_d,
                    tn_dis.std_sample_ph: sample_np_d,
                    tn_dis.std_label_ph: std_label_d,
                }
                sess.run(tn_dis.gan_train, feed_dict=feed_dict)

                if (batch_d + 1) % eval_interval != 0:
                    continue
                end_time = time.time()
                duration = (end_time - start_time) / 3600
                print('dis #batch=%d duration=%.4fh' % (batch_d, duration))
                # evaluate dis if necessary

            num_batch_s = math.ceil(flags.num_std_epoch * flags.train_size /
                                    flags.batch_size)
            for _ in range(num_batch_s):
                batch_s += 1
                tn_image_s, label_dat_s = cifar.next_batch(sess)
                feed_dict = {tn_std.image_ph: tn_image_s}
                label_gen_s = sess.run(tn_std.labels, feed_dict=feed_dict)
                sample_np_s = utils.generate_label(flags, label_dat_s,
                                                   label_gen_s)
                feed_dict = {
                    tn_dis.image_ph: tn_image_s,
                    tn_dis.std_sample_ph: sample_np_s
                }
                reward_np_s = sess.run(tn_dis.std_rewards, feed_dict=feed_dict)
                feed_dict = {
                    tn_std.image_ph: tn_image_s,
                    tn_std.sample_ph: sample_np_s,
                    tn_std.reward_ph: reward_np_s,
                }
                sess.run(tn_std.gan_train, feed_dict=feed_dict)

                if (batch_s + 1) % eval_interval != 0:
                    continue
                acc = cifar.compute_acc(sess, vd_std)
                bst_acc = max(acc, bst_acc)
                end_time = time.time()
                duration = (end_time - start_time) / 3600
                print('gen #batch=%d acc=%.4f bst_acc=%.4f duration=%.4fh' %
                      (batch_s + 1, acc, bst_acc, duration))

                if acc < bst_acc:
                    continue
                # save std if necessary
    tot_time = time.time() - start_time
    print('#cifar=%d final=%.4f et=%.0fs' %
          (flags.train_size, bst_acc, tot_time))
示例#5
0
def main(_):
    bst_gen_acc, bst_tch_acc, bst_eph = 0.0, 0.0, 0
    acc_list = []
    writer = tf.summary.FileWriter(config.logs_dir,
                                   graph=tf.get_default_graph())
    with tf.train.MonitoredTrainingSession() as sess:
        sess.run(init_op)
        tn_dis.saver.restore(sess, flags.dis_model_ckpt)
        tn_gen.saver.restore(sess, flags.gen_model_ckpt)
        tn_tch.saver.restore(sess, flags.tch_model_ckpt)

        feed_dict = {
            vd_dis.image_ph: dis_mnist.test.images,
            vd_dis.hard_label_ph: dis_mnist.test.labels,
        }
        ini_dis = sess.run(vd_dis.accuracy, feed_dict=feed_dict)
        feed_dict = {
            vd_gen.image_ph: gen_mnist.test.images,
            vd_gen.hard_label_ph: gen_mnist.test.labels,
        }
        ini_gen = sess.run(vd_gen.accuracy, feed_dict=feed_dict)
        print('ini dis=%.4f ini gen=%.4f' % (ini_dis, ini_gen))
        # exit()

        start = time.time()
        batch_d, batch_g, batch_t = -1, -1, -1
        for epoch in range(flags.num_epoch):
            for dis_epoch in range(flags.num_dis_epoch):
                # print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                # num_batch_d = math.ceil(tn_size / flags.batch_size)
                # for _ in range(num_batch_d):
                #   image_d, label_dat_d = dis_mnist.train.next_batch(flags.batch_size)
                for image_d, label_dat_d in dis_datagen.generate(
                        batch_size=flags.batch_size):
                    batch_d += 1

                    # feed_dict = {tn_gen.image_ph:image_d}
                    # label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict)
                    # sample_gen_d, dis_label_gen = utils.gan_dis_sample(flags, label_dat_d, label_gen_d)
                    # feed_dict = {
                    #   tn_dis.image_ph:image_d,
                    #   tn_dis.sample_ph:sample_gen_d,
                    #   tn_dis.dis_label_ph:dis_label_gen,
                    # }
                    # sess.run(tn_dis.gan_update, feed_dict=feed_dict)

                    feed_dict = {tn_tch.image_ph: image_d}
                    label_tch_d = sess.run(tn_tch.labels, feed_dict=feed_dict)
                    sample_tch_d, dis_label_tch = utils.gan_dis_sample(
                        flags, label_dat_d, label_tch_d)
                    feed_dict = {
                        tn_dis.image_ph: image_d,
                        tn_dis.sample_ph: sample_tch_d,
                        tn_dis.dis_label_ph: dis_label_tch,
                    }
                    sess.run(tn_dis.gan_update, feed_dict=feed_dict)

            for tch_epoch in range(flags.num_tch_epoch):
                # num_batch_t = math.ceil(tn_size / flags.batch_size)
                # for _ in range(num_batch_t):
                #   image_t, label_dat_t = tch_mnist.train.next_batch(flags.batch_size)
                for image_t, label_dat_t in tch_datagen.generate(
                        batch_size=flags.batch_size):
                    batch_t += 1

                    feed_dict = {tn_tch.image_ph: image_t}
                    label_tch_t = sess.run(tn_tch.labels, feed_dict=feed_dict)
                    sample_t = utils.generate_label(flags, label_dat_t,
                                                    label_tch_t)
                    feed_dict = {
                        tn_dis.image_ph: image_t,
                        tn_dis.sample_ph: sample_t,
                    }
                    reward_t = sess.run(tn_dis.rewards, feed_dict=feed_dict)

                    feed_dict = {
                        tn_tch.image_ph: image_t,
                        tn_tch.sample_ph: sample_t,
                        tn_tch.reward_ph: reward_t,
                    }

                    if flags.kdgan_model != config.kdgan_odgan_flag:
                        feed_dict = {vd_gen.image_ph: image_t}
                        soft_logit_t = sess.run(vd_gen.logits,
                                                feed_dict=feed_dict)
                        feed_dict = {
                            tn_tch.image_ph: image_t,
                            tn_tch.sample_ph: sample_t,
                            tn_tch.reward_ph: reward_t,
                            tn_tch.hard_label_ph: label_dat_t,
                            tn_tch.soft_logit_ph: soft_logit_t,
                        }

                    sess.run(tn_tch.kdgan_update, feed_dict=feed_dict)

                    if (batch_t + 1) % eval_interval != 0:
                        continue
                    feed_dict = {
                        vd_tch.image_ph: gen_mnist.test.images,
                        vd_tch.hard_label_ph: gen_mnist.test.labels,
                    }
                    tch_acc = sess.run(vd_tch.accuracy, feed_dict=feed_dict)

                    # bst_tch_acc = max(tch_acc, bst_tch_acc)
                    # print('#%08d tchcur=%.4f tchbst=%.4f' % (batch_t, tch_acc, bst_tch_acc))

            for gen_epoch in range(flags.num_gen_epoch):
                # print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                # num_batch_g = math.ceil(tn_size / flags.batch_size)
                # for _ in range(num_batch_g):
                #   image_g, label_dat_g = gen_mnist.train.next_batch(flags.batch_size)
                for image_g, label_dat_g in gen_datagen.generate(
                        batch_size=flags.batch_size):
                    batch_g += 1

                    feed_dict = {tn_gen.image_ph: image_g}
                    label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict)
                    sample_g = utils.generate_label(flags, label_dat_g,
                                                    label_gen_g)
                    feed_dict = {
                        tn_dis.image_ph: image_g,
                        tn_dis.sample_ph: sample_g,
                    }
                    reward_g = sess.run(tn_dis.rewards, feed_dict=feed_dict)

                    feed_dict = {vd_tch.image_ph: image_g}
                    soft_logit_g = sess.run(vd_tch.logits, feed_dict=feed_dict)

                    feed_dict = {
                        tn_gen.image_ph: image_g,
                        tn_gen.sample_ph: sample_g,
                        tn_gen.reward_ph: reward_g,
                        tn_gen.hard_label_ph: label_dat_g,
                        tn_gen.soft_logit_ph: soft_logit_g,
                    }
                    sess.run(tn_gen.kdgan_update, feed_dict=feed_dict)

                    if flags.collect_cr_data:
                        feed_dict = {
                            vd_gen.image_ph: gen_mnist.test.images,
                            vd_gen.hard_label_ph: gen_mnist.test.labels,
                        }
                        acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict)
                        acc_list.append(acc)
                        if (batch_g + 1) % eval_interval != 0:
                            continue
                    else:
                        if (batch_g + 1) % eval_interval != 0:
                            continue
                        feed_dict = {
                            vd_gen.image_ph: gen_mnist.test.images,
                            vd_gen.hard_label_ph: gen_mnist.test.labels,
                        }
                        acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict)

                    if acc > bst_gen_acc:
                        bst_gen_acc = max(acc, bst_gen_acc)
                        bst_eph = epoch
                    tot_time = time.time() - start
                    global_step = sess.run(tn_gen.global_step)
                    avg_time = (tot_time / global_step) * (tn_size /
                                                           flags.batch_size)
                    print(
                        '#%08d gencur=%.4f genbst=%.4f tot=%.0fs avg=%.2fs/epoch'
                        % (batch_g, acc, bst_gen_acc, tot_time, avg_time))

                    if acc <= bst_gen_acc:
                        continue
                    # save gen parameters if necessary
    tot_time = time.time() - start
    bst_gen_acc *= 100
    bst_eph += 1
    print('#mnist=%d kdgan_%s@%d=%.2f et=%.0fs' %
          (tn_size, flags.kdgan_model, bst_eph, bst_gen_acc, tot_time))

    if flags.collect_cr_data:
        utils.create_pardir(flags.all_learning_curve_p)
        pickle.dump(acc_list, open(flags.all_learning_curve_p, 'wb'))
示例#6
0
def main(_):
  best_prec, bst_epk = 0.0, 0
  writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph())
  with tf.train.MonitoredTrainingSession() as sess:
    sess.run(init_op)
    tn_dis.saver.restore(sess, flags.dis_model_ckpt)
    tn_gen.saver.restore(sess, flags.gen_model_ckpt)
    start = time.time()

    init_prec = yfcceval.compute_prec(flags, sess, vd_gen)
    print('init@%d=%.4f' % (flags.cutoff, init_prec))

    batch_d, batch_g = -1, -1
    for epoch in range(flags.num_epoch):
      num_batch_d = math.ceil(flags.num_dis_epoch * tn_size / flags.batch_size)
      for _ in range(num_batch_d):
        batch_d += 1
        image_np_d, _, label_dat_d = yfccdata_d.next_batch(flags, sess)
        feed_dict = {tn_gen.image_ph:image_np_d}
        label_gen_d, = sess.run([tn_gen.labels], feed_dict=feed_dict)
        sample_np_d, label_np_d = utils.gan_dis_sample(flags, label_dat_d, label_gen_d)
        feed_dict = {
          tn_dis.image_ph:image_np_d,
          tn_dis.sample_ph:sample_np_d,
          tn_dis.dis_label_ph:label_np_d,
        }
        _, summary_d = sess.run([tn_dis.gan_update, dis_summary_op], feed_dict=feed_dict)
        writer.add_summary(summary_d, batch_d)

      num_batch_g = math.ceil(flags.num_gen_epoch * tn_size / flags.batch_size)
      for _ in range(num_batch_g):
        batch_g += 1
        image_np_g, _, label_dat_g = yfccdata_g.next_batch(flags, sess)
        feed_dict = {tn_gen.image_ph:image_np_g}
        label_gen_g, = sess.run([tn_gen.labels], feed_dict=feed_dict)
        sample_np_g = utils.generate_label(flags, label_dat_g, label_gen_g)
        feed_dict = {
          tn_dis.image_ph:image_np_g,
          tn_dis.sample_ph:sample_np_g,
        }
        reward_np_g, = sess.run([tn_dis.rewards], feed_dict=feed_dict)
        feed_dict = {
          tn_gen.image_ph:image_np_g,
          tn_gen.sample_ph:sample_np_g,
          tn_gen.reward_ph:reward_np_g,
        }
        _, summary_g = sess.run([tn_gen.gan_update, gen_summary_op], feed_dict=feed_dict)
        writer.add_summary(summary_g, batch_g)
        
        if (batch_g + 1) % eval_interval != 0:
            continue
        prec = yfcceval.compute_prec(flags, sess, vd_gen)
        if prec > best_prec:
          bst_epk = epoch
        best_prec = max(prec, best_prec)
        tot_time = time.time() - start
        global_step = sess.run(tn_gen.global_step)
        avg_time = (tot_time / global_step) * (tn_size / flags.batch_size)
        print('#%08d@%d prec@%d=%.4f best@%d=%.4f tot=%.0fs avg=%.2fs/epoch' % 
            (global_step, epoch, flags.cutoff, prec, bst_epk, best_prec, tot_time, avg_time))

        if prec < best_prec:
          continue
        # save if necessary
  tot_time = time.time() - start
  print('best@%d=%.4f et=%.0fs' % (flags.cutoff, best_prec, tot_time))
示例#7
0
def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    dis_summary_op = tf.summary.merge([
        tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate),
        tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss),
    ])
    gen_summary_op = tf.summary.merge([
        tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate),
        tf.summary.scalar(gen_t.gan_loss.name, gen_t.gan_loss),
    ])
    print(type(dis_summary_op), type(gen_summary_op))
    init_op = tf.global_variables_initializer()

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size)
    user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d

    ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size)
    user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g

    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        dis_t.saver.restore(sess, flags.dis_model_ckpt)
        gen_t.saver.restore(sess, flags.gen_model_ckpt)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
            print('init hit=%.4f' % (hit_v))

            batch_d, batch_g = -1, -1
            for epoch in range(flags.num_epoch):
                for dis_epoch in range(flags.num_dis_epoch):
                    print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                    num_batch_d = math.ceil(train_data_size / flags.batch_size)
                    for _ in range(num_batch_d):
                        batch_d += 1
                        image_np_d, label_dat_d = sess.run(
                            [image_bt_d, label_bt_d])
                        feed_dict = {gen_t.image_ph: image_np_d}
                        label_gen_d, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_np_d, label_np_d = utils.gan_dis_sample(
                            flags, label_dat_d, label_gen_d)
                        feed_dict = {
                            dis_t.image_ph: image_np_d,
                            dis_t.sample_ph: sample_np_d,
                            dis_t.dis_label_ph: label_np_d,
                        }
                        _, summary_d = sess.run(
                            [dis_t.gan_update, dis_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_d, batch_d)

                for gen_epoch in range(flags.num_gen_epoch):
                    print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                    num_batch_g = math.ceil(train_data_size / flags.batch_size)
                    for _ in range(num_batch_g):
                        batch_g += 1
                        image_np_g, label_dat_g = sess.run(
                            [image_bt_g, label_bt_g])
                        feed_dict = {gen_t.image_ph: image_np_g}
                        label_gen_g, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_np_g = utils.generate_label(
                            flags, label_dat_g, label_gen_g)
                        feed_dict = {
                            dis_t.image_ph: image_np_g,
                            dis_t.sample_ph: sample_np_g,
                        }
                        reward_np_g, = sess.run([dis_t.rewards],
                                                feed_dict=feed_dict)
                        feed_dict = {
                            gen_t.image_ph: image_np_g,
                            gen_t.sample_ph: sample_np_g,
                            gen_t.reward_ph: reward_np_g,
                        }
                        _, summary_g = sess.run(
                            [gen_t.gan_update, gen_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_g, batch_g)

                        # if (batch_g + 1) % eval_interval != 0:
                        #   continue
                        # hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v)
                        # tot_time = time.time() - start
                        # print('#%08d hit=%.4f %06ds' % (batch_g, hit_v, int(tot_time)))
                        # if hit_v <= best_hit_v:
                        #   continue
                        # best_hit_v = hit_v
                        # print('best hit=%.4f' % (best_hit_v))
                hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                tot_time = time.time() - start
                print('#%03d curbst=%.4f %.0fs' % (epoch, hit_v, tot_time))
                figure_data.append((epoch, hit_v))
                if hit_v <= best_hit_v:
                    continue
                best_hit_v = hit_v
    print('bsthit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.gan_figure_data))
    fout = open(flags.gan_figure_data, 'w')
    for epoch, hit_v in figure_data:
        fout.write('%d\t%.4f\n' % (epoch, hit_v))
    fout.close()
示例#8
0
def main(_):
  bst_gen_acc, bst_tch_acc, bst_eph = 0.0, 0.0, 0
  utils.create_if_nonexist(flags.gradient_dir)
  if flags.log_accuracy:
    acc_history = []
  if flags.evaluate_tch:
    tch_history = []
  with tf.train.MonitoredTrainingSession() as sess:
    sess.run(init_op)
    tn_dis.saver.restore(sess, flags.dis_model_ckpt)
    tn_gen.saver.restore(sess, flags.gen_model_ckpt)
    tn_tch.saver.restore(sess, flags.tch_model_ckpt)

    feed_dict = {
      vd_dis.image_ph:dis_mnist.test.images,
      vd_dis.hard_label_ph:dis_mnist.test.labels,
    }
    ini_dis = sess.run(vd_dis.accuracy, feed_dict=feed_dict)
    feed_dict = {
      vd_gen.image_ph:gen_mnist.test.images,
      vd_gen.hard_label_ph:gen_mnist.test.labels,
    }
    ini_gen = sess.run(vd_gen.accuracy, feed_dict=feed_dict)
    print('ini dis=%.4f ini gen=%.4f' % (ini_dis, ini_gen))
    # exit()

    start = time.time()
    batch_d, batch_g, batch_t = -1, -1, -1
    gumbel_times = (math.log(flags.gumbel_end_temperature / flags.gumbel_temperature) 
        / math.log(flags.gumbel_temperature_decay_factor))
    for epoch in range(flags.num_epoch):
      for dis_epoch in range(flags.num_dis_epoch):
        # print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
        # num_batch_d = math.ceil(tn_size / flags.batch_size)
        # for _ in range(num_batch_d):
        #   image_d, label_dat_d = dis_mnist.train.next_batch(flags.batch_size)
        for image_d, label_dat_d in dis_datagen.generate(batch_size=flags.batch_size):
          batch_d += 1

          feed_dict = {tn_gen.image_ph:image_d}
          label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict)
          sample_gen_d, gen_label_d = utils.gan_dis_sample(flags, label_dat_d, label_gen_d)

          feed_dict = {tn_tch.image_ph:image_d}
          label_tch_d = sess.run(tn_tch.labels, feed_dict=feed_dict)
          sample_tch_d, tch_label_d = utils.gan_dis_sample(flags, label_dat_d, label_tch_d)
          
          feed_dict = {
            tn_dis.image_ph:image_d,
            tn_dis.gen_sample_ph:sample_gen_d,
            tn_dis.gen_label_ph:gen_label_d,
            tn_dis.tch_sample_ph:sample_tch_d,
            tn_dis.tch_label_ph:tch_label_d,
          }
          sess.run(tn_dis.gan_update, feed_dict=feed_dict)

      for tch_epoch in range(flags.num_tch_epoch):
        # num_batch_t = math.ceil(tn_size / flags.batch_size)
        # for _ in range(num_batch_t):
        #   image_t, label_dat_t = tch_mnist.train.next_batch(flags.batch_size)
        for image_t, label_dat_t in tch_datagen.generate(batch_size=flags.batch_size):
          batch_t += 1

          feed_dict = {tn_tch.image_ph:image_t}
          label_tch_t = sess.run(tn_tch.labels, feed_dict=feed_dict)
          sample_t = utils.generate_label(flags, label_dat_t, label_tch_t)
          feed_dict = {
            tn_dis.image_ph:image_t,
            tn_dis.tch_sample_ph:sample_t,
          }
          reward_t = sess.run(tn_dis.tch_rewards, feed_dict=feed_dict)

          feed_dict = {vd_gen.image_ph:image_t}
          soft_logit_t = sess.run(vd_gen.logits, feed_dict=feed_dict)
          feed_dict = {
            tn_tch.image_ph:image_t,
            tn_tch.sample_ph:sample_t,
            tn_tch.reward_ph:reward_t,
            tn_tch.hard_label_ph:label_dat_t,
            tn_tch.soft_logit_ph:soft_logit_t,
          }
          
          sess.run(tn_tch.kdgan_update, feed_dict=feed_dict)

          if not flags.evaluate_tch:
            continue
          if (batch_t + 1) % eval_interval != 0:
            continue
          feed_dict = {
            vd_tch.image_ph:gen_mnist.test.images,
            vd_tch.hard_label_ph:gen_mnist.test.labels,
          }
          tch_acc = sess.run(vd_tch.accuracy, feed_dict=feed_dict)
          bst_tch_acc = max(tch_acc, bst_tch_acc)
          print('#%08d tchcur=%.4f tchbst=%.4f' % (batch_t, tch_acc, bst_tch_acc))
          tch_history.append(tch_acc)

      #### gumbel softmax
      if flags.enable_gumbel:
        if (epoch + 1) % max(int(flags.num_epoch / gumbel_times), 1) == 0:
          sess.run(tn_gen.gt_update)

      for gen_epoch in range(flags.num_gen_epoch):
        batch = -1
        # num_batch_g = math.ceil(tn_size / flags.batch_size)
        # for _ in range(num_batch_g):
        #   image_g, label_dat_g = gen_mnist.train.next_batch(flags.batch_size)
        for image_g, label_dat_g in gen_datagen.generate(batch_size=flags.batch_size):
          batch_g += 1
          batch += 1
          epk_bat = '%d.%d' % (epoch*flags.num_gen_epoch+gen_epoch, batch)
          ggrads_file = path.join(flags.gradient_dir, 'kdgan_ggrads.%s.p' % epk_bat)
          kgrads_file = path.join(flags.gradient_dir, 'kdgan_kgrads.%s.p' % epk_bat)

          feed_dict = {tn_gen.image_ph:image_g}

          if not flags.enable_gumbel:
            label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict)
          else:
            label_gen_g = sess.run(tn_gen.gumbel_labels, feed_dict=feed_dict)

          sample_g = utils.generate_label(flags, label_dat_g, label_gen_g)
          feed_dict = {
            tn_dis.image_ph:image_g,
            tn_dis.gen_sample_ph:sample_g,
          }
          reward_g = sess.run(tn_dis.gen_rewards, feed_dict=feed_dict)
          # reward_g[reward_g>0.5] = 0.7
          # reward_g[reward_g<0.5] = 0.3

          feed_dict = {vd_tch.image_ph:image_g}
          soft_logit_g = sess.run(vd_tch.logits, feed_dict=feed_dict)
          # print(sample_g.shape, reward_g.shape, image_g.shape, soft_logit_g.shape)
          # exit()

          feed_dict = {
            tn_gen.image_ph:image_g,
            tn_gen.sample_ph:sample_g,
            tn_gen.reward_ph:reward_g,
            tn_gen.hard_label_ph:label_dat_g,
            tn_gen.soft_logit_ph:soft_logit_g,
          }
          # sess.run(tn_gen.kdgan_update, feed_dict=feed_dict)
          if flags.log_gradient:
            fetches = [tn_gen.kdgan_ggrads, tn_gen.kdgan_kgrads, tn_gen.kdgan_update]
            kdgan_ggrads, kdgan_kgrads, _ = sess.run(fetches, feed_dict=feed_dict)
            pickle.dump(kdgan_ggrads, open(ggrads_file, 'wb'))
            pickle.dump(kdgan_kgrads, open(kgrads_file, 'wb'))
          else:
            sess.run(tn_gen.kdgan_update, feed_dict=feed_dict)

          if flags.log_accuracy:
            feed_dict = {
              vd_gen.image_ph:gen_mnist.test.images,
              vd_gen.hard_label_ph:gen_mnist.test.labels,
            }
            acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict)
            acc_history.append(acc)
            if (batch_g + 1) % eval_interval != 0:
              continue
          else:
            if (batch_g + 1) % eval_interval != 0:
              continue
            feed_dict = {
              vd_gen.image_ph:gen_mnist.test.images,
              vd_gen.hard_label_ph:gen_mnist.test.labels,
            }
            acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict)

          if acc > bst_gen_acc:
            bst_gen_acc = max(acc, bst_gen_acc)
            bst_eph = epoch
          tot_time = time.time() - start
          global_step = sess.run(tn_gen.global_step)
          # avg_time = (tot_time / global_step) * (tn_size / flags.batch_size)
          if flags.evaluate_tch:
            gen_tch_pct =100 * bst_gen_acc / bst_tch_acc
            print('#%08d/%08d gencur=%.4f genbst=%.4f (%.2f) tot=%.0fs' % 
                (batch_g, tot_batch, acc, bst_gen_acc, gen_tch_pct, tot_time))
          else:
            print('#%08d/%08d gencur=%.4f genbst=%.4f tot=%.0fs' % 
                (batch_g, tot_batch, acc, bst_gen_acc, tot_time))

          stdout.flush()
          if acc <= bst_gen_acc:
            continue
          # save gen parameters if necessary
    gumbel_temperature = sess.run(tn_gen.gumbel_temperature)
    print('gumbel_temperature=%.4f' % gumbel_temperature)
  tot_time = time.time() - start
  bst_gen_acc *= 100
  bst_eph += 1
  print('#mnist=%d kdgan@%d=%.2f et=%.0fs' % (tn_size, bst_eph, bst_gen_acc, tot_time))

  if flags.log_accuracy:
    utils.create_pardir(flags.acc_file)
    pickle.dump(acc_history, open(flags.acc_file, 'wb'))

  if flags.evaluate_tch:
    utils.create_pardir(flags.tch_file)
    pickle.dump(tch_history, open(flags.tch_file, 'wb'))
示例#9
0
def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    dis_summary_op = tf.summary.merge([
        tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate),
        tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss),
    ])
    gen_summary_op = tf.summary.merge([
        tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate),
        tf.summary.scalar(gen_t.kdgan_loss.name, gen_t.kdgan_loss),
    ])
    tch_summary_op = tf.summary.merge([
        tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate),
        tf.summary.scalar(tch_t.kdgan_loss.name, tch_t.kdgan_loss),
    ])
    init_op = tf.global_variables_initializer()

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size)
    user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d

    ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size)
    user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t

    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        dis_t.saver.restore(sess, flags.dis_model_ckpt)
        gen_t.saver.restore(sess, flags.gen_model_ckpt)
        tch_t.saver.restore(sess, flags.tch_model_ckpt)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
            tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v)
            print('hit gen=%.4f tch=%.4f' % (gen_hit, tch_hit))

            batch_d, batch_g, batch_t = -1, -1, -1
            for epoch in range(flags.num_epoch):
                for dis_epoch in range(flags.num_dis_epoch):
                    print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                    for _ in range(num_batch_per_epoch):
                        #continue
                        batch_d += 1
                        image_d, text_d, label_dat_d = sess.run(
                            [image_bt_d, text_bt_d, label_bt_d])

                        feed_dict = {gen_t.image_ph: image_d}
                        label_gen_d, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        # print('gen label', label_gen_d.shape)
                        feed_dict = {
                            tch_t.text_ph: text_d,
                            tch_t.image_ph: image_d
                        }
                        label_tch_d, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        # print('tch label', label_tch_d.shape)

                        sample_d, label_d = utils.kdgan_dis_sample(
                            flags, label_dat_d, label_gen_d, label_tch_d)
                        # print(sample_d.shape, label_d.shape)

                        feed_dict = {
                            dis_t.image_ph: image_d,
                            dis_t.sample_ph: sample_d,
                            dis_t.dis_label_ph: label_d,
                        }
                        _, summary_d = sess.run(
                            [dis_t.gan_update, dis_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_d, batch_d)

                for tch_epoch in range(flags.num_tch_epoch):
                    print('epoch %03d tch_epoch %03d' % (epoch, tch_epoch))
                    for _ in range(num_batch_per_epoch):
                        #continue
                        batch_t += 1
                        image_t, text_t, label_dat_t = sess.run(
                            [image_bt_t, text_bt_t, label_bt_t])

                        feed_dict = {
                            tch_t.text_ph: text_t,
                            tch_t.image_ph: image_t
                        }
                        label_tch_t, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        sample_t = utils.generate_label(
                            flags, label_dat_t, label_tch_t)
                        feed_dict = {
                            dis_t.image_ph: image_t,
                            dis_t.sample_ph: sample_t,
                        }
                        reward_t, = sess.run([dis_t.rewards],
                                             feed_dict=feed_dict)

                        feed_dict = {
                            gen_t.image_ph: image_t,
                        }
                        label_gen_g = sess.run(gen_t.logits,
                                               feed_dict=feed_dict)
                        #print(len(label_dat_t), len(label_dat_t[0]))
                        #exit()
                        feed_dict = {
                            tch_t.text_ph: text_t,
                            tch_t.image_ph: image_t,
                            tch_t.sample_ph: sample_t,
                            tch_t.reward_ph: reward_t,
                            tch_t.hard_label_ph: label_dat_t,
                            tch_t.soft_label_ph: label_gen_g,
                        }

                        _, summary_t, tch_kdgan_loss = sess.run(
                            [
                                tch_t.kdgan_update, tch_summary_op,
                                tch_t.kdgan_loss
                            ],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_t, batch_t)
                        #print("teacher kdgan loss:", tch_kdgan_loss)

                for gen_epoch in range(flags.num_gen_epoch):
                    print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                    for _ in range(num_batch_per_epoch):
                        batch_g += 1
                        image_g, text_g, label_dat_g = sess.run(
                            [image_bt_g, text_bt_g, label_bt_g])

                        feed_dict = {
                            tch_t.text_ph: text_g,
                            tch_t.image_ph: image_g
                        }
                        label_tch_g, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        # print('tch label {}'.format(label_tch_g.shape))

                        feed_dict = {gen_t.image_ph: image_g}
                        label_gen_g, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_g = utils.generate_label(
                            flags, label_dat_g, label_gen_g)
                        feed_dict = {
                            dis_t.image_ph: image_g,
                            dis_t.sample_ph: sample_g,
                        }
                        reward_g, = sess.run([dis_t.rewards],
                                             feed_dict=feed_dict)

                        feed_dict = {
                            gen_t.image_ph: image_g,
                            gen_t.hard_label_ph: label_dat_g,
                            gen_t.soft_label_ph: label_tch_g,
                            gen_t.sample_ph: sample_g,
                            gen_t.reward_ph: reward_g,
                        }
                        _, summary_g = sess.run(
                            [gen_t.kdgan_update, gen_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_g, batch_g)

                        # if (batch_g + 1) % eval_interval != 0:
                        #     continue
                        # gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                        # tot_time = time.time() - start
                        # print('#%08d hit=%.4f %06ds' % (batch_g, gen_hit, int(tot_time)))
                        # if gen_hit <= best_hit_v:
                        #   continue
                        # best_hit_v = gen_hit
                        # print('best hit=%.4f' % (best_hit_v))
                gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v)

                tot_time = time.time() - start
                print('#%03d curgen=%.4f curtch=%.4f %.0fs' %
                      (epoch, gen_hit, tch_hit, tot_time))
                figure_data.append((epoch, gen_hit, tch_hit))
                if gen_hit <= best_hit_v:
                    continue
                best_hit_v = gen_hit
                print("epoch ", epoch + 1, ":, new best validation hit:",
                      best_hit_v, "saving...")
                gen_t.saver.save(sess,
                                 flags.kdgan_model_ckpt,
                                 global_step=epoch + 1)
                print("finish saving")

    print('best hit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.kdgan_figure_data))
    fout = open(flags.kdgan_figure_data, 'w')
    for epoch, gen_hit, tch_hit in figure_data:
        fout.write('%d\t%.4f\t%.4f\n' % (epoch, gen_hit, tch_hit))
    fout.close()