def collect_image(infile, outdir): post_image = {} fin = open(infile) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) post = fields[POST_INDEX] image = fields[IMAGE_INDEX] post_image[post] = image fin.close() utils.create_if_nonexist(outdir) fin = open(sample_file) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) post = fields[POST_INDEX] if post not in post_image: continue image_url = fields[IMAGE_INDEX] src_file = get_image_path(image_dir, image_url) image = post_image[post] dst_file = path.join(outdir, '%s.jpg' % image) if path.isfile(dst_file): continue shutil.copyfile(src_file, dst_file)
def main(_): check_num_field() utils.create_if_nonexist(dataset_dir) if not utils.skip_if_exist(label_file): print('select top labels') select_top_label() if not utils.skip_if_exist(raw_file): print('select posts') select_posts() if not utils.skip_if_exist(data_file): print('tokenize dataset') tokenize_dataset() count_dataset() if (not utils.skip_if_exist(train_file) or not utils.skip_if_exist( valid_file or not utils.skip_if_exist(vocab_file))): # if True: print('split dataset') split_dataset() # if path.isdir(image_dir): if False: print('collect images') # find ImageData/ -type f | wc -l collect_image(data_file, image_data_dir) # create_survey_data() create_tfrecord(valid_file, end_point_v, is_training=False) create_tfrecord(train_file, end_point_t, is_training=True)
def main(_): if flags.overwrite: print('create yfcc small rnd dataset') utils.delete_if_exist(dataset_dir) utils.create_if_nonexist(dataset_dir) check_num_field() if flags.overwrite or (not utils.skip_if_exist(raw_file)): while True: print('random labels and posts') select_rnd_label() min_count = select_posts() if min_count < MIN_RND_POST: continue break if flags.overwrite or (not utils.skip_if_exist(data_file)): print('tokenize and collect images') tokenize_dataset() collect_image(data_file, image_data_dir) if (flags.overwrite or not utils.skip_if_exist(train_file) or not utils.skip_if_exist(valid_file) or not utils.skip_if_exist(vocab_file)): while True: print('split into train and valid') try: split_dataset() break except: continue if flags.baseline: print('create survey data') create_survey_data()
def count_dataset(): utils.create_if_nonexist(temp_dir) user_count = {} fin = open(data_file) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) user = fields[USER_INDEX] if user not in user_count: user_count[user] = 0 user_count[user] += 1 fin.close() sorted_user_count = sorted(user_count.items(), key=operator.itemgetter(0), reverse=True) outfile = path.join(temp_dir, 'user_count') with open(outfile, 'w') as fout: for user, count in sorted_user_count: fout.write('{}\t{}\n'.format(user, count)) label_count = {} fin = open(data_file) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) user = fields[USER_INDEX] labels = fields[LABEL_INDEX].split() assert len(labels) != 0 for label in labels: if label not in label_count: label_count[label] = 0 label_count[label] += 1 fin.close() sorted_label_count = sorted(label_count.items(), key=operator.itemgetter(0), reverse=True) outfile = path.join(temp_dir, 'label_count') labels, lemms = set(), set() with open(outfile, 'w') as fout: for label, count in sorted_label_count: labels.add(label) lemm = lemmatizer.lemmatize(label) lemms.add(lemm) if lemm != label: print('{}->{}'.format(lemm, label)) fout.write('{}\t{}\n'.format(label, count)) print('#label={} #lemm={}'.format(len(labels), len(lemms)))
def survey_image_data(infile): dataset = get_dataset(infile) image_data = path.join(surv_dir, dataset, 'ImageData') utils.create_if_nonexist(image_data) fout = open(path.join(image_data, '%s.txt' % dataset), 'w') fin = open(infile) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) image = fields[IMAGE_INDEX] image_file = '%s.jpg' % image fout.write('{}\n'.format(image_file)) fin.close() fout.close() collect_image(infile, image_data)
def survey_feature_sets(infile): dataset = get_dataset(infile) image_sets = path.join(surv_dir, dataset, 'ImageSets') utils.create_if_nonexist(image_sets) fout = open(path.join(image_sets, '%s.txt' % dataset), 'w') fin = open(infile) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) image = fields[IMAGE_INDEX] fout.write('{}\n'.format(image)) fin.close() fout.close() fout = open(path.join(image_sets, 'holdout.txt'), 'w') fout.close()
def survey_annotations(infile): dataset = get_dataset(infile) annotations = path.join(surv_dir, dataset, 'Annotations') utils.create_if_nonexist(annotations) concepts = 'concepts.txt' label_set = set() label_images = {} image_set = set() fin = open(infile) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) image = fields[IMAGE_INDEX] labels = fields[LABEL_INDEX].split() for label in labels: label_set.add(label) if label not in label_images: label_images[label] = [] label_images[label].append(image) image_set.add(image) fin.close() fout = open(path.join(annotations, concepts), 'w') for label in sorted(label_set): fout.write('{}\n'.format(label)) fout.close() concepts_dir = path.join(annotations, 'Image', concepts) utils.create_if_nonexist(concepts_dir) image_list = sorted(image_set) for label in label_set: label_filepath = path.join(concepts_dir, '%s.txt' % label) fout = open(label_filepath, 'w') for image in image_list: assessment = -1 if image in label_images[label]: assessment = 1 fout.write('{} {}\n'.format(image, assessment)) fout.close()
def survey_text_data(infile): seperator = '###' def _get_key(label_i, label_j): if label_i < label_j: key = label_i + seperator + label_j else: key = label_j + seperator + label_i return key def _get_labels(key): fields = key.split(seperator) label_i, label_j = fields[0], fields[1] return label_i, label_j dataset = get_dataset(infile) text_data = path.join(surv_dir, dataset, 'TextData') utils.create_if_nonexist(text_data) post_image = {} fin = open(infile) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) post = fields[POST_INDEX] image = fields[IMAGE_INDEX] post_image[post] = image fin.close() rawtags_file = path.join(text_data, 'id.userid.rawtags.txt') fout = open(rawtags_file, 'w') fin = open(rawtag_file) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) post = fields[0] if post not in post_image: continue post = fields[POST_INDEX] image = post_image[post] user = fields[USER_INDEX] old_labels = fields[LABEL_INDEX].split(LABEL_SEPERATOR) new_labels = [] for old_label in old_labels: old_label = urllib.parse.unquote(old_label) old_label = old_label.lower() new_label = '' for c in old_label: if not c.isalnum(): continue new_label += c if len(new_label) == 0: continue new_labels.append(new_label) labels = ' '.join(new_labels) fout.write('{}\t{}\t{}\n'.format(image, user, labels)) fin.close() fout.close() lemmtags_file = path.join(text_data, 'id.userid.lemmtags.txt') fout = open(lemmtags_file, 'w') fin = open(rawtags_file) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) old_labels = fields[-1].split(' ') new_labels = [] for old_label in old_labels: new_label = lemmatizer.lemmatize(old_label) new_labels.append(new_label) fields[-1] = ' '.join(new_labels) fout.write('{}\n'.format(FIELD_SEPERATOR.join(fields))) fin.close() fout.close() fin = open(lemmtags_file) label_users, label_images = {}, {} label_set = set() while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) image, user = fields[0], fields[1] labels = fields[2].split() for label in labels: if label not in label_users: label_users[label] = set() label_users[label].add(user) if label not in label_images: label_images[label] = set() label_images[label].add(image) label_set.add(label) fin.close() tagfreq_file = path.join(text_data, 'lemmtag.userfreq.imagefreq.txt') fout = open(tagfreq_file, 'w') label_count = {} for label in label_set: label_count[label] = len( label_users[label]) # + len(label_images[label]) sorted_label_count = sorted(label_count.items(), key=operator.itemgetter(1), reverse=True) for label, _ in sorted_label_count: userfreq = len(label_users[label]) imagefreq = len(label_images[label]) fout.write('{} {} {}\n'.format(label, userfreq, imagefreq)) fout.close() jointfreq_file = path.join(text_data, 'ucij.uuij.icij.iuij.txt') min_count = 4 if not infile.endswith('.valid'): min_count = 8 label_count = {} fin = open(lemmtags_file) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) image, user = fields[0], fields[1] labels = fields[2].split() for label in labels: if label not in label_count: label_count[label] = 0 label_count[label] += 1 fin.close() jointfreq_icij_init = {} fin = open(lemmtags_file) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) image, user = fields[0], fields[1] labels = fields[2].split() num_label = len(labels) for i in range(num_label - 1): for j in range(i + 1, num_label): label_i = labels[i] label_j = labels[j] if label_i == label_j: continue if label_count[label_i] < min_count: continue if label_count[label_j] < min_count: continue key = _get_key(label_i, label_j) if key not in jointfreq_icij_init: jointfreq_icij_init[key] = 0 jointfreq_icij_init[key] += 1 fin.close() keys = set() icij_images = {} iuij_images = {} fin = open(lemmtags_file) while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) image, user = fields[0], fields[1] labels = fields[2].split() num_label = len(labels) for i in range(num_label - 1): for j in range(i + 1, num_label): label_i = labels[i] label_j = labels[j] if label_i == label_j: continue if label_i not in iuij_images: iuij_images[label_i] = set() iuij_images[label_i].add(image) if label_j not in iuij_images: iuij_images[label_j] = set() iuij_images[label_j].add(image) if label_count[label_i] < min_count: continue if label_count[label_j] < min_count: continue key = _get_key(label_i, label_j) if jointfreq_icij_init[key] < min_count: continue keys.add(key) if key not in icij_images: icij_images[key] = set() icij_images[key].add(image) fin.close() jointfreq_icij, jointfreq_iuij = {}, {} keys = sorted(keys) for key in keys: jointfreq_icij[key] = len(icij_images[key]) label_i, label_j = _get_labels(key) label_i_images = iuij_images[label_i] label_j_images = iuij_images[label_j] jointfreq_iuij[key] = len(label_i_images.union(label_j_images)) fout = open(jointfreq_file, 'w') for key in sorted(keys): label_i, label_j = _get_labels(key) fout.write('{} {} {} {} {} {}\n'.format(label_i, label_j, jointfreq_icij[key], jointfreq_iuij[key], jointfreq_icij[key], jointfreq_iuij[key])) fout.close() fin = open(lemmtags_file) vocab = set() while True: line = fin.readline().strip() if not line: break fields = line.split(FIELD_SEPERATOR) image, user = fields[0], fields[1] labels = fields[2].split() for label in labels: if wordnet.synsets(label): vocab.add(label) else: pass fin.close() vocab_file = path.join(text_data, 'wn.%s.txt' % dataset) fout = open(vocab_file, 'w') for label in sorted(vocab): fout.write('{}\n'.format(label)) fout.close()
def create_tfrecord(infile, end_point, is_training=False): utils.create_if_nonexist(precomputed_dir) num_epoch = flags.num_epoch if not is_training: num_epoch = 1 fields = path.basename(infile).split('.') dataset, version = fields[0], fields[1] filepath = path.join(precomputed_dir, tfrecord_tmpl) user_list = [] file_list = [] text_list = [] label_list = [] fin = open(infile) while True: line = fin.readline() if not line: break fields = line.strip().split(FIELD_SEPERATOR) user = fields[USER_INDEX] image = fields[IMAGE_INDEX] file = path.join(image_data_dir, '%s.jpg' % image) tokens = fields[TEXT_INDEX].split() labels = fields[LABEL_INDEX].split() user_list.append(user) file_list.append(file) text_list.append(tokens) label_list.append(labels) fin.close() label_to_id = utils.load_sth_to_id(label_file) num_label = len(label_to_id) print('#label={}'.format(num_label)) token_to_id = utils.load_sth_to_id(vocab_file) unk_token_id = token_to_id[unk_token] vocab_size = len(token_to_id) print('#vocab={}'.format(vocab_size)) reader = ImageReader() with tf.Session() as sess: init_fn(sess) for epoch in range(num_epoch): count = 0 tfrecord_file = filepath.format(dataset, flags.model_name, epoch, version) if path.isfile(tfrecord_file): continue # print(tfrecord_file) # exit() with tf.python_io.TFRecordWriter(tfrecord_file) as fout: for user, file, text, labels in zip(user_list, file_list, text_list, label_list): user = bytes(user, encoding='utf-8') image_np = np.array(Image.open(file)) # print(type(image_np), image_np.shape) feed_dict = {image_ph: image_np} image_t, = sess.run([end_point], feed_dict) image_t = image_t.tolist() # print(type(image_t), len(image_t)) # exit() text = [ token_to_id.get(token, unk_token_id) for token in text ] label_ids = [label_to_id[label] for label in labels] label_vec = np.zeros((num_label, ), dtype=np.int64) label_vec[label_ids] = 1 label = label_vec.tolist() file = bytes(file, encoding='utf-8') # print(file) example = build_example(user, image_t, text, label, file) fout.write(example.SerializeToString()) count += 1 if (count % 500) == 0: print('count={}'.format(count))
def main(_): gen_t = GEN(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() gen_v = GEN(flags, is_training=False) tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate) tf.summary.scalar(gen_t.pre_loss.name, gen_t.pre_loss) summary_op = tf.summary.merge_all() init_op = tf.global_variables_initializer() for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t user_bt_v, image_bt_v, text_bt_v, label_bt_v, file_bt_v = bt_list_v figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): image_np_t, label_np_t = sess.run([image_bt_t, label_bt_t]) feed_dict = { gen_t.image_ph: image_np_t, gen_t.hard_label_ph: label_np_t } _, summary = sess.run([gen_t.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) batch = batch_t + 1 remain = (batch * flags.batch_size) % train_data_size epoch = (batch * flags.batch_size) // train_data_size if remain == 0: pass # print('%d\t%d\t%d' % (epoch, batch, remain)) elif (train_data_size - remain) < flags.batch_size: epoch = epoch + 1 # print('%d\t%d\t%d' % (epoch, batch, remain)) else: continue # if (batch_t + 1) % eval_interval != 0: # continue hit_v = [] for batch_v in range(num_batch_v): image_np_v, label_np_v = sess.run([image_bt_v, label_bt_v]) feed_dict = {gen_v.image_ph: image_np_v} logit_np_v, = sess.run([gen_v.logits], feed_dict=feed_dict) hit_bt = metric.compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) figure_data.append((epoch, hit_v, batch_t)) if hit_v < best_hit_v: continue tot_time = time.time() - start best_hit_v = hit_v print('#%03d curbst=%.4f time=%.0fs' % (epoch, hit_v, tot_time)) gen_t.saver.save(sess, flags.gen_model_ckpt) print('bsthit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.gen_figure_data)) fout = open(flags.gen_figure_data, 'w') for epoch, hit_v, batch_t in figure_data: fout.write('%d\t%.4f\t%d\n' % (epoch, hit_v, batch_t)) fout.close()
def main(_): print('#label={}'.format(config.num_label)) tch_t = TCH(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() tch_v = TCH(flags, is_training=False) ts_list_t = utils.decode_tfrecord(config.train_tfrecord, shuffle=True) ts_list_v = utils.decode_tfrecord(config.valid_tfrecord, shuffle=False) bt_list_t = utils.generate_text_batch(ts_list_t, config.train_batch_size) bt_list_v = utils.generate_text_batch(ts_list_v, config.valid_batch_size) # check_tfrecord(bt_list_t, config.train_batch_size) # check_tfrecord(bt_list_v, config.valid_batch_size) user_bt_t, text_bt_t, label_bt_t, image_file_bt_t = bt_list_t user_bt_v, text_bt_v, label_bt_v, image_file_bt_v = bt_list_v best_hit_v = -np.inf init_op = tf.global_variables_initializer() start = time.time() with tf.Session() as sess: writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) sess.run(init_op) with slim.queues.QueueRunners(sess): for batch_t in range(num_batch_t): text_np_t, label_np_t = sess.run([text_bt_t, label_bt_t]) feed_dict = {tch_t.text_ph:text_np_t, tch_t.label_ph:label_np_t} _, summary = sess.run([tch_t.train_op, tch_t.summary_op], feed_dict=feed_dict) writer.add_summary(summary, batch_t) if (batch_t + 1) % != 0: continue hit_v = [] image_file_v = set() for batch_v in range(num_batch_v): text_np_v, label_np_v, image_file_np_v = sess.run([text_bt_v, label_bt_v, image_file_bt_v]) feed_dict = {tch_v.text_ph:text_np_v} logit_np_v, = sess.run([tch_v.logits], feed_dict=feed_dict) for image_file in image_file_np_v: image_file_v.add(image_file) hit_bt = compute_hit(logit_np_v, label_np_v, flags.cutoff) hit_v.append(hit_bt) hit_v = np.mean(hit_v) total_time = time.time() - start avg_batch = total_time / (batch_t + 1) avg_epoch = avg_batch * (config.train_data_size / config.train_batch_size) s = '{0} hit={1:.4f} tot={2:.0f}s avg={3:.0f}s' s = s.format(batch_t, hit_v, total_time, avg_epoch) print(s) if hit_v < best_hit_v: continue best_hit_v = hit_v ckpt_file = path.join(config.ckpt_dir, 'tch.ckpt') tch_t.saver.save(sess, ckpt_file) utils.create_if_nonexist(config.temp_dir) hit_file = path.join(config.temp_dir, 'tch.hit') with open(hit_file, 'w') as fout: fout.write('{0:.4f}'.format(best_hit_v))
def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) dis_summary_op = tf.summary.merge([ tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate), tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss), ]) gen_summary_op = tf.summary.merge([ tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate), tf.summary.scalar(gen_t.gan_loss.name, gen_t.gan_loss), ]) print(type(dis_summary_op), type(gen_summary_op)) init_op = tf.global_variables_initializer() data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size) user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size) user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) dis_t.saver.restore(sess, flags.dis_model_ckpt) gen_t.saver.restore(sess, flags.gen_model_ckpt) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v) print('init hit=%.4f' % (hit_v)) batch_d, batch_g = -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) num_batch_d = math.ceil(train_data_size / flags.batch_size) for _ in range(num_batch_d): batch_d += 1 image_np_d, label_dat_d = sess.run( [image_bt_d, label_bt_d]) feed_dict = {gen_t.image_ph: image_np_d} label_gen_d, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_np_d, label_np_d = utils.gan_dis_sample( flags, label_dat_d, label_gen_d) feed_dict = { dis_t.image_ph: image_np_d, dis_t.sample_ph: sample_np_d, dis_t.dis_label_ph: label_np_d, } _, summary_d = sess.run( [dis_t.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) for gen_epoch in range(flags.num_gen_epoch): print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) num_batch_g = math.ceil(train_data_size / flags.batch_size) for _ in range(num_batch_g): batch_g += 1 image_np_g, label_dat_g = sess.run( [image_bt_g, label_bt_g]) feed_dict = {gen_t.image_ph: image_np_g} label_gen_g, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_np_g = utils.generate_label( flags, label_dat_g, label_gen_g) feed_dict = { dis_t.image_ph: image_np_g, dis_t.sample_ph: sample_np_g, } reward_np_g, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_np_g, gen_t.sample_ph: sample_np_g, gen_t.reward_ph: reward_np_g, } _, summary_g = sess.run( [gen_t.gan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) # if (batch_g + 1) % eval_interval != 0: # continue # hit_v = utils.evaluate(flags, sess, gen_v, bt_list_v) # tot_time = time.time() - start # print('#%08d hit=%.4f %06ds' % (batch_g, hit_v, int(tot_time))) # if hit_v <= best_hit_v: # continue # best_hit_v = hit_v # print('best hit=%.4f' % (best_hit_v)) hit_v = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tot_time = time.time() - start print('#%03d curbst=%.4f %.0fs' % (epoch, hit_v, tot_time)) figure_data.append((epoch, hit_v)) if hit_v <= best_hit_v: continue best_hit_v = hit_v print('bsthit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.gan_figure_data)) fout = open(flags.gan_figure_data, 'w') for epoch, hit_v in figure_data: fout.write('%d\t%.4f\n' % (epoch, hit_v)) fout.close()
def main(_): bst_gen_acc, bst_tch_acc, bst_eph = 0.0, 0.0, 0 utils.create_if_nonexist(flags.gradient_dir) if flags.log_accuracy: acc_history = [] if flags.evaluate_tch: tch_history = [] with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) tn_dis.saver.restore(sess, flags.dis_model_ckpt) tn_gen.saver.restore(sess, flags.gen_model_ckpt) tn_tch.saver.restore(sess, flags.tch_model_ckpt) feed_dict = { vd_dis.image_ph:dis_mnist.test.images, vd_dis.hard_label_ph:dis_mnist.test.labels, } ini_dis = sess.run(vd_dis.accuracy, feed_dict=feed_dict) feed_dict = { vd_gen.image_ph:gen_mnist.test.images, vd_gen.hard_label_ph:gen_mnist.test.labels, } ini_gen = sess.run(vd_gen.accuracy, feed_dict=feed_dict) print('ini dis=%.4f ini gen=%.4f' % (ini_dis, ini_gen)) # exit() start = time.time() batch_d, batch_g, batch_t = -1, -1, -1 gumbel_times = (math.log(flags.gumbel_end_temperature / flags.gumbel_temperature) / math.log(flags.gumbel_temperature_decay_factor)) for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): # print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) # num_batch_d = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_d): # image_d, label_dat_d = dis_mnist.train.next_batch(flags.batch_size) for image_d, label_dat_d in dis_datagen.generate(batch_size=flags.batch_size): batch_d += 1 feed_dict = {tn_gen.image_ph:image_d} label_gen_d = sess.run(tn_gen.labels, feed_dict=feed_dict) sample_gen_d, gen_label_d = utils.gan_dis_sample(flags, label_dat_d, label_gen_d) feed_dict = {tn_tch.image_ph:image_d} label_tch_d = sess.run(tn_tch.labels, feed_dict=feed_dict) sample_tch_d, tch_label_d = utils.gan_dis_sample(flags, label_dat_d, label_tch_d) feed_dict = { tn_dis.image_ph:image_d, tn_dis.gen_sample_ph:sample_gen_d, tn_dis.gen_label_ph:gen_label_d, tn_dis.tch_sample_ph:sample_tch_d, tn_dis.tch_label_ph:tch_label_d, } sess.run(tn_dis.gan_update, feed_dict=feed_dict) for tch_epoch in range(flags.num_tch_epoch): # num_batch_t = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_t): # image_t, label_dat_t = tch_mnist.train.next_batch(flags.batch_size) for image_t, label_dat_t in tch_datagen.generate(batch_size=flags.batch_size): batch_t += 1 feed_dict = {tn_tch.image_ph:image_t} label_tch_t = sess.run(tn_tch.labels, feed_dict=feed_dict) sample_t = utils.generate_label(flags, label_dat_t, label_tch_t) feed_dict = { tn_dis.image_ph:image_t, tn_dis.tch_sample_ph:sample_t, } reward_t = sess.run(tn_dis.tch_rewards, feed_dict=feed_dict) feed_dict = {vd_gen.image_ph:image_t} soft_logit_t = sess.run(vd_gen.logits, feed_dict=feed_dict) feed_dict = { tn_tch.image_ph:image_t, tn_tch.sample_ph:sample_t, tn_tch.reward_ph:reward_t, tn_tch.hard_label_ph:label_dat_t, tn_tch.soft_logit_ph:soft_logit_t, } sess.run(tn_tch.kdgan_update, feed_dict=feed_dict) if not flags.evaluate_tch: continue if (batch_t + 1) % eval_interval != 0: continue feed_dict = { vd_tch.image_ph:gen_mnist.test.images, vd_tch.hard_label_ph:gen_mnist.test.labels, } tch_acc = sess.run(vd_tch.accuracy, feed_dict=feed_dict) bst_tch_acc = max(tch_acc, bst_tch_acc) print('#%08d tchcur=%.4f tchbst=%.4f' % (batch_t, tch_acc, bst_tch_acc)) tch_history.append(tch_acc) #### gumbel softmax if flags.enable_gumbel: if (epoch + 1) % max(int(flags.num_epoch / gumbel_times), 1) == 0: sess.run(tn_gen.gt_update) for gen_epoch in range(flags.num_gen_epoch): batch = -1 # num_batch_g = math.ceil(tn_size / flags.batch_size) # for _ in range(num_batch_g): # image_g, label_dat_g = gen_mnist.train.next_batch(flags.batch_size) for image_g, label_dat_g in gen_datagen.generate(batch_size=flags.batch_size): batch_g += 1 batch += 1 epk_bat = '%d.%d' % (epoch*flags.num_gen_epoch+gen_epoch, batch) ggrads_file = path.join(flags.gradient_dir, 'kdgan_ggrads.%s.p' % epk_bat) kgrads_file = path.join(flags.gradient_dir, 'kdgan_kgrads.%s.p' % epk_bat) feed_dict = {tn_gen.image_ph:image_g} if not flags.enable_gumbel: label_gen_g = sess.run(tn_gen.labels, feed_dict=feed_dict) else: label_gen_g = sess.run(tn_gen.gumbel_labels, feed_dict=feed_dict) sample_g = utils.generate_label(flags, label_dat_g, label_gen_g) feed_dict = { tn_dis.image_ph:image_g, tn_dis.gen_sample_ph:sample_g, } reward_g = sess.run(tn_dis.gen_rewards, feed_dict=feed_dict) # reward_g[reward_g>0.5] = 0.7 # reward_g[reward_g<0.5] = 0.3 feed_dict = {vd_tch.image_ph:image_g} soft_logit_g = sess.run(vd_tch.logits, feed_dict=feed_dict) # print(sample_g.shape, reward_g.shape, image_g.shape, soft_logit_g.shape) # exit() feed_dict = { tn_gen.image_ph:image_g, tn_gen.sample_ph:sample_g, tn_gen.reward_ph:reward_g, tn_gen.hard_label_ph:label_dat_g, tn_gen.soft_logit_ph:soft_logit_g, } # sess.run(tn_gen.kdgan_update, feed_dict=feed_dict) if flags.log_gradient: fetches = [tn_gen.kdgan_ggrads, tn_gen.kdgan_kgrads, tn_gen.kdgan_update] kdgan_ggrads, kdgan_kgrads, _ = sess.run(fetches, feed_dict=feed_dict) pickle.dump(kdgan_ggrads, open(ggrads_file, 'wb')) pickle.dump(kdgan_kgrads, open(kgrads_file, 'wb')) else: sess.run(tn_gen.kdgan_update, feed_dict=feed_dict) if flags.log_accuracy: feed_dict = { vd_gen.image_ph:gen_mnist.test.images, vd_gen.hard_label_ph:gen_mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) acc_history.append(acc) if (batch_g + 1) % eval_interval != 0: continue else: if (batch_g + 1) % eval_interval != 0: continue feed_dict = { vd_gen.image_ph:gen_mnist.test.images, vd_gen.hard_label_ph:gen_mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) if acc > bst_gen_acc: bst_gen_acc = max(acc, bst_gen_acc) bst_eph = epoch tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) # avg_time = (tot_time / global_step) * (tn_size / flags.batch_size) if flags.evaluate_tch: gen_tch_pct =100 * bst_gen_acc / bst_tch_acc print('#%08d/%08d gencur=%.4f genbst=%.4f (%.2f) tot=%.0fs' % (batch_g, tot_batch, acc, bst_gen_acc, gen_tch_pct, tot_time)) else: print('#%08d/%08d gencur=%.4f genbst=%.4f tot=%.0fs' % (batch_g, tot_batch, acc, bst_gen_acc, tot_time)) stdout.flush() if acc <= bst_gen_acc: continue # save gen parameters if necessary gumbel_temperature = sess.run(tn_gen.gumbel_temperature) print('gumbel_temperature=%.4f' % gumbel_temperature) tot_time = time.time() - start bst_gen_acc *= 100 bst_eph += 1 print('#mnist=%d kdgan@%d=%.2f et=%.0fs' % (tn_size, bst_eph, bst_gen_acc, tot_time)) if flags.log_accuracy: utils.create_pardir(flags.acc_file) pickle.dump(acc_history, open(flags.acc_file, 'wb')) if flags.evaluate_tch: utils.create_pardir(flags.tch_file) pickle.dump(tch_history, open(flags.tch_file, 'wb'))
def create_test_set(): utils.create_if_nonexist(precomputed_dir) user_list = [] file_list = [] text_list = [] label_list = [] fin = open(valid_file) valid_size = 0 while True: line = fin.readline() if not line: break fields = line.strip().split(FIELD_SEPERATOR) user = fields[USER_INDEX] image = fields[IMAGE_INDEX] file = path.join(image_data_dir, '%s.jpg' % image) tokens = fields[TEXT_INDEX].split() labels = fields[LABEL_INDEX].split() user_list.append(user) file_list.append(file) text_list.append(tokens) label_list.append(labels) valid_size += 1 fin.close() label_to_id = utils.load_sth_to_id(label_file) num_label = len(label_to_id) print('#label={}'.format(num_label)) token_to_id = utils.load_sth_to_id(vocab_file) unk_token_id = token_to_id[config.unk_token] vocab_size = len(token_to_id) print('#vocab={}'.format(vocab_size)) image_npy = np.zeros((valid_size, 4096), dtype=np.float32) label_npy = np.zeros((valid_size, 100), dtype=np.int32) imgid_npy = [] text_npy = [] reader = ImageReader() with tf.Session() as sess: init_fn(sess) for i, (user, file, text, labels) in enumerate( zip(user_list, file_list, text_list, label_list)): user = bytes(user, encoding='utf-8') image_np = np.array(Image.open(file)) # print(type(image_np), image_np.shape) feed_dict = {image_ph: image_np} image, = sess.run([end_point_v], feed_dict) image = image.tolist() # print(image) # print(type(image), len(image)) image_npy[i, :] = image # print(image_npy) # input() text = [token_to_id.get(token, unk_token_id) for token in text] text_npy.append(text) label_ids = [label_to_id[label] for label in labels] label_vec = np.zeros((num_label, ), dtype=np.int32) label_vec[label_ids] = 1 label = label_vec.tolist() label_npy[i, :] = label image_id = path.basename(file).split('.')[0] imgid_npy.append(image_id) # example = build_example(user, image, text, label, file) imgid_npy = np.asarray(imgid_npy) filename_tmpl = 'yfcc10k_%s.valid.%s' np.save( path.join(precomputed_dir, filename_tmpl % (flags.model_name, 'image')), image_npy) np.save( path.join(precomputed_dir, filename_tmpl % (flags.model_name, 'label')), label_npy) np.save( path.join(precomputed_dir, filename_tmpl % (flags.model_name, 'imgid')), imgid_npy) def padding(x): x = np.array(x) print(x.shape) max_length = max(len(row) for row in x) x = np.array([row + [0] * (max_length - len(row)) for row in x]) print(x.shape) return x text_npy = padding(text_npy) np.save( path.join(precomputed_dir, filename_tmpl % (flags.model_name, 'text')), text_npy)
DESC_INDEX = 4 LABEL_INDEX = -1 FIELD_SEPERATOR = '\t' EXPECTED_NUM_FIELD = 6 MIN_RND_LABEL = 10 NUM_RND_LABEL = 250 MIN_RND_POST = MIN_RND_LABEL NUM_RND_POST = 10000 TRAIN_DATA_RATIO = 0.80 SHUFFLE_SEED = 100 dataset = 'yfcc_rnd' dataset_dir = config.yfcc_rnd_dir utils.create_if_nonexist(dataset_dir) raw_file = path.join(dataset_dir, '%s.raw' % dataset) data_file = path.join(dataset_dir, '%s.data' % dataset) train_file = path.join(dataset_dir, '%s.train' % dataset) valid_file = path.join(dataset_dir, '%s.valid' % dataset) label_file = path.join(dataset_dir, '%s.label' % dataset) vocab_file = path.join(dataset_dir, '%s.vocab' % dataset) image_data_dir = path.join(dataset_dir, 'ImageData') ################################################################ # # create kdgan data # ################################################################
def main(_): for variable in tf.trainable_variables(): num_params = 1 for dim in variable.shape: num_params *= dim.value print('%-50s (%d params)' % (variable.name, num_params)) dis_summary_op = tf.summary.merge([ tf.summary.scalar(dis_t.learning_rate.name, dis_t.learning_rate), tf.summary.scalar(dis_t.gan_loss.name, dis_t.gan_loss), ]) gen_summary_op = tf.summary.merge([ tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate), tf.summary.scalar(gen_t.kdgan_loss.name, gen_t.kdgan_loss), ]) tch_summary_op = tf.summary.merge([ tf.summary.scalar(tch_t.learning_rate.name, tch_t.learning_rate), tf.summary.scalar(tch_t.kdgan_loss.name, tch_t.kdgan_loss), ]) init_op = tf.global_variables_initializer() data_sources_t = utils.get_data_sources(flags, is_training=True) data_sources_v = utils.get_data_sources(flags, is_training=False) print('tn: #tfrecord=%d\nvd: #tfrecord=%d' % (len(data_sources_t), len(data_sources_v))) ts_list_d = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_d = utils.generate_batch(ts_list_d, flags.batch_size) user_bt_d, image_bt_d, text_bt_d, label_bt_d, file_bt_d = bt_list_d ts_list_g = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_g = utils.generate_batch(ts_list_g, flags.batch_size) user_bt_g, image_bt_g, text_bt_g, label_bt_g, file_bt_g = bt_list_g ts_list_t = utils.decode_tfrecord(flags, data_sources_t, shuffle=True) bt_list_t = utils.generate_batch(ts_list_t, flags.batch_size) user_bt_t, image_bt_t, text_bt_t, label_bt_t, file_bt_t = bt_list_t ts_list_v = utils.decode_tfrecord(flags, data_sources_v, shuffle=False) bt_list_v = utils.generate_batch(ts_list_v, config.valid_batch_size) figure_data = [] best_hit_v = -np.inf start = time.time() with tf.Session() as sess: sess.run(init_op) dis_t.saver.restore(sess, flags.dis_model_ckpt) gen_t.saver.restore(sess, flags.gen_model_ckpt) tch_t.saver.restore(sess, flags.tch_model_ckpt) writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with slim.queues.QueueRunners(sess): gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v) print('hit gen=%.4f tch=%.4f' % (gen_hit, tch_hit)) batch_d, batch_g, batch_t = -1, -1, -1 for epoch in range(flags.num_epoch): for dis_epoch in range(flags.num_dis_epoch): print('epoch %03d dis_epoch %03d' % (epoch, dis_epoch)) for _ in range(num_batch_per_epoch): #continue batch_d += 1 image_d, text_d, label_dat_d = sess.run( [image_bt_d, text_bt_d, label_bt_d]) feed_dict = {gen_t.image_ph: image_d} label_gen_d, = sess.run([gen_t.labels], feed_dict=feed_dict) # print('gen label', label_gen_d.shape) feed_dict = { tch_t.text_ph: text_d, tch_t.image_ph: image_d } label_tch_d, = sess.run([tch_t.labels], feed_dict=feed_dict) # print('tch label', label_tch_d.shape) sample_d, label_d = utils.kdgan_dis_sample( flags, label_dat_d, label_gen_d, label_tch_d) # print(sample_d.shape, label_d.shape) feed_dict = { dis_t.image_ph: image_d, dis_t.sample_ph: sample_d, dis_t.dis_label_ph: label_d, } _, summary_d = sess.run( [dis_t.gan_update, dis_summary_op], feed_dict=feed_dict) writer.add_summary(summary_d, batch_d) for tch_epoch in range(flags.num_tch_epoch): print('epoch %03d tch_epoch %03d' % (epoch, tch_epoch)) for _ in range(num_batch_per_epoch): #continue batch_t += 1 image_t, text_t, label_dat_t = sess.run( [image_bt_t, text_bt_t, label_bt_t]) feed_dict = { tch_t.text_ph: text_t, tch_t.image_ph: image_t } label_tch_t, = sess.run([tch_t.labels], feed_dict=feed_dict) sample_t = utils.generate_label( flags, label_dat_t, label_tch_t) feed_dict = { dis_t.image_ph: image_t, dis_t.sample_ph: sample_t, } reward_t, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_t, } label_gen_g = sess.run(gen_t.logits, feed_dict=feed_dict) #print(len(label_dat_t), len(label_dat_t[0])) #exit() feed_dict = { tch_t.text_ph: text_t, tch_t.image_ph: image_t, tch_t.sample_ph: sample_t, tch_t.reward_ph: reward_t, tch_t.hard_label_ph: label_dat_t, tch_t.soft_label_ph: label_gen_g, } _, summary_t, tch_kdgan_loss = sess.run( [ tch_t.kdgan_update, tch_summary_op, tch_t.kdgan_loss ], feed_dict=feed_dict) writer.add_summary(summary_t, batch_t) #print("teacher kdgan loss:", tch_kdgan_loss) for gen_epoch in range(flags.num_gen_epoch): print('epoch %03d gen_epoch %03d' % (epoch, gen_epoch)) for _ in range(num_batch_per_epoch): batch_g += 1 image_g, text_g, label_dat_g = sess.run( [image_bt_g, text_bt_g, label_bt_g]) feed_dict = { tch_t.text_ph: text_g, tch_t.image_ph: image_g } label_tch_g, = sess.run([tch_t.labels], feed_dict=feed_dict) # print('tch label {}'.format(label_tch_g.shape)) feed_dict = {gen_t.image_ph: image_g} label_gen_g, = sess.run([gen_t.labels], feed_dict=feed_dict) sample_g = utils.generate_label( flags, label_dat_g, label_gen_g) feed_dict = { dis_t.image_ph: image_g, dis_t.sample_ph: sample_g, } reward_g, = sess.run([dis_t.rewards], feed_dict=feed_dict) feed_dict = { gen_t.image_ph: image_g, gen_t.hard_label_ph: label_dat_g, gen_t.soft_label_ph: label_tch_g, gen_t.sample_ph: sample_g, gen_t.reward_ph: reward_g, } _, summary_g = sess.run( [gen_t.kdgan_update, gen_summary_op], feed_dict=feed_dict) writer.add_summary(summary_g, batch_g) # if (batch_g + 1) % eval_interval != 0: # continue # gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) # tot_time = time.time() - start # print('#%08d hit=%.4f %06ds' % (batch_g, gen_hit, int(tot_time))) # if gen_hit <= best_hit_v: # continue # best_hit_v = gen_hit # print('best hit=%.4f' % (best_hit_v)) gen_hit = utils.evaluate_image(flags, sess, gen_v, bt_list_v) tch_hit = utils.evaluate_text(flags, sess, tch_v, bt_list_v) tot_time = time.time() - start print('#%03d curgen=%.4f curtch=%.4f %.0fs' % (epoch, gen_hit, tch_hit, tot_time)) figure_data.append((epoch, gen_hit, tch_hit)) if gen_hit <= best_hit_v: continue best_hit_v = gen_hit print("epoch ", epoch + 1, ":, new best validation hit:", best_hit_v, "saving...") gen_t.saver.save(sess, flags.kdgan_model_ckpt, global_step=epoch + 1) print("finish saving") print('best hit=%.4f' % (best_hit_v)) utils.create_if_nonexist(os.path.dirname(flags.kdgan_figure_data)) fout = open(flags.kdgan_figure_data, 'w') for epoch, gen_hit, tch_hit in figure_data: fout.write('%d\t%.4f\t%.4f\n' % (epoch, gen_hit, tch_hit)) fout.close()