Beispiel #1
0
def collect_image(infile, outdir):
    post_image = {}
    fin = open(infile)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        post = fields[POST_INDEX]
        image = fields[IMAGE_INDEX]
        post_image[post] = image
    fin.close()
    utils.create_if_nonexist(outdir)
    fin = open(sample_file)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        post = fields[POST_INDEX]
        if post not in post_image:
            continue
        image_url = fields[IMAGE_INDEX]
        src_file = get_image_path(image_dir, image_url)
        image = post_image[post]
        dst_file = path.join(outdir, '%s.jpg' % image)
        if path.isfile(dst_file):
            continue
        shutil.copyfile(src_file, dst_file)
Beispiel #2
0
def main(_):
    check_num_field()
    utils.create_if_nonexist(dataset_dir)
    if not utils.skip_if_exist(label_file):
        print('select top labels')
        select_top_label()
    if not utils.skip_if_exist(raw_file):
        print('select posts')
        select_posts()
    if not utils.skip_if_exist(data_file):
        print('tokenize dataset')
        tokenize_dataset()
        count_dataset()
    if (not utils.skip_if_exist(train_file) or not utils.skip_if_exist(
            valid_file or not utils.skip_if_exist(vocab_file))):
        # if True:
        print('split dataset')
        split_dataset()

    # if path.isdir(image_dir):
    if False:
        print('collect images')
        # find ImageData/ -type f | wc -l
        collect_image(data_file, image_data_dir)
    # create_survey_data()

    create_tfrecord(valid_file, end_point_v, is_training=False)
    create_tfrecord(train_file, end_point_t, is_training=True)
Beispiel #3
0
def main(_):
    if flags.overwrite:
        print('create yfcc small rnd dataset')
        utils.delete_if_exist(dataset_dir)
        utils.create_if_nonexist(dataset_dir)

    check_num_field()
    if flags.overwrite or (not utils.skip_if_exist(raw_file)):
        while True:
            print('random labels and posts')
            select_rnd_label()
            min_count = select_posts()
            if min_count < MIN_RND_POST:
                continue
            break

    if flags.overwrite or (not utils.skip_if_exist(data_file)):
        print('tokenize and collect images')
        tokenize_dataset()
        collect_image(data_file, image_data_dir)

    if (flags.overwrite or not utils.skip_if_exist(train_file)
            or not utils.skip_if_exist(valid_file)
            or not utils.skip_if_exist(vocab_file)):
        while True:
            print('split into train and valid')
            try:
                split_dataset()
                break
            except:
                continue

    if flags.baseline:
        print('create survey data')
        create_survey_data()
Beispiel #4
0
def count_dataset():
    utils.create_if_nonexist(temp_dir)
    user_count = {}
    fin = open(data_file)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        user = fields[USER_INDEX]
        if user not in user_count:
            user_count[user] = 0
        user_count[user] += 1
    fin.close()
    sorted_user_count = sorted(user_count.items(),
                               key=operator.itemgetter(0),
                               reverse=True)
    outfile = path.join(temp_dir, 'user_count')
    with open(outfile, 'w') as fout:
        for user, count in sorted_user_count:
            fout.write('{}\t{}\n'.format(user, count))

    label_count = {}
    fin = open(data_file)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        user = fields[USER_INDEX]
        labels = fields[LABEL_INDEX].split()
        assert len(labels) != 0
        for label in labels:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
    fin.close()
    sorted_label_count = sorted(label_count.items(),
                                key=operator.itemgetter(0),
                                reverse=True)
    outfile = path.join(temp_dir, 'label_count')
    labels, lemms = set(), set()
    with open(outfile, 'w') as fout:
        for label, count in sorted_label_count:
            labels.add(label)
            lemm = lemmatizer.lemmatize(label)
            lemms.add(lemm)
            if lemm != label:
                print('{}->{}'.format(lemm, label))
            fout.write('{}\t{}\n'.format(label, count))
    print('#label={} #lemm={}'.format(len(labels), len(lemms)))
Beispiel #5
0
def survey_image_data(infile):
    dataset = get_dataset(infile)
    image_data = path.join(surv_dir, dataset, 'ImageData')
    utils.create_if_nonexist(image_data)
    fout = open(path.join(image_data, '%s.txt' % dataset), 'w')
    fin = open(infile)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        image = fields[IMAGE_INDEX]
        image_file = '%s.jpg' % image
        fout.write('{}\n'.format(image_file))
    fin.close()
    fout.close()
    collect_image(infile, image_data)
Beispiel #6
0
def survey_feature_sets(infile):
    dataset = get_dataset(infile)
    image_sets = path.join(surv_dir, dataset, 'ImageSets')
    utils.create_if_nonexist(image_sets)

    fout = open(path.join(image_sets, '%s.txt' % dataset), 'w')
    fin = open(infile)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        image = fields[IMAGE_INDEX]
        fout.write('{}\n'.format(image))
    fin.close()
    fout.close()

    fout = open(path.join(image_sets, 'holdout.txt'), 'w')
    fout.close()
Beispiel #7
0
def survey_annotations(infile):
    dataset = get_dataset(infile)
    annotations = path.join(surv_dir, dataset, 'Annotations')
    utils.create_if_nonexist(annotations)
    concepts = 'concepts.txt'

    label_set = set()
    label_images = {}
    image_set = set()
    fin = open(infile)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        image = fields[IMAGE_INDEX]
        labels = fields[LABEL_INDEX].split()
        for label in labels:
            label_set.add(label)
            if label not in label_images:
                label_images[label] = []
            label_images[label].append(image)
        image_set.add(image)
    fin.close()
    fout = open(path.join(annotations, concepts), 'w')
    for label in sorted(label_set):
        fout.write('{}\n'.format(label))
    fout.close()

    concepts_dir = path.join(annotations, 'Image', concepts)
    utils.create_if_nonexist(concepts_dir)
    image_list = sorted(image_set)
    for label in label_set:
        label_filepath = path.join(concepts_dir, '%s.txt' % label)
        fout = open(label_filepath, 'w')
        for image in image_list:
            assessment = -1
            if image in label_images[label]:
                assessment = 1
            fout.write('{} {}\n'.format(image, assessment))
        fout.close()
