def tf_test_flow(test_once, model_dir='./model', model_name=None, num_epochs=1, num_steps=0, sess=None): """ basic flow for tf records, allow most freedom for usage, if not tfrecords no need for flow Args: test_once: function with 2 inputs sess and step model_dir: can be dir like ./model will fetch lates model in model dir , or be real model path like ./model/model.0.ckpt """ if sess is None: sess = tf.InteractiveSession() melt.restore(sess, model_dir, model_name) if not os.path.isdir(model_dir): model_dir = os.path.dirname(model_dir) summary_op = tf.merge_all_summaries() summary_writer = tf.train.SummaryWriter(model_dir, sess.graph) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: step = 0 while not coord.should_stop(): test_once(sess, step) step += 1 if num_steps and step == num_steps: raise tf.errors.OutOfRangeError(None, None, 'Reached max num steps') except tf.errors.OutOfRangeError: print('Done testing for %d epochs, %d steps.' % (num_epochs, step)) finally: # When done, ask the threads to stop. coord.request_stop() # Wait for threads to finish. coord.join(threads)
def tf_flow(process_once, model_dir=None, num_steps=None, sess=None): """ basic flow for tf records, allow most freedom for usage, if not tfrecords no need for flow Args: train_once: function with 2 inputs sess and step """ init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) if sess is None: sess = tf.InteractiveSession() if not model_dir: sess.run(init_op) else: melt.restore(sess, model_dir) # coord = tf.train.Coordinator() # threads = tf.train.start_queue_runners(sess=sess, coord=coord) # try: # # dataset # print('tf_flow using datset') # step = 0 # try: # while True: # stop = process_once(sess, step) # if stop is True: # print('Early stop running %d stpes'%(step)) # raise tf.errors.OutOfRangeError(None, None,'Early stop running %d stpes'%(step)) # step += 1 # if num_steps and step == num_steps: # raise tf.errors.OutOfRangeError(None, None, 'Reached max num steps') # except tf.errors.OutOfRangeError: # print('Done training for %d steps.' % (step)) # except Exception: # # old queue method # print('tf_flow using queue') try: step = 0 #while not coord.should_stop(): while True: stop = process_once(sess, step) if stop is True: print('Early stop running %d stpes' % (step)) raise tf.errors.OutOfRangeError( None, None, 'Early stop running %d stpes' % (step)) step += 1 if num_steps and step == num_steps: raise tf.errors.OutOfRangeError(None, None, 'Reached max num steps') except tf.errors.OutOfRangeError: print('Done training for %d steps.' % (step)) # finally: # coord.request_stop() # coord.join(threads) sess.close() return step
def test_flow(ops, names=None, gen_feed_dict=None, deal_results=None, model_dir='./model', model_name=None, num_epochs=1, num_interval_steps=100, eval_times=0, print_avg_loss=True, sess=None): """ test flow, @TODO improve list result print Args: ops: eval ops names: eval names model_path: can be dir like ./model will fetch lates model in model dir , or be real model path like ./model/model.0.ckpt @TODO num_epochs should be 1,but now has problem of loading model if set @FIXME, so now 0 """ if sess is None: sess = tf.InteractiveSession() melt.restore(sess, model_dir, model_name) if not os.path.isdir(model_dir): model_dir = os.path.dirname(model_dir) summary_op = tf.merge_all_summaries() summary_writer = tf.train.SummaryWriter(model_dir, sess.graph) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: step = 0 eval_step = 0 avg_eval = AvgScore() total_avg_eval = AvgScore() while not coord.should_stop(): feed_dict = {} if gen_feed_dict is None else gen_feed_dict() results = sess.run(ops, feed_dict=feed_dict) if not isinstance(results, (list, tuple)): results = [results] if deal_results is not None: #@TODO may need to pass summary_writer, and step #use **args ? deal_results(results) if print_avg_loss: results = gezi.get_singles(results) avg_eval.add(results) total_avg_eval.add(results) if step % num_interval_steps == 0: average_eval = avg_eval.avg_score() print( '{}: average evals = {}'.format( gezi.now_time(), melt.value_name_list_str(average_eval, names)), 'step:', step) summary = tf.Summary() summary_str = sess.run(summary_op, feed_dict=feed_dict) summary.ParseFromString(summary_str) for i in xrange(len(results)): name = i if names is None else names[i] summary.value.add(tag='metric{}'.format(name), simple_value=average_eval[i]) summary_writer.add_summary(summary, step) if eval_step and eval_step == eval_times: break eval_step += 1 step += 1 print('Done testing for {} epochs, {} steps. AverageEvals:{}'.format( num_epochs, step, gezi.pretty_floats(total_avg_eval.avg_score()))) except tf.errors.OutOfRangeError: print('Done testing for {} epochs, {} steps. AverageEvals:{}'.format( num_epochs, step, gezi.pretty_floats(total_avg_eval.avg_score()))) finally: # When done, ask the threads to stop. coord.request_stop() # Wait for threads to finish. coord.join(threads)