예제 #1
0
파일: pretrain.py 프로젝트: xiaojiew1/KDGAN
def main(_):
    print('#label={}'.format(config.num_label))
    gen_t = GEN(flags, is_training=True)
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    gen_v = GEN(flags, is_training=False)

    # train_filename = 'yfcc10k_{}.train.tfrecord'.format(flags.model_name)
    # train_tfrecord = path.join(config.tfrecord_dir, train_filename)
    # valid_filename = 'yfcc10k_{}.valid.tfrecord'.format(flags.model_name)
    # valid_tfrecord = path.join(config.tfrecord_dir, valid_filename)

    data_sources_t = utils.get_data_sources(flags, config.train_file, 500)
    data_sources_v = utils.get_data_sources(flags, config.valid_file, 1)
    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    # check_ts_list(ts_list_t)
    # check_ts_list(ts_list_v)
    bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
    # check_bt_list(bt_list_t, config.train_batch_size)
    # check_bt_list(bt_list_v, config.valid_batch_size)

    user_bt_t, image_bt_t, text_bt_t, label_bt_t, image_file_bt_t = bt_list_t
    user_bt_v, image_bt_v, text_bt_v, label_bt_v, image_file_bt_v = bt_list_v

    best_hit_v = -np.inf
    init_op = tf.global_variables_initializer()
    start = time.time()
    with tf.Session() as sess:
        writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph())
        sess.run(init_op)
        with slim.queues.QueueRunners(sess):
            for batch_t in range(num_batch_t):
                image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t])
                feed_dict = {gen_t.image_ph:image_np_t, gen_t.label_ph:label_np_t}
                _, summary = sess.run([gen_t.train_op, gen_t.summary_op], feed_dict=feed_dict)
                writer.add_summary(summary, batch_t)

                if (batch_t + 1) % int(config.train_data_size / config.train_batch_size) != 0:
                    continue

                hit_v = []
                image_file_v = set()
                for batch_v in range(num_batch_v):
                    image_np_v, label_np_v, image_file_np_v = sess.run([image_bt_v, label_bt_v, image_file_bt_v])
                    feed_dict = {gen_v.image_ph:image_np_v}
                    logit_np_v, = sess.run([gen_v.logits], feed_dict=feed_dict)
                    for image_file in image_file_np_v:
                        image_file_v.add(image_file)
                    hit_bt = compute_hit(logit_np_v, label_np_v, flags.cutoff)
                    hit_v.append(hit_bt)
                hit_v = np.mean(hit_v)

                total_time = time.time() - start
                avg_batch = total_time / (batch_t + 1)
                avg_epoch = avg_batch * (config.train_data_size / config.train_batch_size)
                s = '{0} hit={1:.4f} tot={2:.0f}s avg={3:.0f}s'
                s = s.format(batch_t, hit_v, total_time, avg_epoch)
                print(s)
예제 #2
0
def main(_):
  global_step = tf.train.create_global_step()
  tch_t = TCH(flags, is_training=True)
  scope = tf.get_variable_scope()
  scope.reuse_variables()
  tch_v = TCH(flags, is_training=False)

  for variable in tf.trainable_variables():
    num_params = 1
    for dim in variable.shape:
      num_params *= dim.value
    print('{}\t({} params)'.format(variable.name, num_params))

  data_sources_t = utils.get_data_sources(flags, is_training=True, single=True)
  data_sources_v = utils.get_data_sources(flags, is_training=False)
  print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v)))

  ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
  ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
  bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size)
  bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
  user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t
  user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v

  best_hit_v = -np.inf
  init_op = tf.global_variables_initializer()
  start = time.time()
  with tf.Session() as sess:
    writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph())
    sess.run(init_op)
    with slim.queues.QueueRunners(sess):
      for batch_t in range(num_batch_t):
        text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t])
        feed_dict = {tch_t.text_ph:text_np_t, tch_t.label_ph:label_np_t}
        _, summary = sess.run([tch_t.train_op, tch_t.summary_op], feed_dict=feed_dict)
        writer.add_summary(summary, batch_t)

        if (batch_t + 1) % eval_interval != 0:
            continue

        hit_v = []
        for batch_v in range(num_batch_v):
          text_np_v, label_np_v = sess.run([text_bt_v, label_bt_v])
          feed_dict = {tch_v.text_ph:text_np_v}
          logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict)
          hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff)
          hit_v.append(hit_bt)
        hit_v = np.mean(hit_v)

        tot_time = time.time() - start
        print('#{0} hit={1:.4f} {2:.0f}s'.format(batch_t, hit_v, tot_time))

        if hit_v < best_hit_v:
          continue
        best_hit_v = hit_v
        ckpt_file = path.join(config.ckpt_dir, '%s.ckpt' % flags.tch_model)
        tch_t.saver.save(sess, ckpt_file)
  print('best hit={0:.4f}'.format(best_hit_v))