Beispiel #8
0
def survey_text_data(infile):
    seperator = '###'

    def _get_key(label_i, label_j):
        if label_i < label_j:
            key = label_i + seperator + label_j
        else:
            key = label_j + seperator + label_i
        return key

    def _get_labels(key):
        fields = key.split(seperator)
        label_i, label_j = fields[0], fields[1]
        return label_i, label_j

    dataset = get_dataset(infile)
    text_data = path.join(surv_dir, dataset, 'TextData')
    utils.create_if_nonexist(text_data)

    post_image = {}
    fin = open(infile)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        post = fields[POST_INDEX]
        image = fields[IMAGE_INDEX]
        post_image[post] = image
    fin.close()

    rawtags_file = path.join(text_data, 'id.userid.rawtags.txt')
    fout = open(rawtags_file, 'w')
    fin = open(rawtag_file)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        post = fields[0]
        if post not in post_image:
            continue
        post = fields[POST_INDEX]
        image = post_image[post]
        user = fields[USER_INDEX]
        old_labels = fields[LABEL_INDEX].split(LABEL_SEPERATOR)
        new_labels = []
        for old_label in old_labels:
            old_label = urllib.parse.unquote(old_label)
            old_label = old_label.lower()
            new_label = ''
            for c in old_label:
                if not c.isalnum():
                    continue
                new_label += c
            if len(new_label) == 0:
                continue
            new_labels.append(new_label)
        labels = ' '.join(new_labels)
        fout.write('{}\t{}\t{}\n'.format(image, user, labels))
    fin.close()
    fout.close()

    lemmtags_file = path.join(text_data, 'id.userid.lemmtags.txt')
    fout = open(lemmtags_file, 'w')
    fin = open(rawtags_file)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        old_labels = fields[-1].split(' ')
        new_labels = []
        for old_label in old_labels:
            new_label = lemmatizer.lemmatize(old_label)
            new_labels.append(new_label)
        fields[-1] = ' '.join(new_labels)
        fout.write('{}\n'.format(FIELD_SEPERATOR.join(fields)))
    fin.close()
    fout.close()

    fin = open(lemmtags_file)
    label_users, label_images = {}, {}
    label_set = set()
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        image, user = fields[0], fields[1]
        labels = fields[2].split()
        for label in labels:
            if label not in label_users:
                label_users[label] = set()
            label_users[label].add(user)
            if label not in label_images:
                label_images[label] = set()
            label_images[label].add(image)
            label_set.add(label)
    fin.close()
    tagfreq_file = path.join(text_data, 'lemmtag.userfreq.imagefreq.txt')
    fout = open(tagfreq_file, 'w')
    label_count = {}
    for label in label_set:
        label_count[label] = len(
            label_users[label])  # + len(label_images[label])
    sorted_label_count = sorted(label_count.items(),
                                key=operator.itemgetter(1),
                                reverse=True)
    for label, _ in sorted_label_count:
        userfreq = len(label_users[label])
        imagefreq = len(label_images[label])
        fout.write('{} {} {}\n'.format(label, userfreq, imagefreq))
    fout.close()

    jointfreq_file = path.join(text_data, 'ucij.uuij.icij.iuij.txt')
    min_count = 4
    if not infile.endswith('.valid'):
        min_count = 8
    label_count = {}
    fin = open(lemmtags_file)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        image, user = fields[0], fields[1]
        labels = fields[2].split()
        for label in labels:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
    fin.close()
    jointfreq_icij_init = {}
    fin = open(lemmtags_file)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        image, user = fields[0], fields[1]
        labels = fields[2].split()
        num_label = len(labels)
        for i in range(num_label - 1):
            for j in range(i + 1, num_label):
                label_i = labels[i]
                label_j = labels[j]
                if label_i == label_j:
                    continue
                if label_count[label_i] < min_count:
                    continue
                if label_count[label_j] < min_count:
                    continue
                key = _get_key(label_i, label_j)
                if key not in jointfreq_icij_init:
                    jointfreq_icij_init[key] = 0
                jointfreq_icij_init[key] += 1
    fin.close()
    keys = set()
    icij_images = {}
    iuij_images = {}
    fin = open(lemmtags_file)
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        image, user = fields[0], fields[1]
        labels = fields[2].split()
        num_label = len(labels)
        for i in range(num_label - 1):
            for j in range(i + 1, num_label):
                label_i = labels[i]
                label_j = labels[j]
                if label_i == label_j:
                    continue
                if label_i not in iuij_images:
                    iuij_images[label_i] = set()
                iuij_images[label_i].add(image)
                if label_j not in iuij_images:
                    iuij_images[label_j] = set()
                iuij_images[label_j].add(image)
                if label_count[label_i] < min_count:
                    continue
                if label_count[label_j] < min_count:
                    continue
                key = _get_key(label_i, label_j)
                if jointfreq_icij_init[key] < min_count:
                    continue
                keys.add(key)
                if key not in icij_images:
                    icij_images[key] = set()
                icij_images[key].add(image)
    fin.close()
    jointfreq_icij, jointfreq_iuij = {}, {}
    keys = sorted(keys)
    for key in keys:
        jointfreq_icij[key] = len(icij_images[key])
        label_i, label_j = _get_labels(key)
        label_i_images = iuij_images[label_i]
        label_j_images = iuij_images[label_j]
        jointfreq_iuij[key] = len(label_i_images.union(label_j_images))
    fout = open(jointfreq_file, 'w')
    for key in sorted(keys):
        label_i, label_j = _get_labels(key)
        fout.write('{} {} {} {} {} {}\n'.format(label_i, label_j,
                                                jointfreq_icij[key],
                                                jointfreq_iuij[key],
                                                jointfreq_icij[key],
                                                jointfreq_iuij[key]))
    fout.close()

    fin = open(lemmtags_file)
    vocab = set()
    while True:
        line = fin.readline().strip()
        if not line:
            break
        fields = line.split(FIELD_SEPERATOR)
        image, user = fields[0], fields[1]
        labels = fields[2].split()
        for label in labels:
            if wordnet.synsets(label):
                vocab.add(label)
            else:
                pass
    fin.close()
    vocab_file = path.join(text_data, 'wn.%s.txt' % dataset)
    fout = open(vocab_file, 'w')
    for label in sorted(vocab):
        fout.write('{}\n'.format(label))
    fout.close()
