def main(_): global_step = tf.train.create_global_step() tch_t = TCH(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() tch_v = TCH(flags, is_training=False) for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('{}\t({} params)'.format(variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True, single=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v best_hit_v = -np.inf init_op = tf.global_variables_initializer() start = time.time() with tf.Session() as sess: writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) sess.run(init_op) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t]) feed_dict = {tch_t.text_ph:text_np_t, tch_t.label_ph:label_np_t} _, summary = sess.run([tch_t.train_op, tch_t.summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % eval_interval != 0: continue hit_v = [] for batch_v in range(num_batch_v): text_np_v, label_np_v = sess.run([text_bt_v, label_bt_v]) feed_dict = {tch_v.text_ph:text_np_v} logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) tot_time = time.time() - start print('#{0} hit={1:.4f} {2:.0f}s'.format(batch_t, hit_v, tot_time)) if hit_v < best_hit_v: continue best_hit_v = hit_v ckpt_file = path.join(config.ckpt_dir, '%s.ckpt' % flags.tch_model) tch_t.saver.save(sess, ckpt_file) print('best hit={0:.4f}'.format(best_hit_v))
import math import os import pickle import time import numpy as np import tensorflow as tf from os import path from tensorflow.contrib import slim tn_size = utils.get_tn_size(flags.dataset) eval_interval = int(tn_size / flags.batch_size) print('#tn_size=%d' % (tn_size)) tn_dis = DIS(flags, is_training=True) tn_gen = GEN(flags, is_training=True) tn_tch = TCH(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() vd_dis = DIS(flags, is_training=False) vd_gen = GEN(flags, is_training=False) vd_tch = TCH(flags, is_training=False) for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) dis_summary_op = tf.summary.merge([ tf.summary.scalar(tn_dis.learning_rate.name, tn_dis.learning_rate), tf.summary.scalar(tn_dis.gan_loss.name, tn_dis.gan_loss),
mnist = data_utils.read_data_sets(flags.dataset_dir, one_hot=True, train_size=flags.train_size, valid_size=flags.valid_size, reshape=True) datagen = AffineGenerator(mnist) tn_size, vd_size = mnist.train.num_examples, mnist.test.num_examples print('tn size=%d vd size=%d' % (tn_size, vd_size)) tn_num_batch = int(flags.num_epoch * tn_size / flags.batch_size) print('tn #batch=%d' % (tn_num_batch)) eval_interval = int(tn_size / flags.batch_size) print('ev #interval=%d' % (eval_interval)) tn_gen = GEN(flags, mnist.train, is_training=True) tn_tch = TCH(flags, mnist.train, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() vd_gen = GEN(flags, mnist.test, is_training=False) vd_tch = TCH(flags, mnist.test, is_training=False) tf.summary.scalar(tn_gen.learning_rate.name, tn_gen.learning_rate) tf.summary.scalar(tn_gen.kd_loss.name, tn_gen.kd_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() tot_params = 0 for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value
tf.app.flags.DEFINE_float('tch_weight_decay', 0.00001, 'l2 coefficient') tf.app.flags.DEFINE_integer('embedding_size', 10, '') tf.app.flags.DEFINE_string('tch_model_ckpt', None, '') tf.app.flags.DEFINE_integer('num_tch_epoch', 5, '') # kdgan tf.app.flags.DEFINE_integer('num_negative', 1, '') tf.app.flags.DEFINE_integer('num_positive', 1, '') flags = tf.app.flags.FLAGS train_data_size = utils.get_train_data_size(flags.dataset) num_batch_t = int(flags.num_epoch * train_data_size / config.train_batch_size) eval_interval = int(train_data_size / config.train_batch_size) print('tn:\t#batch=%d\neval:\t#interval=%d' % (num_batch_t, eval_interval)) gen_t = GEN(flags, is_training=True) tch_t = TCH(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() gen_v = GEN(flags, is_training=False) tch_v = TCH(flags, is_training=False) def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate) tf.summary.scalar(gen_t.kd_loss.name, gen_t.kd_loss)
def main(_): tch_t = TCH(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() tch_v = TCH(flags, is_training=False) tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate) tf.summary.scalar(tch_t.pre_loss.name, tch_t.pre_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): #image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t]) image_np_t, text_np_t, label_np_t = sess.run( [image_bt_t, text_bt_t, label_bt_t]) feed_dict = { tch_t.text_ph: text_np_t, tch_t.image_ph: image_np_t, tch_t.hard_label_ph: label_np_t } _, summary = sess.run([tch_t.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) batch = batch_t + 1 remain = (batch * flags.batch_size) % train_data_size epoch = (batch * flags.batch_size) // train_data_size if remain == 0: pass # print('%d\t%d\t%d' % (epoch, batch, remain)) elif (train_data_size - remain) < flags.batch_size: epoch = epoch + 1 # print('%d\t%d\t%d' % (epoch, batch, remain)) else: continue # if (batch_t + 1) % eval_interval != 0: # continue hit_v = [] for batch_v in range(num_batch_v): #image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v]) text_np_v, image_np_v, label_np_v = sess.run( [text_bt_v, image_bt_v, label_bt_v]) feed_dict = { tch_v.text_ph: text_np_v, tch_v.image_ph: image_np_v } logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) figure_data.append((epoch, hit_v, batch_t)) if hit_v < best_hit_v: continue tot_time = time.time() - start best_hit_v = hit_v print('#%03d curbst=%.4f time=%.0fs' % (epoch, hit_v, tot_time)) tch_t.saver.save(sess, flags.tch_model_ckpt) print('bsthit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.gen_figure_data)) fout = open(flags.gen_figure_data, 'w') for epoch, hit_v, batch_t in figure_data: fout.write('%d\t%.4f\t%d\n' % (epoch, hit_v, batch_t)) fout.close()
def main(_): tch_t = TCH(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() tch_v = TCH(flags, is_training=False) tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate) tf.summary.scalar(tch_t.pre_loss.name, tch_t.pre_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t]) image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t]) feed_dict = { tch_t.image_ph: image_np_t, tch_t.text_ph: text_np_t, tch_t.hard_label_ph: label_np_t } _, summary = sess.run([tch_t.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % eval_interval != 0: continue hit_v = [] for batch_v in range(num_batch_v): text_np_v, label_np_v = sess.run([text_bt_v, label_bt_v]) image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v]) feed_dict = { tch_v.image_ph: image_np_v, tch_v.text_ph: text_np_v } logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) tot_time = time.time() - start print('#%08d hit=%.4f %06ds' % (batch_t, hit_v, int(tot_time))) if hit_v < best_hit_v: continue best_hit_v = hit_v tch_t.saver.save(sess, flags.tch_model_ckpt) print('best hit=%.4f' % (best_hit_v))