예제 #3
0
 def get_batch(self, flags, is_training):
     if is_training:
         single = False
         stage = 'train'
         shuffle = True
     else:
         single = True
         stage = 'valid'
         shuffle = False
     data_sources = utils.get_data_sources(flags,
                                           is_training=is_training,
                                           single=single)
     print('#tfrecord=%d for %s' % (len(data_sources), stage))
     ts_list = utils.decode_tfrecord(flags, data_sources, shuffle=shuffle)
     bt_list = utils.generate_batch(ts_list, flags.batch_size)
     return bt_list
예제 #4
0
def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('{}\t({} params)'.format(variable.name, num_params))

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_d = utils.generate_batch(ts_list_d, config.train_batch_size)
    user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d

    ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_g = utils.generate_batch(ts_list_g, config.train_batch_size)
    user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g

    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)

    best_hit_v = -np.inf
    start = time.time()
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        dis_t.saver.restore(sess, flags.dis_model_ckpt)
        gen_t.saver.restore(sess, flags.gen_model_ckpt)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            image_hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v)
            print('init\thit={0:.4f}'.format(image_hit_v))

            batch_d, batch_g = -1, -1
            for epoch in range(flags.num_epoch):
                for dis_epoch in range(flags.num_dis_epoch):
                    print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                    num_batch_d = math.ceil(train_data_size /
                                            config.train_batch_size)
                    for _ in range(num_batch_d):
                        batch_d += 1
                        image_np_d, label_dat_d = sess.run(
                            [image_bt_d, label_bt_d])
                        # print(image_np_d.shape, label_dat_d.shape)
                        feed_dict = {gen_t.image_ph: image_np_d}
                        label_gen_d, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        # print(label_gen_d.shape, type(label_gen_d))
                        sample_np_d, label_np_d = gan_dis_sample(
                            label_dat_d, label_gen_d)
                        feed_dict = {
                            dis_t.image_ph: image_np_d,
                            dis_t.sample_ph: sample_np_d,
                            dis_t.label_ph: label_np_d,
                        }
                        sess.run(dis_t.train_op, feed_dict=feed_dict)
                        # _, summary = sess.run([dis_t.train_op, dis_t.summary_op], feed_dict=feed_dict)
                        # writer.add_summary(summary, batch_d)

                        # if (batch_d + 1) % eval_interval != 0:
                        #   continue
                        # image_hit_v = utils.evaluate(flags, sess, dis_v, bt_list_v)
                        # tot_time = time.time() - start
                        # print('#%d hit=%.4f (%.0fs)' % (batch_d, image_hit_v, tot_time))

                for gen_epoch in range(flags.num_gen_epoch):
                    print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                    num_batch_g = math.ceil(train_data_size /
                                            config.train_batch_size)
                    for _ in range(num_batch_g):
                        batch_g += 1
                        image_np_g, label_dat_g = sess.run(
                            [image_bt_g, label_bt_g])
                        # print(image_np_g.shape, label_dat_g.shape)
                        feed_dict = {gen_t.image_ph: image_np_g}
                        label_gen_g, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_np_g = generate_label(label_dat_g, label_gen_g)
                        # for sample in sample_np_g:
                        #   print(sample)
                        feed_dict = {
                            dis_t.image_ph: image_np_g,
                            dis_t.sample_ph: sample_np_g,
                        }
                        reward_np_g, = sess.run([dis_t.rewards],
                                                feed_dict=feed_dict)
                        # for sample, reward in zip(sample_np_g, reward_np_g):
                        #   batch = sample[0]
                        #   label = [i for i, l in enumerate(label_dat_g[batch]) if l != 0.0]
                        #   print(sample, reward, label)
                        # print('%.2f %.2f' % (reward_np_g.min(), reward_np_g.max()))
                        # input()
                        feed_dict = {
                            gen_t.image_ph: image_np_g,
                            gen_t.sample_ph: sample_np_g,
                            gen_t.reward_ph: reward_np_g,
                        }
                        sess.run([gen_t.gan_train_op], feed_dict=feed_dict)
                        if (batch_g + 1) % eval_interval != 0:
                            continue
                        image_hit_v = utils.evaluate(flags, sess, gen_v,
                                                     bt_list_v)
                        tot_time = time.time() - start
                        print('#%d hit=%.4f (%.0fs)' %
                              (batch_g, image_hit_v, tot_time))
                        if image_hit_v > best_hit_v:
                            best_hit_v = image_hit_v
                            print('best hit=%.4f (%.0fs)' %
                                  (image_hit_v, tot_time))
                    # break

            print('best\thit={0:.4f}'.format(best_hit_v))