Beispiel #9
0
def create_tfrecord(infile, end_point, is_training=False):
    utils.create_if_nonexist(precomputed_dir)

    num_epoch = flags.num_epoch
    if not is_training:
        num_epoch = 1

    fields = path.basename(infile).split('.')
    dataset, version = fields[0], fields[1]
    filepath = path.join(precomputed_dir, tfrecord_tmpl)

    user_list = []
    file_list = []
    text_list = []
    label_list = []
    fin = open(infile)
    while True:
        line = fin.readline()
        if not line:
            break
        fields = line.strip().split(FIELD_SEPERATOR)
        user = fields[USER_INDEX]
        image = fields[IMAGE_INDEX]
        file = path.join(image_data_dir, '%s.jpg' % image)
        tokens = fields[TEXT_INDEX].split()
        labels = fields[LABEL_INDEX].split()
        user_list.append(user)
        file_list.append(file)
        text_list.append(tokens)
        label_list.append(labels)
    fin.close()

    label_to_id = utils.load_sth_to_id(label_file)
    num_label = len(label_to_id)
    print('#label={}'.format(num_label))
    token_to_id = utils.load_sth_to_id(vocab_file)
    unk_token_id = token_to_id[unk_token]
    vocab_size = len(token_to_id)
    print('#vocab={}'.format(vocab_size))

    reader = ImageReader()
    with tf.Session() as sess:
        init_fn(sess)
        for epoch in range(num_epoch):
            count = 0
            tfrecord_file = filepath.format(dataset, flags.model_name, epoch,
                                            version)
            if path.isfile(tfrecord_file):
                continue
            # print(tfrecord_file)
            # exit()
            with tf.python_io.TFRecordWriter(tfrecord_file) as fout:
                for user, file, text, labels in zip(user_list, file_list,
                                                    text_list, label_list):
                    user = bytes(user, encoding='utf-8')

                    image_np = np.array(Image.open(file))
                    # print(type(image_np), image_np.shape)
                    feed_dict = {image_ph: image_np}
                    image_t, = sess.run([end_point], feed_dict)
                    image_t = image_t.tolist()
                    # print(type(image_t), len(image_t))
                    # exit()

                    text = [
                        token_to_id.get(token, unk_token_id) for token in text
                    ]

                    label_ids = [label_to_id[label] for label in labels]
                    label_vec = np.zeros((num_label, ), dtype=np.int64)
                    label_vec[label_ids] = 1
                    label = label_vec.tolist()

                    file = bytes(file, encoding='utf-8')
                    # print(file)

                    example = build_example(user, image_t, text, label, file)
                    fout.write(example.SerializeToString())
                    count += 1
                    if (count % 500) == 0:
                        print('count={}'.format(count))
