Esempio n. 1
0
def main(_):
  global_step = tf.train.create_global_step()
  tch_t = TCH(flags, is_training=True)
  scope = tf.get_variable_scope()
  scope.reuse_variables()
  tch_v = TCH(flags, is_training=False)

  for variable in tf.trainable_variables():
    num_params = 1
    for dim in variable.shape:
      num_params *= dim.value
    print('{}\t({} params)'.format(variable.name, num_params))

  data_sources_t = utils.get_data_sources(flags, is_training=True, single=True)
  data_sources_v = utils.get_data_sources(flags, is_training=False)
  print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v)))

  ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
  ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
  bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size)
  bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
  user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t
  user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v

  best_hit_v = -np.inf
  init_op = tf.global_variables_initializer()
  start = time.time()
  with tf.Session() as sess:
    writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph())
    sess.run(init_op)
    with slim.queues.QueueRunners(sess):
      for batch_t in range(num_batch_t):
        text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t])
        feed_dict = {tch_t.text_ph:text_np_t, tch_t.label_ph:label_np_t}
        _, summary = sess.run([tch_t.train_op, tch_t.summary_op], feed_dict=feed_dict)
        writer.add_summary(summary, batch_t)

        if (batch_t + 1) % eval_interval != 0:
            continue

        hit_v = []
        for batch_v in range(num_batch_v):
          text_np_v, label_np_v = sess.run([text_bt_v, label_bt_v])
          feed_dict = {tch_v.text_ph:text_np_v}
          logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict)
          hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff)
          hit_v.append(hit_bt)
        hit_v = np.mean(hit_v)

        tot_time = time.time() - start
        print('#{0} hit={1:.4f} {2:.0f}s'.format(batch_t, hit_v, tot_time))

        if hit_v < best_hit_v:
          continue
        best_hit_v = hit_v
        ckpt_file = path.join(config.ckpt_dir, '%s.ckpt' % flags.tch_model)
        tch_t.saver.save(sess, ckpt_file)
  print('best hit={0:.4f}'.format(best_hit_v))
Esempio n. 2
0
import math
import os
import pickle
import time
import numpy as np
import tensorflow as tf
from os import path
from tensorflow.contrib import slim

tn_size = utils.get_tn_size(flags.dataset)
eval_interval = int(tn_size / flags.batch_size)
print('#tn_size=%d' % (tn_size))

tn_dis = DIS(flags, is_training=True)
tn_gen = GEN(flags, is_training=True)
tn_tch = TCH(flags, is_training=True)
scope = tf.get_variable_scope()
scope.reuse_variables()
vd_dis = DIS(flags, is_training=False)
vd_gen = GEN(flags, is_training=False)
vd_tch = TCH(flags, is_training=False)

for variable in tf.trainable_variables():
  num_params = 1
  for dim in variable.shape:
    num_params *= dim.value
  print('%-50s (%d params)' % (variable.name, num_params))

