コード例 #1
0
ファイル: pretrain_gen.py プロジェクト: xiaojiew1/KDGAN
def check_tfrecord(bt_list, batch_size):
    id_to_label = utils.load_id_to_label()
    id_to_token = utils.load_id_to_token()

    user_bt, image_bt, text_bt, label_bt, image_file_bt = bt_list
    with tf.Session() as sess:
        with slim.queues.QueueRunners(sess):
            for t in range(3):
                user_np, image_np, text_np, label_np, image_file_np = sess.run(
                    [user_bt, image_bt, text_bt, label_bt, image_file_bt])
                # for b in range(batch_size):
                #     print('{0}\n{0}'.format('#'*80))
                #     print(user_np[b])
                #     num_token = text_np[b].shape[0]
                #     tokens = [id_to_token[text_np[b, i]] for i in range(num_token)]
                #     print(tokens)
                #     label_vt = label_np[b,:]
                #     label_ids = [i for i, l in enumerate(label_vt) if l != 0]
                #     labels = [id_to_label[label_id] for label_id in label_ids]
                #     print(labels)
                #     print(image_file_np[b])
                #     print('{0}\n{0}'.format('#'*80))
                #     input()
                print(user_np.shape, image_np.shape, text_np.shape,
                      label_np.shape, image_file_np.shape)
コード例 #2
0
ファイル: pretrain.py プロジェクト: xiaojiew1/KDGAN
def check_ts_list(ts_list):
    id_to_label = utils.load_id_to_label()
    id_to_token = utils.load_id_to_token()

    user_ts, image_ts, text_ts, label_ts, file_ts = ts_list
    with tf.Session() as sess:
        with slim.queues.QueueRunners(sess):
            for t in range(3):
                user_np, image_np, text_np, label_np, file_np = sess.run(
                        [user_ts, image_ts, text_ts, label_ts, file_ts])
                print(user_np)
                print(image_np.shape)
                print(text_np.shape, text_np)
                print(label_np.shape)
                print(file_np)
コード例 #3
0
ファイル: eval_model.py プロジェクト: xiaojiew1/KDGAN
def main(_):
    utils.create_pardir(flags.model_run)
    id_to_label = utils.load_id_to_label(flags.dataset)
    fout = open(flags.model_run, 'w')
    with tf.train.MonitoredTrainingSession() as sess:
        tn_model.saver.restore(sess, model_ckpt)
        if hasattr(vd_model, 'text_ph'):
            feed_dict = {
                vd_model.image_ph: image_np,
                vd_model.text_ph: text_np,
            }
        else:
            feed_dict = {
                vd_model.image_ph: image_np,
            }
        logit_np = sess.run(vd_model.logits, feed_dict=feed_dict)
        for imgid, logit_np in zip(imgid_np, logit_np):
            sorted_labels = (-logit_np).argsort()
            fout.write('%s' % (imgid))
            for label in sorted_labels:
                fout.write(' %s %.4f' % (id_to_label[label], logit_np[label]))
            fout.write('\n')
    fout.close()
    print('result saved in %s' % flags.model_run)
コード例 #4
0
from tensorflow.contrib import slim

tf.app.flags.DEFINE_integer('embedding_size', 10, '')
tf.app.flags.DEFINE_integer('num_epoch', 100, '')
tf.app.flags.DEFINE_float('init_learning_rate', 0.1, '')
tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.94, '')
tf.app.flags.DEFINE_float('num_epochs_per_decay', 10.0, '')
tf.flags.DEFINE_float('learning_rate', 0.01, '')
tf.app.flags.DEFINE_integer('cutoff', 3, '')
flags = tf.app.flags.FLAGS

num_batch_t = int(flags.num_epoch * config.train_data_size / config.train_batch_size)
num_batch_v = int(config.valid_data_size / config.valid_batch_size)
print('train#batch={} valid#batch={}'.format(num_batch_t, num_batch_v))

id_to_label = utils.load_id_to_label()

def check_tfrecord(bt_list, batch_size):
    id_to_label = utils.load_id_to_label()
    id_to_token = utils.load_id_to_token()

    user_bt, image_bt, text_bt, label_bt, image_file_bt = bt_list
    with tf.Session() as sess:
        with slim.queues.QueueRunners(sess):
            for t in range(3):
                user_np, image_np, text_np, label_np, image_file_np = sess.run(
                        [user_bt, image_bt, text_bt, label_bt, image_file_bt])
                # for b in range(batch_size):
                #     print('{0}\n{0}'.format('#'*80))
                #     print(user_np[b])
                #     num_token = text_np[b].shape[0]