Beispiel #10
0
def main(_):
    gen_t = GEN(flags, is_training=True)
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    gen_v = GEN(flags, is_training=False)

    tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate)
    tf.summary.scalar(gen_t.pre_loss.name, gen_t.pre_loss)
    summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t
    user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            for batch_t in range(num_batch_t):
                image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t])
                feed_dict = {
                    gen_t.image_ph: image_np_t,
                    gen_t.hard_label_ph: label_np_t
                }
                _, summary = sess.run([gen_t.pre_update, summary_op],
                                      feed_dict=feed_dict)
                writer.add_summary(summary, batch_t)

                batch = batch_t + 1
                remain = (batch * flags.batch_size) % train_data_size
                epoch = (batch * flags.batch_size) // train_data_size
                if remain == 0:
                    pass
                    # print('%d\t%d\t%d' % (epoch, batch, remain))
                elif (train_data_size - remain) < flags.batch_size:
                    epoch = epoch + 1
                    # print('%d\t%d\t%d' % (epoch, batch, remain))
                else:
                    continue
                # if (batch_t + 1) % eval_interval != 0:
                #     continue

                hit_v = []
                for batch_v in range(num_batch_v):
                    image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v])
                    feed_dict = {gen_v.image_ph: image_np_v}
                    logit_np_v, = sess.run([gen_v.logits], feed_dict=feed_dict)
                    hit_bt = metric.compute_hit(logit_np_v, label_np_v,
                                                flags.cutoff)
                    hit_v.append(hit_bt)
                hit_v = np.mean(hit_v)

                figure_data.append((epoch, hit_v, batch_t))

                if hit_v < best_hit_v:
                    continue
                tot_time = time.time() - start
                best_hit_v = hit_v
                print('#%03d curbst=%.4f time=%.0fs' %
                      (epoch, hit_v, tot_time))
                gen_t.saver.save(sess, flags.gen_model_ckpt)
    print('bsthit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.gen_figure_data))
    fout = open(flags.gen_figure_data, 'w')
    for epoch, hit_v, batch_t in figure_data:
        fout.write('%d\t%.4f\t%d\n' % (epoch, hit_v, batch_t))
    fout.close()
Beispiel #11
0
def main(_):
    print('#label={}'.format(config.num_label))
    tch_t = TCH(flags, is_training=True)
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    tch_v = TCH(flags, is_training=False)

    ts_list_t = utils.decode_tfrecord(config.train_tfrecord, shuffle=True)
    ts_list_v = utils.decode_tfrecord(config.valid_tfrecord, shuffle=False)
    bt_list_t = utils.generate_text_batch(ts_list_t, config.train_batch_size)
    bt_list_v = utils.generate_text_batch(ts_list_v, config.valid_batch_size)
    # check_tfrecord(bt_list_t, config.train_batch_size)
    # check_tfrecord(bt_list_v, config.valid_batch_size)

    user_bt_t, text_bt_t, label_bt_t, image_file_bt_t = bt_list_t
    user_bt_v, text_bt_v, label_bt_v, image_file_bt_v = bt_list_v

    best_hit_v = -np.inf
    init_op = tf.global_variables_initializer()
    start = time.time()
    with tf.Session() as sess:
        writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph())
        sess.run(init_op)

        with slim.queues.QueueRunners(sess):
            for batch_t in range(num_batch_t):
                text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t])
                feed_dict = {tch_t.text_ph:text_np_t, tch_t.label_ph:label_np_t}
                _, summary = sess.run([tch_t.train_op, tch_t.summary_op], feed_dict=feed_dict)
                writer.add_summary(summary, batch_t)

                if (batch_t + 1) %  != 0:
                    continue

                hit_v = []
                image_file_v = set()
                for batch_v in range(num_batch_v):
                    text_np_v, label_np_v, image_file_np_v = sess.run([text_bt_v, label_bt_v, image_file_bt_v])
                    feed_dict = {tch_v.text_ph:text_np_v}
                    logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict)
                    for image_file in image_file_np_v:
                        image_file_v.add(image_file)
                    hit_bt = compute_hit(logit_np_v, label_np_v, flags.cutoff)
                    hit_v.append(hit_bt)
                hit_v = np.mean(hit_v)

                total_time = time.time() - start
                avg_batch = total_time / (batch_t + 1)
                avg_epoch = avg_batch * (config.train_data_size / config.train_batch_size)
                s = '{0} hit={1:.4f} tot={2:.0f}s avg={3:.0f}s'
                s = s.format(batch_t, hit_v, total_time, avg_epoch)
                print(s)

                if hit_v < best_hit_v:
                    continue
                best_hit_v = hit_v
                ckpt_file = path.join(config.ckpt_dir, 'tch.ckpt')
                tch_t.saver.save(sess, ckpt_file)
    utils.create_if_nonexist(config.temp_dir)
    hit_file = path.join(config.temp_dir, 'tch.hit')
    with open(hit_file, 'w') as fout:
        fout.write('{0:.4f}'.format(best_hit_v))