dis_summary_op = tf.summary.merge([
  tf.summary.scalar(tn_dis.learning_rate.name, tn_dis.learning_rate),
  tf.summary.scalar(tn_dis.gan_loss.name, tn_dis.gan_loss),
Esempio n. 3
0
mnist = data_utils.read_data_sets(flags.dataset_dir,
                                  one_hot=True,
                                  train_size=flags.train_size,
                                  valid_size=flags.valid_size,
                                  reshape=True)
datagen = AffineGenerator(mnist)

tn_size, vd_size = mnist.train.num_examples, mnist.test.num_examples
print('tn size=%d vd size=%d' % (tn_size, vd_size))
tn_num_batch = int(flags.num_epoch * tn_size / flags.batch_size)
print('tn #batch=%d' % (tn_num_batch))
eval_interval = int(tn_size / flags.batch_size)
print('ev #interval=%d' % (eval_interval))

tn_gen = GEN(flags, mnist.train, is_training=True)
tn_tch = TCH(flags, mnist.train, is_training=True)
scope = tf.get_variable_scope()
scope.reuse_variables()
vd_gen = GEN(flags, mnist.test, is_training=False)
vd_tch = TCH(flags, mnist.test, is_training=False)

tf.summary.scalar(tn_gen.learning_rate.name, tn_gen.learning_rate)
tf.summary.scalar(tn_gen.kd_loss.name, tn_gen.kd_loss)
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()

tot_params = 0
for variable in tf.trainable_variables():
    num_params = 1
    for dim in variable.shape:
        num_params *= dim.value
Esempio n. 4
0
tf.app.flags.DEFINE_float('tch_weight_decay', 0.00001, 'l2 coefficient')
tf.app.flags.DEFINE_integer('embedding_size', 10, '')
tf.app.flags.DEFINE_string('tch_model_ckpt', None, '')
tf.app.flags.DEFINE_integer('num_tch_epoch', 5, '')
# kdgan
tf.app.flags.DEFINE_integer('num_negative', 1, '')
tf.app.flags.DEFINE_integer('num_positive', 1, '')
flags = tf.app.flags.FLAGS

train_data_size = utils.get_train_data_size(flags.dataset)
num_batch_t = int(flags.num_epoch * train_data_size / config.train_batch_size)
eval_interval = int(train_data_size / config.train_batch_size)
print('tn:\t#batch=%d\neval:\t#interval=%d' % (num_batch_t, eval_interval))

gen_t = GEN(flags, is_training=True)
tch_t = TCH(flags, is_training=True)
scope = tf.get_variable_scope()
scope.reuse_variables()
gen_v = GEN(flags, is_training=False)
tch_v = TCH(flags, is_training=False)


def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate)
    tf.summary.scalar(gen_t.kd_loss.name, gen_t.kd_loss)
Esempio n. 5
0
def main(_):
    tch_t = TCH(flags, is_training=True)
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    tch_v = TCH(flags, is_training=False)

    tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate)
    tf.summary.scalar(tch_t.pre_loss.name, tch_t.pre_loss)
    summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t
    user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            for batch_t in range(num_batch_t):
                #image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t])
                image_np_t, text_np_t, label_np_t = sess.run(
                    [image_bt_t, text_bt_t, label_bt_t])
                feed_dict = {
                    tch_t.text_ph: text_np_t,
                    tch_t.image_ph: image_np_t,
                    tch_t.hard_label_ph: label_np_t
                }
                _, summary = sess.run([tch_t.pre_update, summary_op],
                                      feed_dict=feed_dict)
                writer.add_summary(summary, batch_t)

                batch = batch_t + 1
                remain = (batch * flags.batch_size) % train_data_size
                epoch = (batch * flags.batch_size) // train_data_size
                if remain == 0:
                    pass
                    # print('%d\t%d\t%d' % (epoch, batch, remain))
                elif (train_data_size - remain) < flags.batch_size:
                    epoch = epoch + 1
                    # print('%d\t%d\t%d' % (epoch, batch, remain))
                else:
                    continue
                # if (batch_t + 1) % eval_interval != 0:
                #     continue

                hit_v = []
                for batch_v in range(num_batch_v):
                    #image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v])
                    text_np_v, image_np_v, label_np_v = sess.run(
                        [text_bt_v, image_bt_v, label_bt_v])
                    feed_dict = {
                        tch_v.text_ph: text_np_v,
                        tch_v.image_ph: image_np_v
                    }
                    logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict)
                    hit_bt = metric.compute_hit(logit_np_v, label_np_v,
                                                flags.cutoff)
                    hit_v.append(hit_bt)
                hit_v = np.mean(hit_v)

                figure_data.append((epoch, hit_v, batch_t))

                if hit_v < best_hit_v:
                    continue
                tot_time = time.time() - start
                best_hit_v = hit_v
                print('#%03d curbst=%.4f time=%.0fs' %
                      (epoch, hit_v, tot_time))
                tch_t.saver.save(sess, flags.tch_model_ckpt)
    print('bsthit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.gen_figure_data))
    fout = open(flags.gen_figure_data, 'w')
    for epoch, hit_v, batch_t in figure_data:
        fout.write('%d\t%.4f\t%d\n' % (epoch, hit_v, batch_t))
    fout.close()
Esempio n. 6
0
def main(_):
    tch_t = TCH(flags, is_training=True)
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    tch_v = TCH(flags, is_training=False)

    tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate)
    tf.summary.scalar(tch_t.pre_loss.name, tch_t.pre_loss)
    summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t
    user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            for batch_t in range(num_batch_t):
                text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t])
                image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t])
                feed_dict = {
                    tch_t.image_ph: image_np_t,
                    tch_t.text_ph: text_np_t,
                    tch_t.hard_label_ph: label_np_t
                }
                _, summary = sess.run([tch_t.pre_update, summary_op],
                                      feed_dict=feed_dict)
                writer.add_summary(summary, batch_t)

                if (batch_t + 1) % eval_interval != 0:
                    continue

                hit_v = []
                for batch_v in range(num_batch_v):
                    text_np_v, label_np_v = sess.run([text_bt_v, label_bt_v])
                    image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v])
                    feed_dict = {
                        tch_v.image_ph: image_np_v,
                        tch_v.text_ph: text_np_v
                    }
                    logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict)
                    hit_bt = metric.compute_hit(logit_np_v, label_np_v,
                                                flags.cutoff)
                    hit_v.append(hit_bt)
                hit_v = np.mean(hit_v)

                tot_time = time.time() - start
                print('#%08d hit=%.4f %06ds' % (batch_t, hit_v, int(tot_time)))

                if hit_v < best_hit_v:
                    continue
                best_hit_v = hit_v
                tch_t.saver.save(sess, flags.tch_model_ckpt)
    print('best hit=%.4f' % (best_hit_v))