def print_img(img, i): img_url = FLAGS.image_url_prefix + img if not img.startswith( "http://") else img logging.info( img_html.format(img_url, i, img, melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss(), melt.duration(), gezi.now_time()))
def print_img(img, i): img_url = get_img_url(img) logging.info(img_html.format( img_url, i, img, melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss(), melt.duration(), gezi.now_time()))
def evaluate_scores(predictor, random=False): timer = gezi.Timer('evaluate_scores') init() imgs, img_features = get_image_names_and_features() num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) step = FLAGS.metric_eval_batch_size if random: index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False) imgs = imgs[index] img_features = img_features[index] text_max_words = all_distinct_texts.shape[1] rank_metrics = gezi.rank_metrics.RecallMetrics() print('text_max_words:', text_max_words) start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts start:', start, 'end:', end, file=sys.stderr) predicts(imgs[start:end], img_features[start:end], predictor, rank_metrics) start = end melt.logging_results( rank_metrics.get_metrics(), rank_metrics.get_names(), tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format( melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss())) timer.print() return rank_metrics.get_metrics(), rank_metrics.get_names()
def tf_train_flow( train_once_fn, model_dir='./model', max_models_keep=1, save_interval_seconds=600, save_interval_steps=1000, num_epochs=None, num_steps=None, save_model=True, save_interval_epochs=1, num_steps_per_epoch=0, restore_from_latest=True, metric_eval_fn=None, init_fn=None, restore_fn=None, restore_scope=None, save_all_scope=False, #TODO save load from restore scope only but svae all variables_to_restore=None, variables_to_save=None, #by default will be the same as variables_to_restore sess=None): """ similary flow as tf_flow, but add model try reload and save """ if sess is None: #TODO melt.get_session is global session but may cause non close at last sess = melt.get_session() logging.info('tf_train_flow start') print('max_models_keep:', max_models_keep, file=sys.stderr) print('save_interval_seconds:', save_interval_seconds, file=sys.stderr) #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars! #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here var_list = None if not restore_scope else tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope) if not variables_to_restore: variables_to_restore = var_list if not variables_to_save: variables_to_save = variables_to_restore if save_all_scope: variables_to_save = None if variables_to_restore is None: #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir) #print(varnames_in_checkpoint) variables_to_restore = slim.get_variables_to_restore( include=varnames_in_checkpoint) #logging.info('variables_to_restore:{}'.format(variables_to_restore)) loader = tf.train.Saver(var_list=variables_to_restore) saver = tf.train.Saver( max_to_keep=max_models_keep, keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0, var_list=variables_to_save) epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000) best_epoch_saver = tf.train.Saver(var_list=variables_to_save) ##TODO for safe restore all init will be ok ? #if variables_to_restore is None: init_op = tf.group( tf.global_variables_initializer( ), #variables_initializer(global_variables()) tf.local_variables_initializer() ) #variables_initializer(local_variables()) # else: # init_op = tf.group(tf.variables_initializer(variables_to_restore), # tf.local_variables_initializer()) ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong ##so assume to all run init op! if using assistant predictor, make sure it use another session sess.run(init_op) #melt.init_uninitialized_variables(sess) #pre_step means the step last saved, train without pretrained,then -1 pre_step = -1 fixed_pre_step = -1 #fixed pre step is for epoch num to be correct if yu change batch size model_path = _get_model_path(model_dir, save_model) model_dir = gezi.get_dir( model_dir) #incase you pass ./model/model-ckpt1000 -> ./model if model_path is not None: if not restore_from_latest: print('using recent but not latest model', file=sys.stderr) model_path = melt.recent_checkpoint(model_dir) model_name = os.path.basename(model_path) timer = gezi.Timer('Loading and training from existing model [%s]' % model_path) if restore_fn is not None: restore_fn(sess) loader.restore(sess, model_path) timer.print() pre_step = melt.get_model_step(model_path) pre_epoch = melt.get_model_epoch(model_path) fixed_pre_step = pre_step if pre_epoch is not None: #like using batch size 32, then reload train using batch size 64 if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1: fixed_pre_step = int(pre_epoch * num_steps_per_epoch) logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\ .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step)) else: print('Train all start step 0', file=sys.stderr) #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()), #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()), #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively. #init_op = tf.group(tf.global_variables_initializer(), # tf.local_variables_initializer()) #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope) #sess.run(init_op) #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok #for finetune from loading other model init if init_fn is not None: init_fn(sess) if save_interval_epochs and num_steps_per_epoch: epoch_dir = os.path.join(model_dir, 'epoch') gezi.try_mkdir(epoch_dir) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) checkpoint_path = os.path.join(model_dir, 'model.ckpt') tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt') only_one_step = False try: step = start = pre_step + 1 fixed_step = fixed_pre_step + 1 #hack just for save one model after load if num_steps < 0 or (num_steps and num_steps < step): print('just load and resave then exit', file=sys.stderr) saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) sess.close() exit(0) if num_epochs < 0: only_one_step = True print('just run one step', file=sys.stderr) early_stop = True #TODO allow config num_bad_epochs = 0 pre_epoch_eval_loss = 1e20 best_epoch_eval_loss = 1e20 num_allowed_bad_epochs = 4 #allow 5 non decrease eval loss epochs before stop while not coord.should_stop(): stop = train_once_fn(sess, step, is_start=(step == start), fixed_step=fixed_step) if only_one_step: stop = True if save_model and step: #step 0 is also saved! actually train one step and save if step % save_interval_steps == 0: timer = gezi.Timer('save model step %d to %s' % (step, checkpoint_path)) saver.save(sess, _get_checkpoint_path(checkpoint_path, fixed_step, num_steps_per_epoch), global_step=step) timer.print() #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0: #if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0: if save_interval_epochs and num_steps_per_epoch and fixed_step % num_steps_per_epoch == 0: #epoch = step // num_steps_per_epoch epoch = fixed_step // num_steps_per_epoch eval_loss = melt.eval_loss() if eval_loss: #['eval_loss:3.2','eal_accuracy:4.3'] eval_loss = float( eval_loss.strip('[]').split(',')[0].strip( "'").split(':')[-1]) if os.path.exists( os.path.join(epoch_dir, 'best_eval_loss.txt')): with open( os.path.join(epoch_dir, 'best_eval_loss.txt')) as f: best_epoch_eval_loss = float( f.readline().split()[-1].strip()) if eval_loss < best_epoch_eval_loss: best_epoch_eval_loss = eval_loss logging.info( 'Now best eval loss is epoch %d eval_loss:%f' % (epoch, eval_loss)) with open( os.path.join(epoch_dir, 'best_eval_loss.txt'), 'w') as f: f.write('%d %d %f\n' % (epoch, step, best_epoch_eval_loss)) best_epoch_saver.save( sess, os.path.join(epoch_dir, 'model.ckpt-best')) with open(os.path.join(epoch_dir, 'eval_loss.txt'), 'a') as f: f.write('%d %d %f\n' % (epoch, step, eval_loss)) if eval_loss >= pre_epoch_eval_loss: num_bad_epochs += 1 if num_bad_epochs > num_allowed_bad_epochs: logging.warning( 'Evaluate loss not decrease for last %d epochs' % (num_allowed_bad_epochs + 1)) if not os.path.exists( os.path.join(epoch_dir, 'model.ckpt-noimprove')): best_epoch_saver.save( sess, os.path.join(epoch_dir, 'model.ckpt-noimprove')) ##-------well remove it since #if early_stop: # stop = True else: num_bad_epochs = 0 pre_epoch_eval_loss = eval_loss if step % (num_steps_per_epoch * save_interval_epochs) == 0: epoch_saver.save(sess, os.path.join(epoch_dir, 'model.ckpt-%d' % epoch), global_step=step) #--------do not add step # epoch_saver.save(sess, # os.path.join(epoch_dir,'model.ckpt-%d'%epoch)) if stop is True: print('Early stop running %d stpes' % (step), file=sys.stderr) raise tf.errors.OutOfRangeError( None, None, 'Early stop running %d stpes' % (step)) if num_steps and (step + 1) == start + num_steps: raise tf.errors.OutOfRangeError(None, None, 'Reached max num steps') #max_num_epochs = 1000 max_num_epochs = num_epochs if max_num_epochs and num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs: raise tf.errors.OutOfRangeError( None, None, 'Reached max num epochs of %d' % max_num_epochs) step += 1 fixed_step += 1 except tf.errors.OutOfRangeError, e: if not (step == start) and save_model and step % save_interval_steps != 0: saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) if only_one_step: print('Done one step', file=sys.stderr) exit(0) if metric_eval_fn is not None: metric_eval_fn() if (num_epochs and step / num_steps_per_epoch >= num_epochs) or ( num_steps and (step + 1) == start + num_steps): print('Done training for %.3f epochs, %d steps.' % (step / num_steps_per_epoch, step + 1), file=sys.stderr) #FIXME becase coord.join seems not work, RuntimeError: Coordinator stopped with threads still running: Thread-9 exit(0) else: print('Should not stop, but stopped at epoch: %.3f' % (step / num_steps_per_epoch), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) raise e
def tf_train_flow(train_once_fn, model_dir='./model', max_models_keep=1, save_interval_seconds=600, save_interval_steps=1000, num_epochs=None, num_steps=None, save_model=True, save_interval_epochs=1, num_steps_per_epoch=0, restore_from_latest=True, metric_eval_fn=None, init_fn=None, sess=None): """ similary flow as tf_flow, but add model try reload and save """ if sess is None: #TODO melt.get_session is global session but may cause sess = melt.get_session() logging.info('tf_train_flow start') print('max_models_keep:', max_models_keep) print('save_interval_seconds:', save_interval_seconds) saver = tf.train.Saver( max_to_keep=max_models_keep, keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0) epoch_saver = tf.train.Saver() best_epoch_saver = tf.train.Saver() #pre_step means the step last saved, train without pretrained,then -1 pre_step = -1; model_path = _get_model_path(model_dir, save_model) model_dir = gezi.get_dir(model_dir) #incase you pass ./model/model-ckpt1000 -> ./model if model_path is not None: if not restore_from_latest: print('using recent but not latest model', file=sys.stderr) model_path = melt.recent_checkpoint(model_dir) model_name = os.path.basename(model_path) timer = gezi.Timer('Loading and training from existing model [%s]'%model_path) saver.restore(sess, model_path) timer.print() pre_step = melt.get_model_step(model_path) if 'epoch' in model_name: pre_step *= num_steps_per_epoch #for non 0 eopochs without this will be #Attempting to use uninitialized value input/input_producer/limit_epochs/epochs sess.run(tf.local_variables_initializer()) else: print('Train all start step 0', file=sys.stderr) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) if init_fn is not None: init_fn(sess) if save_interval_epochs and num_steps_per_epoch: epoch_dir = os.path.join(model_dir, 'epoch') gezi.try_mkdir(epoch_dir) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) checkpoint_path = os.path.join(model_dir, 'model.ckpt') tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt') try: step = start = pre_step + 1 #hack just for save one model after load if num_steps and num_steps < step: print('just load and resave then exit', file=sys.stderr) saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) sess.close() exit(0) early_stop = True #TODO allow config num_bad_epochs = 0 pre_epoch_eval_loss = 1e20 best_epoch_eval_loss = 1e20 num_allowed_bad_epochs = 4 #allow 5 non decrease eval loss epochs before stop while not coord.should_stop(): stop = train_once_fn(sess, step, is_start=(step==start)) if save_model and step: #step 0 is also saved! actually train one step and save if step % save_interval_steps == 0: timer = gezi.Timer('save model step %d to %s'%(step, checkpoint_path)) saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) timer.print() #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0: if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0: epoch = step // num_steps_per_epoch eval_loss = melt.eval_loss() if eval_loss: #['eval_loss:3.2','eal_accuracy:4.3'] eval_loss = float(eval_loss.strip('[]').split(',')[0].strip("'").split(':')[-1]) if os.path.exists(os.path.join(epoch_dir, 'best_eval_loss.txt')): with open(os.path.join(epoch_dir, 'best_eval_loss.txt')) as f: best_epoch_eval_loss = float(f.readline().split()[-1].strip()) if eval_loss < best_epoch_eval_loss: best_epoch_eval_loss = eval_loss logging.info('Now best eval loss is epoch %d eval_loss:%f' % (epoch, eval_loss)) with open(os.path.join(epoch_dir, 'best_eval_loss.txt'), 'w') as f: f.write('%d %d %f\n'%(epoch, step, best_epoch_eval_loss)) best_epoch_saver.save(sess, os.path.join(epoch_dir,'model.cpkt-best')) with open(os.path.join(epoch_dir, 'eval_loss.txt'), 'a') as f: f.write('%d %d %f\n'%(epoch, step, eval_loss)) if eval_loss >= pre_epoch_eval_loss: num_bad_epochs += 1 if num_bad_epochs > num_allowed_bad_epochs: logging.warning('Evaluate loss not decrease for last %d epochs'% (num_allowed_bad_epochs + 1)) if not os.path.exists(os.path.join(epoch_dir,'model.cpkt-noimprove')): best_epoch_saver.save(sess, os.path.join(epoch_dir,'model.cpkt-noimprove')) ##-------well remove it since #if early_stop: # stop = True else: num_bad_epochs = 0 pre_epoch_eval_loss = eval_loss if step % (num_steps_per_epoch * save_interval_epochs) == 0: epoch_saver.save(sess, os.path.join(epoch_dir,'model.cpkt-%d'%epoch), global_step=step) #--------do not add step # epoch_saver.save(sess, # os.path.join(epoch_dir,'model.cpkt-%d'%epoch)) if stop is True: print('Early stop running %d stpes'%(step), file=sys.stderr) raise tf.errors.OutOfRangeError(None, None,'Early stop running %d stpes'%(step)) if num_steps and (step + 1) == start + num_steps: raise tf.errors.OutOfRangeError(None, None,'Reached max num steps') #max_num_epochs = 1000 max_num_epochs = num_epochs if num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs: raise tf.errors.OutOfRangeError(None, None,'Reached max num epochs of %d'%max_num_epochs) step += 1 except tf.errors.OutOfRangeError, e: if not (step==start) and save_model and step % save_interval_steps != 0: saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) if metric_eval_fn is not None: metric_eval_fn() if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (num_steps and (step + 1) == start + num_steps) : print('Done training for %.3f epochs, %d steps.' % (step / num_steps_per_epoch, step + 1), file=sys.stderr) #FIXME becase coord.join seems not work, RuntimeError: Coordinator stopped with threads still running: Thread-9 exit(0) else: print('Should not stop, but stopped at epoch: %.3f'%(step / num_steps_per_epoch), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) raise e
def evaluate_scores(predictor, random=False): timer = gezi.Timer('evaluate_scores') init() if FLAGS.eval_img2text: imgs, img_features = get_image_names_and_features() num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) step = FLAGS.metric_eval_batch_size if random: index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False) imgs = imgs[index] img_features = img_features[index] rank_metrics = gezi.rank_metrics.RecallMetrics() start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts image start:', start, 'end:', end, file=sys.stderr) predicts(imgs[start:end], img_features[start:end], predictor, rank_metrics) start = end melt.logging_results( rank_metrics.get_metrics(), rank_metrics.get_names(), tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format( melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss())) if FLAGS.eval_text2img: num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(all_distinct_texts)) if random: index = np.random.choice(len(all_distinct_texts), num_metric_eval_examples, replace=False) text_strs = all_distinct_text_strs[index] texts = all_distinct_texts[index] rank_metrics2 = gezi.rank_metrics.RecallMetrics() start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts start:', start, 'end:', end, file=sys.stderr) predicts_txt2im(text_strs[start:end], texts[start:end], predictor, rank_metrics2) start = end melt.logging_results(rank_metrics2.get_metrics(), ['t2i' + x for x in rank_metrics2.get_names()], tag='text2img') timer.print() if FLAGS.eval_img2text and FLAGS.eval_text2img: return rank_metrics.get_metrics() + rank_metrics2.get_metrics( ), rank_metrics.get_names() + [ 't2i' + x for x in rank_metrics2.get_names() ] elif FLAGS.eval_img2text: return rank_metrics.get_metrics(), rank_metrics.get_names() else: return rank_metrics2.get_metrics(), rank_metrics2.get_names()
def evaluate_translation(predictor, random=False, index=None): timer = gezi.Timer('evaluate_translation') refs = prepare_refs() imgs, img_features = get_image_names_and_features() num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) if num_metric_eval_examples <= 0: num_metric_eval_examples = len(imgs) if num_metric_eval_examples == len(imgs): random = False step = FLAGS.metric_eval_batch_size if random: if index is None: index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False) imgs = imgs[index] img_features = img_features[index] else: img_features = img_features[:num_metric_eval_examples] results = {} start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts image start:', start, 'end:', end, file=sys.stderr, end='\r') translation_predicts(imgs[start: end], img_features[start: end], predictor, results) start = end scorers = [ (Bleu(4), ["bleu_1", "bleu_2", "bleu_3", "bleu_4"]), (Meteor(),"meteor"), (Rouge(), "rouge_l"), (Cider(), "cider") ] score_list = [] metric_list = [] selected_refs = {} selected_results = {} #by doing this can force same .keys() for key in results: selected_refs[key] = refs[key] selected_results[key] = results[key] assert len(selected_results[key]) == 1, selected_results[key] assert selected_results.keys() == selected_refs.keys(), '%d %d'%(len(selected_results.keys()), len(selected_refs.keys())) if FLAGS.eval_translation_reseg: print('tokenization...', file=sys.stderr) global tokenizer if tokenizer is None: tokenizer = PTBTokenizer() selected_refs = tokenizer.tokenize(selected_refs) selected_results = tokenizer.tokenize(selected_results) logging.info('predict&label:{}{}{}'.format('|'.join(selected_results.items()[0][1]), '---', '|'.join(selected_refs.items()[0][1]))) for scorer, method in scorers: print('computing %s score...'%(scorer.method()), file=sys.stderr) score, scores = scorer.compute_score(selected_refs, selected_results) if type(method) == list: for sc, scs, m in zip(score, scores, method): score_list.append(sc) metric_list.append(m) if FLAGS.eval_result_dir: out = open(os.path.join(FLAGS.eval_result_dir, m+'.txt'), 'w') for i, sc in enumerate(scs): key = selected_results.keys()[i] result = selected_results[key] refs = '\x01'.join(selected_refs[key]) print(key, result, refs, sc, sep='\t', file=out) else: score_list.append(score) metric_list.append(method) if FLAGS.eval_result_dir: out = open(os.path.join(FLAGS.eval_result_dir, m+'.txt'), 'w') for i, sc in enumerate(scores): key = selected_results.keys()[i] result = selected_results[key] refs = '\x01'.join(selected_refs[key]) print(key, result, refs, sc, sep='\t', file=out) #exclude "bleu_1", "bleu_2", "bleu_3" score_list, metric_list = score_list[3:], metric_list[3:] assert(len(score_list) == 4) avg_score = sum(score_list) / len(score_list) score_list.append(avg_score) metric_list.append('avg') metric_list = ['trans_' + x for x in metric_list] melt.logging_results( score_list, metric_list, tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format( melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss())) timer.print() return score_list, metric_list
def evaluate_scores(predictor, random=False, index=None, exact_predictor=None, exact_ratio=1.): """ actually this is rank metrics evaluation, by default recall@1,2,5,10,50 """ timer = gezi.Timer('evaluate_scores') init() if FLAGS.eval_img2text: imgs, img_features = get_image_names_and_features() num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) if num_metric_eval_examples <= 0: num_metric_eval_examples = len(imgs) if num_metric_eval_examples == len(imgs): random = False step = FLAGS.metric_eval_batch_size if random: if index is None: index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False) imgs = imgs[index] img_features = img_features[index] else: img_features = img_features[:num_metric_eval_examples] rank_metrics = gezi.rank_metrics.RecallMetrics() start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts image start:', start, 'end:', end, file=sys.stderr, end='\r') predicts(imgs[start: end], img_features[start: end], predictor, rank_metrics, exact_predictor=exact_predictor, exact_ratio=exact_ratio) start = end melt.logging_results( rank_metrics.get_metrics(), rank_metrics.get_names(), tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format( melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss())) if FLAGS.eval_text2img: num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(all_distinct_texts)) if random: index = np.random.choice(len(all_distinct_texts), num_metric_eval_examples, replace=False) text_strs = all_distinct_text_strs[index] texts = all_distinct_texts[index] else: text_strs = all_distinct_text_strs texts = all_distinct_texts rank_metrics2 = gezi.rank_metrics.RecallMetrics() start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts start:', start, 'end:', end, file=sys.stderr, end='\r') predicts_txt2im(text_strs[start: end], texts[start: end], predictor, rank_metrics2, exact_predictor=exact_predictor) start = end melt.logging_results( rank_metrics2.get_metrics(), ['t2i' + x for x in rank_metrics2.get_names()], tag='text2img') timer.print() if FLAGS.eval_img2text and FLAGS.eval_text2img: return rank_metrics.get_metrics() + rank_metrics2.get_metrics(), rank_metrics.get_names() + ['t2i' + x for x in rank_metrics2.get_names()] elif FLAGS.eval_img2text: return rank_metrics.get_metrics(), rank_metrics.get_names() else: return rank_metrics2.get_metrics(), rank_metrics2.get_names()