Beispiel #12
0
def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    dis_summary_op = tf.summary.merge([
        tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate),
        tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss),
    ])
    gen_summary_op = tf.summary.merge([
        tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate),
        tf.summary.scalar(gen_t.gan_loss.name, gen_t.gan_loss),
    ])
    print(type(dis_summary_op), type(gen_summary_op))
    init_op = tf.global_variables_initializer()

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size)
    user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d

    ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size)
    user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g

    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        dis_t.saver.restore(sess, flags.dis_model_ckpt)
        gen_t.saver.restore(sess, flags.gen_model_ckpt)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
            print('init hit=%.4f' % (hit_v))

            batch_d, batch_g = -1, -1
            for epoch in range(flags.num_epoch):
                for dis_epoch in range(flags.num_dis_epoch):
                    print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                    num_batch_d = math.ceil(train_data_size / flags.batch_size)
                    for _ in range(num_batch_d):
                        batch_d += 1
                        image_np_d, label_dat_d = sess.run(
                            [image_bt_d, label_bt_d])
                        feed_dict = {gen_t.image_ph: image_np_d}
                        label_gen_d, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_np_d, label_np_d = utils.gan_dis_sample(
                            flags, label_dat_d, label_gen_d)
                        feed_dict = {
                            dis_t.image_ph: image_np_d,
                            dis_t.sample_ph: sample_np_d,
                            dis_t.dis_label_ph: label_np_d,
                        }
                        _, summary_d = sess.run(
                            [dis_t.gan_update, dis_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_d, batch_d)

                for gen_epoch in range(flags.num_gen_epoch):
                    print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                    num_batch_g = math.ceil(train_data_size / flags.batch_size)
                    for _ in range(num_batch_g):
                        batch_g += 1
                        image_np_g, label_dat_g = sess.run(
                            [image_bt_g, label_bt_g])
                        feed_dict = {gen_t.image_ph: image_np_g}
                        label_gen_g, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_np_g = utils.generate_label(
                            flags, label_dat_g, label_gen_g)
                        feed_dict = {
                            dis_t.image_ph: image_np_g,
                            dis_t.sample_ph: sample_np_g,
                        }
                        reward_np_g, = sess.run([dis_t.rewards],
                                                feed_dict=feed_dict)
                        feed_dict = {
                            gen_t.image_ph: image_np_g,
                            gen_t.sample_ph: sample_np_g,
                            gen_t.reward_ph: reward_np_g,
                        }
                        _, summary_g = sess.run(
                            [gen_t.gan_update, gen_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_g, batch_g)

                        # if (batch_g + 1) % eval_interval != 0:
                        #   continue
                        # hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v)
                        # tot_time = time.time() - start
                        # print('#%08d hit=%.4f %06ds' % (batch_g, hit_v, int(tot_time)))
                        # if hit_v <= best_hit_v:
                        #   continue
                        # best_hit_v = hit_v
                        # print('best hit=%.4f' % (best_hit_v))
                hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                tot_time = time.time() - start
                print('#%03d curbst=%.4f %.0fs' % (epoch, hit_v, tot_time))
                figure_data.append((epoch, hit_v))
                if hit_v <= best_hit_v:
                    continue
                best_hit_v = hit_v
    print('bsthit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.gan_figure_data))
    fout = open(flags.gan_figure_data, 'w')
    for epoch, hit_v in figure_data:
        fout.write('%d\t%.4f\n' % (epoch, hit_v))
    fout.close()
Beispiel #13
0
def main(_):
  bst_gen_acc, bst_tch_acc, bst_eph = 0.0, 0.0, 0
  utils.create_if_nonexist(flags.gradient_dir)
  if flags.log_accuracy:
    acc_history = []
  if flags.evaluate_tch:
    tch_history = []
  with tf.train.MonitoredTrainingSession() as sess:
    sess.run(init_op)
    tn_dis.saver.restore(sess, flags.dis_model_ckpt)
    tn_gen.saver.restore(sess, flags.gen_model_ckpt)
    tn_tch.saver.restore(sess, flags.tch_model_ckpt)

    feed_dict = {
      vd_dis.image_ph:dis_mnist.test.images,
      vd_dis.hard_label_ph:dis_mnist.test.labels,
    }
    ini_dis = sess.run(vd_dis.accuracy, feed_dict=feed_dict)
    feed_dict = {
      vd_gen.image_ph:gen_mnist.test.images,
      vd_gen.hard_label_ph:gen_mnist.test.labels,
    }
    ini_gen = sess.run(vd_gen.accuracy, feed_dict=feed_dict)
    print('ini dis=%.4f ini gen=%.4f' % (ini_dis, ini_gen))
    # exit()

    start = time.time()
    batch_d, batch_g, batch_t = -1, -1, -1
    gumbel_times = (math.log(flags.gumbel_end_temperature / flags.gumbel_temperature) 
        / math.log(flags.gumbel_temperature_decay_factor))
    for epoch in range(flags.num_epoch):
      for dis_epoch in range(flags.num_dis_epoch):
        # print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
        # num_batch_d = math.ceil(tn_size / flags.batch_size)
        # for _ in range(num_batch_d):
        #   image_d, label_dat_d = dis_mnist.train.next_batch(flags.batch_size)
        for image_d, label_dat_d in dis_datagen.generate(batch_size=flags.batch_size):
          batch_d += 1

          feed_dict = {tn_gen.image_ph:image_d}
          label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict)
          sample_gen_d, gen_label_d = utils.gan_dis_sample(flags, label_dat_d, label_gen_d)

          feed_dict = {tn_tch.image_ph:image_d}
          label_tch_d = sess.run(tn_tch.labels, feed_dict=feed_dict)
          sample_tch_d, tch_label_d = utils.gan_dis_sample(flags, label_dat_d, label_tch_d)
          
          feed_dict = {
            tn_dis.image_ph:image_d,
            tn_dis.gen_sample_ph:sample_gen_d,
            tn_dis.gen_label_ph:gen_label_d,
            tn_dis.tch_sample_ph:sample_tch_d,
            tn_dis.tch_label_ph:tch_label_d,
          }
          sess.run(tn_dis.gan_update, feed_dict=feed_dict)

      for tch_epoch in range(flags.num_tch_epoch):
        # num_batch_t = math.ceil(tn_size / flags.batch_size)
        # for _ in range(num_batch_t):
        #   image_t, label_dat_t = tch_mnist.train.next_batch(flags.batch_size)
        for image_t, label_dat_t in tch_datagen.generate(batch_size=flags.batch_size):
          batch_t += 1

          feed_dict = {tn_tch.image_ph:image_t}
          label_tch_t = sess.run(tn_tch.labels, feed_dict=feed_dict)
          sample_t = utils.generate_label(flags, label_dat_t, label_tch_t)
          feed_dict = {
            tn_dis.image_ph:image_t,
            tn_dis.tch_sample_ph:sample_t,
          }
          reward_t = sess.run(tn_dis.tch_rewards, feed_dict=feed_dict)

          feed_dict = {vd_gen.image_ph:image_t}
          soft_logit_t = sess.run(vd_gen.logits, feed_dict=feed_dict)
          feed_dict = {
            tn_tch.image_ph:image_t,
            tn_tch.sample_ph:sample_t,
            tn_tch.reward_ph:reward_t,
            tn_tch.hard_label_ph:label_dat_t,
            tn_tch.soft_logit_ph:soft_logit_t,
          }
          
          sess.run(tn_tch.kdgan_update, feed_dict=feed_dict)

          if not flags.evaluate_tch:
            continue
          if (batch_t + 1) % eval_interval != 0:
            continue
          feed_dict = {
            vd_tch.image_ph:gen_mnist.test.images,
            vd_tch.hard_label_ph:gen_mnist.test.labels,
          }
          tch_acc = sess.run(vd_tch.accuracy, feed_dict=feed_dict)
          bst_tch_acc = max(tch_acc, bst_tch_acc)
          print('#%08d tchcur=%.4f tchbst=%.4f' % (batch_t, tch_acc, bst_tch_acc))
          tch_history.append(tch_acc)

      #### gumbel softmax
      if flags.enable_gumbel:
        if (epoch + 1) % max(int(flags.num_epoch / gumbel_times), 1) == 0:
          sess.run(tn_gen.gt_update)

      for gen_epoch in range(flags.num_gen_epoch):
        batch = -1
        # num_batch_g = math.ceil(tn_size / flags.batch_size)
        # for _ in range(num_batch_g):
        #   image_g, label_dat_g = gen_mnist.train.next_batch(flags.batch_size)
        for image_g, label_dat_g in gen_datagen.generate(batch_size=flags.batch_size):
          batch_g += 1
          batch += 1
          epk_bat = '%d.%d' % (epoch*flags.num_gen_epoch+gen_epoch, batch)
          ggrads_file = path.join(flags.gradient_dir, 'kdgan_ggrads.%s.p' % epk_bat)
          kgrads_file = path.join(flags.gradient_dir, 'kdgan_kgrads.%s.p' % epk_bat)

          feed_dict = {tn_gen.image_ph:image_g}

          if not flags.enable_gumbel:
            label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict)
          else:
            label_gen_g = sess.run(tn_gen.gumbel_labels, feed_dict=feed_dict)

          sample_g = utils.generate_label(flags, label_dat_g, label_gen_g)
          feed_dict = {
            tn_dis.image_ph:image_g,
            tn_dis.gen_sample_ph:sample_g,
          }
          reward_g = sess.run(tn_dis.gen_rewards, feed_dict=feed_dict)
          # reward_g[reward_g>0.5] = 0.7
          # reward_g[reward_g<0.5] = 0.3

          feed_dict = {vd_tch.image_ph:image_g}
          soft_logit_g = sess.run(vd_tch.logits, feed_dict=feed_dict)
          # print(sample_g.shape, reward_g.shape, image_g.shape, soft_logit_g.shape)
          # exit()

          feed_dict = {
            tn_gen.image_ph:image_g,
            tn_gen.sample_ph:sample_g,
            tn_gen.reward_ph:reward_g,
            tn_gen.hard_label_ph:label_dat_g,
            tn_gen.soft_logit_ph:soft_logit_g,
          }
          # sess.run(tn_gen.kdgan_update, feed_dict=feed_dict)
          if flags.log_gradient:
            fetches = [tn_gen.kdgan_ggrads, tn_gen.kdgan_kgrads, tn_gen.kdgan_update]
            kdgan_ggrads, kdgan_kgrads, _ = sess.run(fetches, feed_dict=feed_dict)
            pickle.dump(kdgan_ggrads, open(ggrads_file, 'wb'))
            pickle.dump(kdgan_kgrads, open(kgrads_file, 'wb'))
          else:
            sess.run(tn_gen.kdgan_update, feed_dict=feed_dict)

          if flags.log_accuracy:
            feed_dict = {
              vd_gen.image_ph:gen_mnist.test.images,
              vd_gen.hard_label_ph:gen_mnist.test.labels,
            }
            acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict)
            acc_history.append(acc)
            if (batch_g + 1) % eval_interval != 0:
              continue
          else:
            if (batch_g + 1) % eval_interval != 0:
              continue
            feed_dict = {
              vd_gen.image_ph:gen_mnist.test.images,
              vd_gen.hard_label_ph:gen_mnist.test.labels,
            }
            acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict)

          if acc > bst_gen_acc:
            bst_gen_acc = max(acc, bst_gen_acc)
            bst_eph = epoch
          tot_time = time.time() - start
          global_step = sess.run(tn_gen.global_step)
          # avg_time = (tot_time / global_step) * (tn_size / flags.batch_size)
          if flags.evaluate_tch:
            gen_tch_pct =100 * bst_gen_acc / bst_tch_acc
            print('#%08d/%08d gencur=%.4f genbst=%.4f (%.2f) tot=%.0fs' % 
                (batch_g, tot_batch, acc, bst_gen_acc, gen_tch_pct, tot_time))
          else:
            print('#%08d/%08d gencur=%.4f genbst=%.4f tot=%.0fs' % 
                (batch_g, tot_batch, acc, bst_gen_acc, tot_time))

          stdout.flush()
          if acc <= bst_gen_acc:
            continue
          # save gen parameters if necessary
    gumbel_temperature = sess.run(tn_gen.gumbel_temperature)
    print('gumbel_temperature=%.4f' % gumbel_temperature)
  tot_time = time.time() - start
  bst_gen_acc *= 100
  bst_eph += 1
  print('#mnist=%d kdgan@%d=%.2f et=%.0fs' % (tn_size, bst_eph, bst_gen_acc, tot_time))

  if flags.log_accuracy:
    utils.create_pardir(flags.acc_file)
    pickle.dump(acc_history, open(flags.acc_file, 'wb'))

  if flags.evaluate_tch:
    utils.create_pardir(flags.tch_file)
    pickle.dump(tch_history, open(flags.tch_file, 'wb'))
