def main(_): global_step = tf.train.create_global_step() tch_t = TCH(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() tch_v = TCH(flags, is_training=False) for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('{}\t({} params)'.format(variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True, single=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, config.train_batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v best_hit_v = -np.inf init_op = tf.global_variables_initializer() start = time.time() with tf.Session() as sess: writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) sess.run(init_op) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t]) feed_dict = {tch_t.text_ph:text_np_t, tch_t.label_ph:label_np_t} _, summary = sess.run([tch_t.train_op, tch_t.summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % eval_interval != 0: continue hit_v = [] for batch_v in range(num_batch_v): text_np_v, label_np_v = sess.run([text_bt_v, label_bt_v]) feed_dict = {tch_v.text_ph:text_np_v} logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) tot_time = time.time() - start print('#{0} hit={1:.4f} {2:.0f}s'.format(batch_t, hit_v, tot_time)) if hit_v < best_hit_v: continue best_hit_v = hit_v ckpt_file = path.join(config.ckpt_dir, '%s.ckpt' % flags.tch_model) tch_t.saver.save(sess, ckpt_file) print('best hit={0:.4f}'.format(best_hit_v))
def eval(sess, bt_list_v): user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v image_hit_v, text_hit_v = [], [] for batch_v in range(num_batch_v): image_np_v, text_np_v, label_np_v = sess.run( [image_bt_v, text_bt_v, label_bt_v]) feed_dict = {gen_v.image_ph: image_np_v, tch_v.text_ph: text_np_v} image_logit_v, = sess.run([gen_v.logits], feed_dict=feed_dict) image_hit_bt = metric.compute_hit(image_logit_v, label_np_v, flags.cutoff) image_hit_v.append(image_hit_bt) text_logit_v, = sess.run([tch_v.logits], feed_dict=feed_dict) text_hit_bt = metric.compute_hit(text_logit_v, label_np_v, flags.cutoff) text_hit_v.append(text_hit_bt) image_hit_v = np.mean(image_hit_v) text_hit_v = np.mean(text_hit_v) # print('img:\thit=%.4f\ntxt:\thit=%.4f' % (image_hit_v, text_hit_v)) return image_hit_v
def evaluate(flags, sess, gen_v, bt_list_v): num_batch_v = int(config.valid_data_size / config.valid_batch_size) # print('vd:\t#batch=%d\n' % num_batch_v) user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v image_hit_v = [] for batch_v in range(num_batch_v): image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v]) feed_dict = {gen_v.image_ph: image_np_v} image_logit_v, = sess.run([gen_v.logits], feed_dict=feed_dict) image_hit_bt = metric.compute_hit(image_logit_v, label_np_v, flags.cutoff) image_hit_v.append(image_hit_bt) image_hit_v = np.mean(image_hit_v) return image_hit_v
def evaluate_text(flags, sess, tch_v, bt_list_v): vd_size = get_vd_size(flags.dataset) num_batch_v = int(vd_size / config.valid_batch_size) # print('vd:\t#batch=%d\n' % num_batch_v) user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v text_hit_v = [] for batch_v in range(num_batch_v): text_np_v, label_np_v, image_np_v = sess.run( [text_bt_v, label_bt_v, image_bt_v]) feed_dict = {tch_v.text_ph: text_np_v, tch_v.image_ph: image_np_v} text_logit_v, = sess.run([tch_v.logits], feed_dict=feed_dict) text_hit_bt = metric.compute_hit(text_logit_v, label_np_v, flags.cutoff) text_hit_v.append(text_hit_bt) text_hit_v = np.mean(text_hit_v) return text_hit_v
def main(_): gen_t = GEN(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() gen_v = GEN(flags, is_training=False) tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate) tf.summary.scalar(gen_t.pre_loss.name, gen_t.pre_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t]) feed_dict = { gen_t.image_ph: image_np_t, gen_t.hard_label_ph: label_np_t } _, summary = sess.run([gen_t.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) batch = batch_t + 1 remain = (batch * flags.batch_size) % train_data_size epoch = (batch * flags.batch_size) // train_data_size if remain == 0: pass # print('%d\t%d\t%d' % (epoch, batch, remain)) elif (train_data_size - remain) < flags.batch_size: epoch = epoch + 1 # print('%d\t%d\t%d' % (epoch, batch, remain)) else: continue # if (batch_t + 1) % eval_interval != 0: # continue hit_v = [] for batch_v in range(num_batch_v): image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v]) feed_dict = {gen_v.image_ph: image_np_v} logit_np_v, = sess.run([gen_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) figure_data.append((epoch, hit_v, batch_t)) if hit_v < best_hit_v: continue tot_time = time.time() - start best_hit_v = hit_v print('#%03d curbst=%.4f time=%.0fs' % (epoch, hit_v, tot_time)) gen_t.saver.save(sess, flags.gen_model_ckpt) print('bsthit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.gen_figure_data)) fout = open(flags.gen_figure_data, 'w') for epoch, hit_v, batch_t in figure_data: fout.write('%d\t%.4f\t%d\n' % (epoch, hit_v, batch_t)) fout.close()
def main(_): dis_t = DIS(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() dis_v = DIS(flags, is_training=False) tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate) tf.summary.scalar(dis_t.pre_loss.name, dis_t.pre_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v start = time.time() best_hit_v = -np.inf with tf.Session() as sess: sess.run(init_op) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t]) feed_dict = { dis_t.image_ph: image_np_t, dis_t.hard_label_ph: label_np_t } _, summary = sess.run([dis_t.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % eval_interval != 0: continue hit_v = [] for batch_v in range(num_batch_v): image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v]) feed_dict = {dis_v.image_ph: image_np_v} logit_np_v, = sess.run([dis_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) tot_time = time.time() - start print('#%08d hit=%.4f %06ds' % (batch_t, hit_v, int(tot_time))) if hit_v < best_hit_v: continue best_hit_v = hit_v dis_t.saver.save(sess, flags.dis_model_ckpt) print('best hit=%.4f' % (best_hit_v))