def main(_): print('#label={}'.format(config.num_label)) gen_t = GEN(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() gen_v = GEN(flags, is_training=False) # train_filename = 'yfcc10k_{}.train.tfrecord'.format(flags.model_name) # train_tfrecord = path.join(config.tfrecord_dir, train_filename) # valid_filename = 'yfcc10k_{}.valid.tfrecord'.format(flags.model_name) # valid_tfrecord = path.join(config.tfrecord_dir, valid_filename) data_sources_t = utils.get_data_sources(flags, config.train_file, 500) data_sources_v = utils.get_data_sources(flags, config.valid_file, 1) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) # check_ts_list(ts_list_t) # check_ts_list(ts_list_v) bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) # check_bt_list(bt_list_t, config.train_batch_size) # check_bt_list(bt_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, image_file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, image_file_bt_v = bt_list_v best_hit_v = -np.inf init_op = tf.global_variables_initializer() start = time.time() with tf.Session() as sess: writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) sess.run(init_op) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t]) feed_dict = {gen_t.image_ph:image_np_t, gen_t.label_ph:label_np_t} _, summary = sess.run([gen_t.train_op, gen_t.summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % int(config.train_data_size / config.train_batch_size) != 0: continue hit_v = [] image_file_v = set() for batch_v in range(num_batch_v): image_np_v, label_np_v, image_file_np_v = sess.run([image_bt_v, label_bt_v, image_file_bt_v]) feed_dict = {gen_v.image_ph:image_np_v} logit_np_v, = sess.run([gen_v.logits], feed_dict=feed_dict) for image_file in image_file_np_v: image_file_v.add(image_file) hit_bt = compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) total_time = time.time() - start avg_batch = total_time / (batch_t + 1) avg_epoch = avg_batch * (config.train_data_size / config.train_batch_size) s = '{0} hit={1:.4f} tot={2:.0f}s avg={3:.0f}s' s = s.format(batch_t, hit_v, total_time, avg_epoch) print(s)
def main(_): global_step = tf.train.create_global_step() tch_t = TCH(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() tch_v = TCH(flags, is_training=False) for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('{}\t({} params)'.format(variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True, single=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v best_hit_v = -np.inf init_op = tf.global_variables_initializer() start = time.time() with tf.Session() as sess: writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) sess.run(init_op) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t]) feed_dict = {tch_t.text_ph:text_np_t, tch_t.label_ph:label_np_t} _, summary = sess.run([tch_t.train_op, tch_t.summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % eval_interval != 0: continue hit_v = [] for batch_v in range(num_batch_v): text_np_v, label_np_v = sess.run([text_bt_v, label_bt_v]) feed_dict = {tch_v.text_ph:text_np_v} logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) tot_time = time.time() - start print('#{0} hit={1:.4f} {2:.0f}s'.format(batch_t, hit_v, tot_time)) if hit_v < best_hit_v: continue best_hit_v = hit_v ckpt_file = path.join(config.ckpt_dir, '%s.ckpt' % flags.tch_model) tch_t.saver.save(sess, ckpt_file) print('best hit={0:.4f}'.format(best_hit_v))
def get_batch(self, flags, is_training): if is_training: single = False stage = 'train' shuffle = True else: single = True stage = 'valid' shuffle = False data_sources = utils.get_data_sources(flags, is_training=is_training, single=single) print('#tfrecord=%d for %s' % (len(data_sources), stage)) ts_list = utils.decode_tfrecord(flags, data_sources, shuffle=shuffle) bt_list = utils.generate_batch(ts_list, flags.batch_size) return bt_list
def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('{}\t({} params)'.format(variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_d = utils.generate_batch(ts_list_d, config.train_batch_size) user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_g = utils.generate_batch(ts_list_g, config.train_batch_size) user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) best_hit_v = -np.inf start = time.time() init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) dis_t.saver.restore(sess, flags.dis_model_ckpt) gen_t.saver.restore(sess, flags.gen_model_ckpt) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): image_hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v) print('init\thit={0:.4f}'.format(image_hit_v)) batch_d, batch_g = -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) num_batch_d = math.ceil(train_data_size / config.train_batch_size) for _ in range(num_batch_d): batch_d += 1 image_np_d, label_dat_d = sess.run( [image_bt_d, label_bt_d]) # print(image_np_d.shape, label_dat_d.shape) feed_dict = {gen_t.image_ph: image_np_d} label_gen_d, = sess.run([gen_t.labels], feed_dict=feed_dict) # print(label_gen_d.shape, type(label_gen_d)) sample_np_d, label_np_d = gan_dis_sample( label_dat_d, label_gen_d) feed_dict = { dis_t.image_ph: image_np_d, dis_t.sample_ph: sample_np_d, dis_t.label_ph: label_np_d, } sess.run(dis_t.train_op, feed_dict=feed_dict) # _, summary = sess.run([dis_t.train_op, dis_t.summary_op], feed_dict=feed_dict) # writer.add_summary(summary, batch_d) # if (batch_d + 1) % eval_interval != 0: # continue # image_hit_v = utils.evaluate(flags, sess, dis_v, bt_list_v) # tot_time = time.time() - start # print('#%d hit=%.4f (%.0fs)' % (batch_d, image_hit_v, tot_time)) for gen_epoch in range(flags.num_gen_epoch): print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) num_batch_g = math.ceil(train_data_size / config.train_batch_size) for _ in range(num_batch_g): batch_g += 1 image_np_g, label_dat_g = sess.run( [image_bt_g, label_bt_g]) # print(image_np_g.shape, label_dat_g.shape) feed_dict = {gen_t.image_ph: image_np_g} label_gen_g, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_np_g = generate_label(label_dat_g, label_gen_g) # for sample in sample_np_g: # print(sample) feed_dict = { dis_t.image_ph: image_np_g, dis_t.sample_ph: sample_np_g, } reward_np_g, = sess.run([dis_t.rewards], feed_dict=feed_dict) # for sample, reward in zip(sample_np_g, reward_np_g): # batch = sample[0] # label = [i for i, l in enumerate(label_dat_g[batch]) if l != 0.0] # print(sample, reward, label) # print('%.2f %.2f' % (reward_np_g.min(), reward_np_g.max())) # input() feed_dict = { gen_t.image_ph: image_np_g, gen_t.sample_ph: sample_np_g, gen_t.reward_ph: reward_np_g, } sess.run([gen_t.gan_train_op], feed_dict=feed_dict) if (batch_g + 1) % eval_interval != 0: continue image_hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v) tot_time = time.time() - start print('#%d hit=%.4f (%.0fs)' % (batch_g, image_hit_v, tot_time)) if image_hit_v > best_hit_v: best_hit_v = image_hit_v print('best hit=%.4f (%.0fs)' % (image_hit_v, tot_time)) # break print('best\thit={0:.4f}'.format(best_hit_v))
def main(_): gen_t = GEN(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() gen_v = GEN(flags, is_training=False) tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate) tf.summary.scalar(gen_t.pre_loss.name, gen_t.pre_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t]) feed_dict = { gen_t.image_ph: image_np_t, gen_t.hard_label_ph: label_np_t } _, summary = sess.run([gen_t.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) batch = batch_t + 1 remain = (batch * flags.batch_size) % train_data_size epoch = (batch * flags.batch_size) // train_data_size if remain == 0: pass # print('%d\t%d\t%d' % (epoch, batch, remain)) elif (train_data_size - remain) < flags.batch_size: epoch = epoch + 1 # print('%d\t%d\t%d' % (epoch, batch, remain)) else: continue # if (batch_t + 1) % eval_interval != 0: # continue hit_v = [] for batch_v in range(num_batch_v): image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v]) feed_dict = {gen_v.image_ph: image_np_v} logit_np_v, = sess.run([gen_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) figure_data.append((epoch, hit_v, batch_t)) if hit_v < best_hit_v: continue tot_time = time.time() - start best_hit_v = hit_v print('#%03d curbst=%.4f time=%.0fs' % (epoch, hit_v, tot_time)) gen_t.saver.save(sess, flags.gen_model_ckpt) print('bsthit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.gen_figure_data)) fout = open(flags.gen_figure_data, 'w') for epoch, hit_v, batch_t in figure_data: fout.write('%d\t%.4f\t%d\n' % (epoch, hit_v, batch_t)) fout.close()
def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate) tf.summary.scalar(gen_t.kd_loss.name, gen_t.kd_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) gen_t.saver.restore(sess, flags.gen_model_ckpt) tch_t.saver.restore(sess, flags.tch_model_ckpt) with slim.queues.QueueRunners(sess): hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v) print('init hit=%.4f' % (hit_v)) for batch_t in range(num_batch_t): image_np_t, text_np_t, hard_labels = sess.run( [image_bt_t, text_bt_t, label_bt_t]) # print('hard labels:\t{}'.format(hard_labels.shape)) # print(np.argsort(-hard_labels[0,:])[:10]) feed_dict = {tch_t.text_ph: text_np_t} soft_labels, = sess.run([tch_t.labels], feed_dict=feed_dict) # print('soft labels:\t{}'.format(soft_labels.shape)) # print(np.argsort(-soft_labels[0,:])[:10]) feed_dict = { gen_t.image_ph: image_np_t, gen_t.hard_label_ph: hard_labels, gen_t.soft_label_ph: soft_labels, } _, summary = sess.run([gen_t.kd_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % eval_interval != 0: continue hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v) tot_time = time.time() - start print('#%08d hit=%.4f %06ds' % (batch_t, hit_v, int(tot_time))) if hit_v <= best_hit_v: continue best_hit_v = hit_v print('best hit=%.4f' % (best_hit_v)) print('best hit=%.4f' % (best_hit_v))
def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) dis_summary_op = tf.summary.merge([ tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate), tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss), ]) gen_summary_op = tf.summary.merge([ tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate), tf.summary.scalar(gen_t.gan_loss.name, gen_t.gan_loss), ]) print(type(dis_summary_op), type(gen_summary_op)) init_op = tf.global_variables_initializer() data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size) user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size) user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) dis_t.saver.restore(sess, flags.dis_model_ckpt) gen_t.saver.restore(sess, flags.gen_model_ckpt) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v) print('init hit=%.4f' % (hit_v)) batch_d, batch_g = -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) num_batch_d = math.ceil(train_data_size / flags.batch_size) for _ in range(num_batch_d): batch_d += 1 image_np_d, label_dat_d = sess.run( [image_bt_d, label_bt_d]) feed_dict = {gen_t.image_ph: image_np_d} label_gen_d, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_np_d, label_np_d = utils.gan_dis_sample( flags, label_dat_d, label_gen_d) feed_dict = { dis_t.image_ph: image_np_d, dis_t.sample_ph: sample_np_d, dis_t.dis_label_ph: label_np_d, } _, summary_d = sess.run( [dis_t.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) for gen_epoch in range(flags.num_gen_epoch): print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) num_batch_g = math.ceil(train_data_size / flags.batch_size) for _ in range(num_batch_g): batch_g += 1 image_np_g, label_dat_g = sess.run( [image_bt_g, label_bt_g]) feed_dict = {gen_t.image_ph: image_np_g} label_gen_g, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_np_g = utils.generate_label( flags, label_dat_g, label_gen_g) feed_dict = { dis_t.image_ph: image_np_g, dis_t.sample_ph: sample_np_g, } reward_np_g, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_np_g, gen_t.sample_ph: sample_np_g, gen_t.reward_ph: reward_np_g, } _, summary_g = sess.run( [gen_t.gan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) # if (batch_g + 1) % eval_interval != 0: # continue # hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v) # tot_time = time.time() - start # print('#%08d hit=%.4f %06ds' % (batch_g, hit_v, int(tot_time))) # if hit_v <= best_hit_v: # continue # best_hit_v = hit_v # print('best hit=%.4f' % (best_hit_v)) hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tot_time = time.time() - start print('#%03d curbst=%.4f %.0fs' % (epoch, hit_v, tot_time)) figure_data.append((epoch, hit_v)) if hit_v <= best_hit_v: continue best_hit_v = hit_v print('bsthit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.gan_figure_data)) fout = open(flags.gan_figure_data, 'w') for epoch, hit_v in figure_data: fout.write('%d\t%.4f\n' % (epoch, hit_v)) fout.close()
def main(_): dis_t = DIS(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() dis_v = DIS(flags, is_training=False) tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate) tf.summary.scalar(dis_t.pre_loss.name, dis_t.pre_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v start = time.time() best_hit_v = -np.inf with tf.Session() as sess: sess.run(init_op) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t]) feed_dict = { dis_t.image_ph: image_np_t, dis_t.hard_label_ph: label_np_t } _, summary = sess.run([dis_t.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % eval_interval != 0: continue hit_v = [] for batch_v in range(num_batch_v): image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v]) feed_dict = {dis_v.image_ph: image_np_v} logit_np_v, = sess.run([dis_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) tot_time = time.time() - start print('#%08d hit=%.4f %06ds' % (batch_t, hit_v, int(tot_time))) if hit_v < best_hit_v: continue best_hit_v = hit_v dis_t.saver.save(sess, flags.dis_model_ckpt) print('best hit=%.4f' % (best_hit_v))
def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) dis_summary_op = tf.summary.merge([ tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate), tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss), ]) gen_summary_op = tf.summary.merge([ tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate), tf.summary.scalar(gen_t.kdgan_loss.name, gen_t.kdgan_loss), ]) tch_summary_op = tf.summary.merge([ tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate), tf.summary.scalar(tch_t.kdgan_loss.name, tch_t.kdgan_loss), ]) init_op = tf.global_variables_initializer() data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size) user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size) user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) dis_t.saver.restore(sess, flags.dis_model_ckpt) gen_t.saver.restore(sess, flags.gen_model_ckpt) tch_t.saver.restore(sess, flags.tch_model_ckpt) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v) print('hit gen=%.4f tch=%.4f' % (gen_hit, tch_hit)) batch_d, batch_g, batch_t = -1, -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) for _ in range(num_batch_per_epoch): #continue batch_d += 1 image_d, text_d, label_dat_d = sess.run( [image_bt_d, text_bt_d, label_bt_d]) feed_dict = {gen_t.image_ph: image_d} label_gen_d, = sess.run([gen_t.labels], feed_dict=feed_dict) # print('gen label', label_gen_d.shape) feed_dict = { tch_t.text_ph: text_d, tch_t.image_ph: image_d } label_tch_d, = sess.run([tch_t.labels], feed_dict=feed_dict) # print('tch label', label_tch_d.shape) sample_d, label_d = utils.kdgan_dis_sample( flags, label_dat_d, label_gen_d, label_tch_d) # print(sample_d.shape, label_d.shape) feed_dict = { dis_t.image_ph: image_d, dis_t.sample_ph: sample_d, dis_t.dis_label_ph: label_d, } _, summary_d = sess.run( [dis_t.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) for tch_epoch in range(flags.num_tch_epoch): print('epoch %03d tch_epoch %03d' % (epoch, tch_epoch)) for _ in range(num_batch_per_epoch): #continue batch_t += 1 image_t, text_t, label_dat_t = sess.run( [image_bt_t, text_bt_t, label_bt_t]) feed_dict = { tch_t.text_ph: text_t, tch_t.image_ph: image_t } label_tch_t, = sess.run([tch_t.labels], feed_dict=feed_dict) sample_t = utils.generate_label( flags, label_dat_t, label_tch_t) feed_dict = { dis_t.image_ph: image_t, dis_t.sample_ph: sample_t, } reward_t, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_t, } label_gen_g = sess.run(gen_t.logits, feed_dict=feed_dict) #print(len(label_dat_t), len(label_dat_t[0])) #exit() feed_dict = { tch_t.text_ph: text_t, tch_t.image_ph: image_t, tch_t.sample_ph: sample_t, tch_t.reward_ph: reward_t, tch_t.hard_label_ph: label_dat_t, tch_t.soft_label_ph: label_gen_g, } _, summary_t, tch_kdgan_loss = sess.run( [ tch_t.kdgan_update, tch_summary_op, tch_t.kdgan_loss ], feed_dict=feed_dict) writer.add_summary(summary_t, batch_t) #print("teacher kdgan loss:", tch_kdgan_loss) for gen_epoch in range(flags.num_gen_epoch): print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) for _ in range(num_batch_per_epoch): batch_g += 1 image_g, text_g, label_dat_g = sess.run( [image_bt_g, text_bt_g, label_bt_g]) feed_dict = { tch_t.text_ph: text_g, tch_t.image_ph: image_g } label_tch_g, = sess.run([tch_t.labels], feed_dict=feed_dict) # print('tch label {}'.format(label_tch_g.shape)) feed_dict = {gen_t.image_ph: image_g} label_gen_g, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_g = utils.generate_label( flags, label_dat_g, label_gen_g) feed_dict = { dis_t.image_ph: image_g, dis_t.sample_ph: sample_g, } reward_g, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_g, gen_t.hard_label_ph: label_dat_g, gen_t.soft_label_ph: label_tch_g, gen_t.sample_ph: sample_g, gen_t.reward_ph: reward_g, } _, summary_g = sess.run( [gen_t.kdgan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) # if (batch_g + 1) % eval_interval != 0: # continue # gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) # tot_time = time.time() - start # print('#%08d hit=%.4f %06ds' % (batch_g, gen_hit, int(tot_time))) # if gen_hit <= best_hit_v: # continue # best_hit_v = gen_hit # print('best hit=%.4f' % (best_hit_v)) gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v) tot_time = time.time() - start print('#%03d curgen=%.4f curtch=%.4f %.0fs' % (epoch, gen_hit, tch_hit, tot_time)) figure_data.append((epoch, gen_hit, tch_hit)) if gen_hit <= best_hit_v: continue best_hit_v = gen_hit print("epoch ", epoch + 1, ":, new best validation hit:", best_hit_v, "saving...") gen_t.saver.save(sess, flags.kdgan_model_ckpt, global_step=epoch + 1) print("finish saving") print('best hit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.kdgan_figure_data)) fout = open(flags.kdgan_figure_data, 'w') for epoch, gen_hit, tch_hit in figure_data: fout.write('%d\t%.4f\t%.4f\n' % (epoch, gen_hit, tch_hit)) fout.close()