예제 #5
0
def main(_):
    gen_t = GEN(flags, is_training=True)
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    gen_v = GEN(flags, is_training=False)

    tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate)
    tf.summary.scalar(gen_t.pre_loss.name, gen_t.pre_loss)
    summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t
    user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            for batch_t in range(num_batch_t):
                image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t])
                feed_dict = {
                    gen_t.image_ph: image_np_t,
                    gen_t.hard_label_ph: label_np_t
                }
                _, summary = sess.run([gen_t.pre_update, summary_op],
                                      feed_dict=feed_dict)
                writer.add_summary(summary, batch_t)

                batch = batch_t + 1
                remain = (batch * flags.batch_size) % train_data_size
                epoch = (batch * flags.batch_size) // train_data_size
                if remain == 0:
                    pass
                    # print('%d\t%d\t%d' % (epoch, batch, remain))
                elif (train_data_size - remain) < flags.batch_size:
                    epoch = epoch + 1
                    # print('%d\t%d\t%d' % (epoch, batch, remain))
                else:
                    continue
                # if (batch_t + 1) % eval_interval != 0:
                #     continue

                hit_v = []
                for batch_v in range(num_batch_v):
                    image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v])
                    feed_dict = {gen_v.image_ph: image_np_v}
                    logit_np_v, = sess.run([gen_v.logits], feed_dict=feed_dict)
                    hit_bt = metric.compute_hit(logit_np_v, label_np_v,
                                                flags.cutoff)
                    hit_v.append(hit_bt)
                hit_v = np.mean(hit_v)

                figure_data.append((epoch, hit_v, batch_t))

                if hit_v < best_hit_v:
                    continue
                tot_time = time.time() - start
                best_hit_v = hit_v
                print('#%03d curbst=%.4f time=%.0fs' %
                      (epoch, hit_v, tot_time))
                gen_t.saver.save(sess, flags.gen_model_ckpt)
    print('bsthit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.gen_figure_data))
    fout = open(flags.gen_figure_data, 'w')
    for epoch, hit_v, batch_t in figure_data:
        fout.write('%d\t%.4f\t%d\n' % (epoch, hit_v, batch_t))
    fout.close()
예제 #6
0
def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate)
    tf.summary.scalar(gen_t.kd_loss.name, gen_t.kd_loss)
    summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t

    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        gen_t.saver.restore(sess, flags.gen_model_ckpt)
        tch_t.saver.restore(sess, flags.tch_model_ckpt)
        with slim.queues.QueueRunners(sess):
            hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v)
            print('init hit=%.4f' % (hit_v))

            for batch_t in range(num_batch_t):
                image_np_t, text_np_t, hard_labels = sess.run(
                    [image_bt_t, text_bt_t, label_bt_t])
                # print('hard labels:\t{}'.format(hard_labels.shape))
                # print(np.argsort(-hard_labels[0,:])[:10])

                feed_dict = {tch_t.text_ph: text_np_t}
                soft_labels, = sess.run([tch_t.labels], feed_dict=feed_dict)
                # print('soft labels:\t{}'.format(soft_labels.shape))
                # print(np.argsort(-soft_labels[0,:])[:10])

                feed_dict = {
                    gen_t.image_ph: image_np_t,
                    gen_t.hard_label_ph: hard_labels,
                    gen_t.soft_label_ph: soft_labels,
                }
                _, summary = sess.run([gen_t.kd_update, summary_op],
                                      feed_dict=feed_dict)
                writer.add_summary(summary, batch_t)

                if (batch_t + 1) % eval_interval != 0:
                    continue
                hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v)
                tot_time = time.time() - start
                print('#%08d hit=%.4f %06ds' % (batch_t, hit_v, int(tot_time)))
                if hit_v <= best_hit_v:
                    continue
                best_hit_v = hit_v
                print('best hit=%.4f' % (best_hit_v))
    print('best hit=%.4f' % (best_hit_v))