Beispiel #14
0
def create_test_set():
    utils.create_if_nonexist(precomputed_dir)

    user_list = []
    file_list = []
    text_list = []
    label_list = []
    fin = open(valid_file)
    valid_size = 0
    while True:
        line = fin.readline()
        if not line:
            break
        fields = line.strip().split(FIELD_SEPERATOR)
        user = fields[USER_INDEX]
        image = fields[IMAGE_INDEX]
        file = path.join(image_data_dir, '%s.jpg' % image)
        tokens = fields[TEXT_INDEX].split()
        labels = fields[LABEL_INDEX].split()
        user_list.append(user)
        file_list.append(file)
        text_list.append(tokens)
        label_list.append(labels)
        valid_size += 1
    fin.close()

    label_to_id = utils.load_sth_to_id(label_file)
    num_label = len(label_to_id)
    print('#label={}'.format(num_label))
    token_to_id = utils.load_sth_to_id(vocab_file)
    unk_token_id = token_to_id[config.unk_token]
    vocab_size = len(token_to_id)
    print('#vocab={}'.format(vocab_size))

    image_npy = np.zeros((valid_size, 4096), dtype=np.float32)
    label_npy = np.zeros((valid_size, 100), dtype=np.int32)
    imgid_npy = []
    text_npy = []
    reader = ImageReader()
    with tf.Session() as sess:
        init_fn(sess)
        for i, (user, file, text, labels) in enumerate(
                zip(user_list, file_list, text_list, label_list)):
            user = bytes(user, encoding='utf-8')

            image_np = np.array(Image.open(file))
            # print(type(image_np), image_np.shape)
            feed_dict = {image_ph: image_np}
            image, = sess.run([end_point_v], feed_dict)
            image = image.tolist()
            # print(image)
            # print(type(image), len(image))
            image_npy[i, :] = image
            # print(image_npy)
            # input()

            text = [token_to_id.get(token, unk_token_id) for token in text]
            text_npy.append(text)

            label_ids = [label_to_id[label] for label in labels]
            label_vec = np.zeros((num_label, ), dtype=np.int32)
            label_vec[label_ids] = 1
            label = label_vec.tolist()
            label_npy[i, :] = label

            image_id = path.basename(file).split('.')[0]
            imgid_npy.append(image_id)
            # example = build_example(user, image, text, label, file)

    imgid_npy = np.asarray(imgid_npy)
    filename_tmpl = 'yfcc10k_%s.valid.%s'
    np.save(
        path.join(precomputed_dir,
                  filename_tmpl % (flags.model_name, 'image')), image_npy)
    np.save(
        path.join(precomputed_dir,
                  filename_tmpl % (flags.model_name, 'label')), label_npy)
    np.save(
        path.join(precomputed_dir,
                  filename_tmpl % (flags.model_name, 'imgid')), imgid_npy)

    def padding(x):
        x = np.array(x)
        print(x.shape)
        max_length = max(len(row) for row in x)
        x = np.array([row + [0] * (max_length - len(row)) for row in x])
        print(x.shape)
        return x

    text_npy = padding(text_npy)
    np.save(
        path.join(precomputed_dir, filename_tmpl % (flags.model_name, 'text')),
        text_npy)
