def check_tfrecord(bt_list, batch_size): id_to_label = utils.load_id_to_label() id_to_token = utils.load_id_to_token() user_bt, image_bt, text_bt, label_bt, image_file_bt = bt_list with tf.Session() as sess: with slim.queues.QueueRunners(sess): for t in range(3): user_np, image_np, text_np, label_np, image_file_np = sess.run( [user_bt, image_bt, text_bt, label_bt, image_file_bt]) # for b in range(batch_size): # print('{0}\n{0}'.format('#'*80)) # print(user_np[b]) # num_token = text_np[b].shape[0] # tokens = [id_to_token[text_np[b, i]] for i in range(num_token)] # print(tokens) # label_vt = label_np[b,:] # label_ids = [i for i, l in enumerate(label_vt) if l != 0] # labels = [id_to_label[label_id] for label_id in label_ids] # print(labels) # print(image_file_np[b]) # print('{0}\n{0}'.format('#'*80)) # input() print(user_np.shape, image_np.shape, text_np.shape, label_np.shape, image_file_np.shape)
def check_ts_list(ts_list): id_to_label = utils.load_id_to_label() id_to_token = utils.load_id_to_token() user_ts, image_ts, text_ts, label_ts, file_ts = ts_list with tf.Session() as sess: with slim.queues.QueueRunners(sess): for t in range(3): user_np, image_np, text_np, label_np, file_np = sess.run( [user_ts, image_ts, text_ts, label_ts, file_ts]) print(user_np) print(image_np.shape) print(text_np.shape, text_np) print(label_np.shape) print(file_np)
def main(_): utils.create_pardir(flags.model_run) id_to_label = utils.load_id_to_label(flags.dataset) fout = open(flags.model_run, 'w') with tf.train.MonitoredTrainingSession() as sess: tn_model.saver.restore(sess, model_ckpt) if hasattr(vd_model, 'text_ph'): feed_dict = { vd_model.image_ph: image_np, vd_model.text_ph: text_np, } else: feed_dict = { vd_model.image_ph: image_np, } logit_np = sess.run(vd_model.logits, feed_dict=feed_dict) for imgid, logit_np in zip(imgid_np, logit_np): sorted_labels = (-logit_np).argsort() fout.write('%s' % (imgid)) for label in sorted_labels: fout.write(' %s %.4f' % (id_to_label[label], logit_np[label])) fout.write('\n') fout.close() print('result saved in %s' % flags.model_run)
from tensorflow.contrib import slim tf.app.flags.DEFINE_integer('embedding_size', 10, '') tf.app.flags.DEFINE_integer('num_epoch', 100, '') tf.app.flags.DEFINE_float('init_learning_rate', 0.1, '') tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.94, '') tf.app.flags.DEFINE_float('num_epochs_per_decay', 10.0, '') tf.flags.DEFINE_float('learning_rate', 0.01, '') tf.app.flags.DEFINE_integer('cutoff', 3, '') flags = tf.app.flags.FLAGS num_batch_t = int(flags.num_epoch * config.train_data_size / config.train_batch_size) num_batch_v = int(config.valid_data_size / config.valid_batch_size) print('train#batch={} valid#batch={}'.format(num_batch_t, num_batch_v)) id_to_label = utils.load_id_to_label() def check_tfrecord(bt_list, batch_size): id_to_label = utils.load_id_to_label() id_to_token = utils.load_id_to_token() user_bt, image_bt, text_bt, label_bt, image_file_bt = bt_list with tf.Session() as sess: with slim.queues.QueueRunners(sess): for t in range(3): user_np, image_np, text_np, label_np, image_file_np = sess.run( [user_bt, image_bt, text_bt, label_bt, image_file_bt]) # for b in range(batch_size): # print('{0}\n{0}'.format('#'*80)) # print(user_np[b]) # num_token = text_np[b].shape[0]