예제 #7
0
def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    dis_summary_op = tf.summary.merge([
        tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate),
        tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss),
    ])
    gen_summary_op = tf.summary.merge([
        tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate),
        tf.summary.scalar(gen_t.gan_loss.name, gen_t.gan_loss),
    ])
    print(type(dis_summary_op), type(gen_summary_op))
    init_op = tf.global_variables_initializer()

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size)
    user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d

    ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size)
    user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g

    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        dis_t.saver.restore(sess, flags.dis_model_ckpt)
        gen_t.saver.restore(sess, flags.gen_model_ckpt)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
            print('init hit=%.4f' % (hit_v))

            batch_d, batch_g = -1, -1
            for epoch in range(flags.num_epoch):
                for dis_epoch in range(flags.num_dis_epoch):
                    print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                    num_batch_d = math.ceil(train_data_size / flags.batch_size)
                    for _ in range(num_batch_d):
                        batch_d += 1
                        image_np_d, label_dat_d = sess.run(
                            [image_bt_d, label_bt_d])
                        feed_dict = {gen_t.image_ph: image_np_d}
                        label_gen_d, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_np_d, label_np_d = utils.gan_dis_sample(
                            flags, label_dat_d, label_gen_d)
                        feed_dict = {
                            dis_t.image_ph: image_np_d,
                            dis_t.sample_ph: sample_np_d,
                            dis_t.dis_label_ph: label_np_d,
                        }
                        _, summary_d = sess.run(
                            [dis_t.gan_update, dis_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_d, batch_d)

                for gen_epoch in range(flags.num_gen_epoch):
                    print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                    num_batch_g = math.ceil(train_data_size / flags.batch_size)
                    for _ in range(num_batch_g):
                        batch_g += 1
                        image_np_g, label_dat_g = sess.run(
                            [image_bt_g, label_bt_g])
                        feed_dict = {gen_t.image_ph: image_np_g}
                        label_gen_g, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_np_g = utils.generate_label(
                            flags, label_dat_g, label_gen_g)
                        feed_dict = {
                            dis_t.image_ph: image_np_g,
                            dis_t.sample_ph: sample_np_g,
                        }
                        reward_np_g, = sess.run([dis_t.rewards],
                                                feed_dict=feed_dict)
                        feed_dict = {
                            gen_t.image_ph: image_np_g,
                            gen_t.sample_ph: sample_np_g,
                            gen_t.reward_ph: reward_np_g,
                        }
                        _, summary_g = sess.run(
                            [gen_t.gan_update, gen_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_g, batch_g)

                        # if (batch_g + 1) % eval_interval != 0:
                        #   continue
                        # hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v)
                        # tot_time = time.time() - start
                        # print('#%08d hit=%.4f %06ds' % (batch_g, hit_v, int(tot_time)))
                        # if hit_v <= best_hit_v:
                        #   continue
                        # best_hit_v = hit_v
                        # print('best hit=%.4f' % (best_hit_v))
                hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                tot_time = time.time() - start
                print('#%03d curbst=%.4f %.0fs' % (epoch, hit_v, tot_time))
                figure_data.append((epoch, hit_v))
                if hit_v <= best_hit_v:
                    continue
                best_hit_v = hit_v
    print('bsthit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.gan_figure_data))
    fout = open(flags.gan_figure_data, 'w')
    for epoch, hit_v in figure_data:
        fout.write('%d\t%.4f\n' % (epoch, hit_v))
    fout.close()
예제 #8
0
def main(_):
    dis_t = DIS(flags, is_training=True)
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    dis_v = DIS(flags, is_training=False)

    tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate)
    tf.summary.scalar(dis_t.pre_loss.name, dis_t.pre_loss)
    summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t
    user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v

    start = time.time()
    best_hit_v = -np.inf
    with tf.Session() as sess:
        sess.run(init_op)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            for batch_t in range(num_batch_t):
                image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t])
                feed_dict = {
                    dis_t.image_ph: image_np_t,
                    dis_t.hard_label_ph: label_np_t
                }
                _, summary = sess.run([dis_t.pre_update, summary_op],
                                      feed_dict=feed_dict)
                writer.add_summary(summary, batch_t)

                if (batch_t + 1) % eval_interval != 0:
                    continue

                hit_v = []
                for batch_v in range(num_batch_v):
                    image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v])
                    feed_dict = {dis_v.image_ph: image_np_v}
                    logit_np_v, = sess.run([dis_v.logits], feed_dict=feed_dict)
                    hit_bt = metric.compute_hit(logit_np_v, label_np_v,
                                                flags.cutoff)
                    hit_v.append(hit_bt)
                hit_v = np.mean(hit_v)

                tot_time = time.time() - start
                print('#%08d hit=%.4f %06ds' % (batch_t, hit_v, int(tot_time)))

                if hit_v < best_hit_v:
                    continue
                best_hit_v = hit_v
                dis_t.saver.save(sess, flags.dis_model_ckpt)
    print('best hit=%.4f' % (best_hit_v))