Beispiel #15
0
DESC_INDEX = 4
LABEL_INDEX = -1

FIELD_SEPERATOR = '\t'
EXPECTED_NUM_FIELD = 6

MIN_RND_LABEL = 10
NUM_RND_LABEL = 250
MIN_RND_POST = MIN_RND_LABEL
NUM_RND_POST = 10000
TRAIN_DATA_RATIO = 0.80
SHUFFLE_SEED = 100

dataset = 'yfcc_rnd'
dataset_dir = config.yfcc_rnd_dir
utils.create_if_nonexist(dataset_dir)
raw_file = path.join(dataset_dir, '%s.raw' % dataset)
data_file = path.join(dataset_dir, '%s.data' % dataset)
train_file = path.join(dataset_dir, '%s.train' % dataset)
valid_file = path.join(dataset_dir, '%s.valid' % dataset)
label_file = path.join(dataset_dir, '%s.label' % dataset)
vocab_file = path.join(dataset_dir, '%s.vocab' % dataset)
image_data_dir = path.join(dataset_dir, 'ImageData')

################################################################
#
# create kdgan data
#
################################################################

Beispiel #16
0
def main(_):
    for variable in tf.trainable_variables():
        num_params = 1
        for dim in variable.shape:
            num_params *= dim.value
        print('%-50s (%d params)' % (variable.name, num_params))

    dis_summary_op = tf.summary.merge([
        tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate),
        tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss),
    ])
    gen_summary_op = tf.summary.merge([
        tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate),
        tf.summary.scalar(gen_t.kdgan_loss.name, gen_t.kdgan_loss),
    ])
    tch_summary_op = tf.summary.merge([
        tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate),
        tf.summary.scalar(tch_t.kdgan_loss.name, tch_t.kdgan_loss),
    ])
    init_op = tf.global_variables_initializer()

    data_sources_t = utils.get_data_sources(flags, is_training=True)
    data_sources_v = utils.get_data_sources(flags, is_training=False)
    print('tn: #tfrecord=%d\nvd: #tfrecord=%d' %
          (len(data_sources_t), len(data_sources_v)))

    ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size)
    user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d

    ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size)
    user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g

    ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True)
    bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size)
    user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t

    ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False)
    bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size)

    figure_data = []
    best_hit_v = -np.inf
    start = time.time()
    with tf.Session() as sess:
        sess.run(init_op)
        dis_t.saver.restore(sess, flags.dis_model_ckpt)
        gen_t.saver.restore(sess, flags.gen_model_ckpt)
        tch_t.saver.restore(sess, flags.tch_model_ckpt)
        writer = tf.summary.FileWriter(config.logs_dir,
                                       graph=tf.get_default_graph())
        with slim.queues.QueueRunners(sess):
            gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
            tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v)
            print('hit gen=%.4f tch=%.4f' % (gen_hit, tch_hit))

            batch_d, batch_g, batch_t = -1, -1, -1
            for epoch in range(flags.num_epoch):
                for dis_epoch in range(flags.num_dis_epoch):
                    print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch))
                    for _ in range(num_batch_per_epoch):
                        #continue
                        batch_d += 1
                        image_d, text_d, label_dat_d = sess.run(
                            [image_bt_d, text_bt_d, label_bt_d])

                        feed_dict = {gen_t.image_ph: image_d}
                        label_gen_d, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        # print('gen label', label_gen_d.shape)
                        feed_dict = {
                            tch_t.text_ph: text_d,
                            tch_t.image_ph: image_d
                        }
                        label_tch_d, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        # print('tch label', label_tch_d.shape)

                        sample_d, label_d = utils.kdgan_dis_sample(
                            flags, label_dat_d, label_gen_d, label_tch_d)
                        # print(sample_d.shape, label_d.shape)

                        feed_dict = {
                            dis_t.image_ph: image_d,
                            dis_t.sample_ph: sample_d,
                            dis_t.dis_label_ph: label_d,
                        }
                        _, summary_d = sess.run(
                            [dis_t.gan_update, dis_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_d, batch_d)

                for tch_epoch in range(flags.num_tch_epoch):
                    print('epoch %03d tch_epoch %03d' % (epoch, tch_epoch))
                    for _ in range(num_batch_per_epoch):
                        #continue
                        batch_t += 1
                        image_t, text_t, label_dat_t = sess.run(
                            [image_bt_t, text_bt_t, label_bt_t])

                        feed_dict = {
                            tch_t.text_ph: text_t,
                            tch_t.image_ph: image_t
                        }
                        label_tch_t, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        sample_t = utils.generate_label(
                            flags, label_dat_t, label_tch_t)
                        feed_dict = {
                            dis_t.image_ph: image_t,
                            dis_t.sample_ph: sample_t,
                        }
                        reward_t, = sess.run([dis_t.rewards],
                                             feed_dict=feed_dict)

                        feed_dict = {
                            gen_t.image_ph: image_t,
                        }
                        label_gen_g = sess.run(gen_t.logits,
                                               feed_dict=feed_dict)
                        #print(len(label_dat_t), len(label_dat_t[0]))
                        #exit()
                        feed_dict = {
                            tch_t.text_ph: text_t,
                            tch_t.image_ph: image_t,
                            tch_t.sample_ph: sample_t,
                            tch_t.reward_ph: reward_t,
                            tch_t.hard_label_ph: label_dat_t,
                            tch_t.soft_label_ph: label_gen_g,
                        }

                        _, summary_t, tch_kdgan_loss = sess.run(
                            [
                                tch_t.kdgan_update, tch_summary_op,
                                tch_t.kdgan_loss
                            ],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_t, batch_t)
                        #print("teacher kdgan loss:", tch_kdgan_loss)

                for gen_epoch in range(flags.num_gen_epoch):
                    print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch))
                    for _ in range(num_batch_per_epoch):
                        batch_g += 1
                        image_g, text_g, label_dat_g = sess.run(
                            [image_bt_g, text_bt_g, label_bt_g])

                        feed_dict = {
                            tch_t.text_ph: text_g,
                            tch_t.image_ph: image_g
                        }
                        label_tch_g, = sess.run([tch_t.labels],
                                                feed_dict=feed_dict)
                        # print('tch label {}'.format(label_tch_g.shape))

                        feed_dict = {gen_t.image_ph: image_g}
                        label_gen_g, = sess.run([gen_t.labels],
                                                feed_dict=feed_dict)
                        sample_g = utils.generate_label(
                            flags, label_dat_g, label_gen_g)
                        feed_dict = {
                            dis_t.image_ph: image_g,
                            dis_t.sample_ph: sample_g,
                        }
                        reward_g, = sess.run([dis_t.rewards],
                                             feed_dict=feed_dict)

                        feed_dict = {
                            gen_t.image_ph: image_g,
                            gen_t.hard_label_ph: label_dat_g,
                            gen_t.soft_label_ph: label_tch_g,
                            gen_t.sample_ph: sample_g,
                            gen_t.reward_ph: reward_g,
                        }
                        _, summary_g = sess.run(
                            [gen_t.kdgan_update, gen_summary_op],
                            feed_dict=feed_dict)
                        writer.add_summary(summary_g, batch_g)

                        # if (batch_g + 1) % eval_interval != 0:
                        #     continue
                        # gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                        # tot_time = time.time() - start
                        # print('#%08d hit=%.4f %06ds' % (batch_g, gen_hit, int(tot_time)))
                        # if gen_hit <= best_hit_v:
                        #   continue
                        # best_hit_v = gen_hit
                        # print('best hit=%.4f' % (best_hit_v))
                gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v)
                tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v)

                tot_time = time.time() - start
                print('#%03d curgen=%.4f curtch=%.4f %.0fs' %
                      (epoch, gen_hit, tch_hit, tot_time))
                figure_data.append((epoch, gen_hit, tch_hit))
                if gen_hit <= best_hit_v:
                    continue
                best_hit_v = gen_hit
                print("epoch ", epoch + 1, ":, new best validation hit:",
                      best_hit_v, "saving...")
                gen_t.saver.save(sess,
                                 flags.kdgan_model_ckpt,
                                 global_step=epoch + 1)
                print("finish saving")

    print('best hit=%.4f' % (best_hit_v))

    utils.create_if_nonexist(os.path.dirname(flags.kdgan_figure_data))
    fout = open(flags.kdgan_figure_data, 'w')
    for epoch, gen_hit, tch_hit in figure_data:
        fout.write('%d\t%.4f\t%.4f\n' % (epoch, gen_hit, tch_hit))
    fout.close()