예제 #9
0
def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    dis_summary_op = tf.summary.merge([
        tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate),
        tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss),
    ])
    gen_summary_op = tf.summary.merge([
        tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate),
        tf.summary.scalar(gen_t.kdgan_loss.name, gen_t.kdgan_loss),
    ])
    tch_summary_op = tf.summary.merge([
        tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate),
        tf.summary.scalar(tch_t.kdgan_loss.name, tch_t.kdgan_loss),
    ])
    init_op = tf.global_variables_initializer()

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size)
    user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d

    ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size)
    user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t

    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        dis_t.saver.restore(sess, flags.dis_model_ckpt)
        gen_t.saver.restore(sess, flags.gen_model_ckpt)
        tch_t.saver.restore(sess, flags.tch_model_ckpt)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
            tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v)
            print('hit gen=%.4f tch=%.4f' % (gen_hit, tch_hit))

            batch_d, batch_g, batch_t = -1, -1, -1
            for epoch in range(flags.num_epoch):
                for dis_epoch in range(flags.num_dis_epoch):
                    print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                    for _ in range(num_batch_per_epoch):
                        #continue
                        batch_d += 1
                        image_d, text_d, label_dat_d = sess.run(
                            [image_bt_d, text_bt_d, label_bt_d])

                        feed_dict = {gen_t.image_ph: image_d}
                        label_gen_d, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        # print('gen label', label_gen_d.shape)
                        feed_dict = {
                            tch_t.text_ph: text_d,
                            tch_t.image_ph: image_d
                        }
                        label_tch_d, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        # print('tch label', label_tch_d.shape)

                        sample_d, label_d = utils.kdgan_dis_sample(
                            flags, label_dat_d, label_gen_d, label_tch_d)
                        # print(sample_d.shape, label_d.shape)

                        feed_dict = {
                            dis_t.image_ph: image_d,
                            dis_t.sample_ph: sample_d,
                            dis_t.dis_label_ph: label_d,
                        }
                        _, summary_d = sess.run(
                            [dis_t.gan_update, dis_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_d, batch_d)

                for tch_epoch in range(flags.num_tch_epoch):
                    print('epoch %03d tch_epoch %03d' % (epoch, tch_epoch))
                    for _ in range(num_batch_per_epoch):
                        #continue
                        batch_t += 1
                        image_t, text_t, label_dat_t = sess.run(
                            [image_bt_t, text_bt_t, label_bt_t])

                        feed_dict = {
                            tch_t.text_ph: text_t,
                            tch_t.image_ph: image_t
                        }
                        label_tch_t, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        sample_t = utils.generate_label(
                            flags, label_dat_t, label_tch_t)
                        feed_dict = {
                            dis_t.image_ph: image_t,
                            dis_t.sample_ph: sample_t,
                        }
                        reward_t, = sess.run([dis_t.rewards],
                                             feed_dict=feed_dict)

                        feed_dict = {
                            gen_t.image_ph: image_t,
                        }
                        label_gen_g = sess.run(gen_t.logits,
                                               feed_dict=feed_dict)
                        #print(len(label_dat_t), len(label_dat_t[0]))
                        #exit()
                        feed_dict = {
                            tch_t.text_ph: text_t,
                            tch_t.image_ph: image_t,
                            tch_t.sample_ph: sample_t,
                            tch_t.reward_ph: reward_t,
                            tch_t.hard_label_ph: label_dat_t,
                            tch_t.soft_label_ph: label_gen_g,
                        }

                        _, summary_t, tch_kdgan_loss = sess.run(
                            [
                                tch_t.kdgan_update, tch_summary_op,
                                tch_t.kdgan_loss
                            ],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_t, batch_t)
                        #print("teacher kdgan loss:", tch_kdgan_loss)

                for gen_epoch in range(flags.num_gen_epoch):
                    print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                    for _ in range(num_batch_per_epoch):
                        batch_g += 1
                        image_g, text_g, label_dat_g = sess.run(
                            [image_bt_g, text_bt_g, label_bt_g])

                        feed_dict = {
                            tch_t.text_ph: text_g,
                            tch_t.image_ph: image_g
                        }
                        label_tch_g, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        # print('tch label {}'.format(label_tch_g.shape))

                        feed_dict = {gen_t.image_ph: image_g}
                        label_gen_g, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_g = utils.generate_label(
                            flags, label_dat_g, label_gen_g)
                        feed_dict = {
                            dis_t.image_ph: image_g,
                            dis_t.sample_ph: sample_g,
                        }
                        reward_g, = sess.run([dis_t.rewards],
                                             feed_dict=feed_dict)

                        feed_dict = {
                            gen_t.image_ph: image_g,
                            gen_t.hard_label_ph: label_dat_g,
                            gen_t.soft_label_ph: label_tch_g,
                            gen_t.sample_ph: sample_g,
                            gen_t.reward_ph: reward_g,
                        }
                        _, summary_g = sess.run(
                            [gen_t.kdgan_update, gen_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_g, batch_g)

                        # if (batch_g + 1) % eval_interval != 0:
                        #     continue
                        # gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                        # tot_time = time.time() - start
                        # print('#%08d hit=%.4f %06ds' % (batch_g, gen_hit, int(tot_time)))
                        # if gen_hit <= best_hit_v:
                        #   continue
                        # best_hit_v = gen_hit
                        # print('best hit=%.4f' % (best_hit_v))
                gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v)

                tot_time = time.time() - start
                print('#%03d curgen=%.4f curtch=%.4f %.0fs' %
                      (epoch, gen_hit, tch_hit, tot_time))
                figure_data.append((epoch, gen_hit, tch_hit))
                if gen_hit <= best_hit_v:
                    continue
                best_hit_v = gen_hit
                print("epoch ", epoch + 1, ":, new best validation hit:",
                      best_hit_v, "saving...")
                gen_t.saver.save(sess,
                                 flags.kdgan_model_ckpt,
                                 global_step=epoch + 1)
                print("finish saving")

    print('best hit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.kdgan_figure_data))
    fout = open(flags.kdgan_figure_data, 'w')
    for epoch, gen_hit, tch_hit in figure_data:
        fout.write('%d\t%.4f\t%.4f\n' % (epoch, gen_hit, tch_hit))
    fout.close()