def train_once( sess, step, ops, names=None, gen_feed_dict_fn=None, deal_results_fn=melt.print_results, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, eval_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_fn=None, metric_eval_interval_steps=0, summary_excls=None, fixed_step=None #for epoch only, incase you change batch size ): timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = (fixed_step or step) / num_steps_per_epoch if num_steps_per_epoch else -1 epoch_str = 'epoch:%.4f' % (epoch) if num_steps_per_epoch else '' melt.set_global('epoch', '%.4f' % (epoch)) info = BytesIO() stop = False if is_start or eval_interval_steps and step % eval_interval_steps == 0: if eval_ops is not None: if deal_eval_results_fn is None and eval_names is not None: deal_eval_results_fn = lambda x: melt.print_results( x, eval_names) eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn( ) #eval_feed_dict.update(feed_dict) #------show how to perf debug ##timer_ = gezi.Timer('sess run generate') ##sess.run(eval_ops[-2], feed_dict=None) ##timer_.print() timer_ = gezi.Timer('sess run eval_ops') eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) eval_loss = gezi.get_singles(eval_results) timer_.print() if deal_eval_results_fn is not None: #@TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') logging.info2('{} eval_step: {} eval_metrics:{}'.format( epoch_str, step, eval_loss)) eval_stop = deal_eval_results_fn(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != eval_interval_steps: #print() pass metric_evaluate = False # if metric_eval_function is not None \ # and ( (is_start and (step or ops is None))\ # or (step and ((num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)))): # metric_evaluate = True if metric_eval_fn is not None \ and (is_start \ or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ or (metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0)): metric_evaluate = True if metric_evaluate: evaluate_results, evaluate_names = metric_eval_fn() if ops is not None: if deal_results_fn is None and names is not None: deal_results_fn = lambda x: melt.print_results(x, names) if eval_names is None: eval_names = names feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn() results = sess.run(ops, feed_dict=feed_dict) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op results = results[1:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() if interval_steps != eval_interval_steps: train_once.avg_loss2 = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) if interval_steps != eval_interval_steps: train_once.avg_loss2.add(loss) if is_start or interval_steps and step % interval_steps == 0: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.3f} '.format(duration) melt.set_global('duration', '%.3f' % duration) info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size * num_gpus / elapsed gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format( num_gpus) if num_steps_per_epoch is None: epoch_time_info = '' else: epoch_time_info = ' 1epoch:[{:.2f}h]'.format( num_steps_per_epoch / interval_steps * elapsed / 3600) info.write( 'elapsed:[{:.3f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} ' .format(elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) info.write('train_avg_metric:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'train_step:%d' % step, info.getvalue())) if deal_results_fn is not None: stop = deal_results_fn(results) if is_start or eval_interval_steps and step % eval_interval_steps == 0: if ops is not None: if interval_steps != eval_interval_steps: train_average_loss = train_once.avg_loss2.avg_score() info = BytesIO() names_ = melt.adjust_names(results, names) train_average_loss_str = '' if print_avg_loss and interval_steps != eval_interval_steps: train_average_loss_str = melt.value_name_list_str( train_average_loss, names_) melt.set_global('train_loss', train_average_loss_str) train_average_loss_str = 'train_avg_metric:{} '.format( train_average_loss_str) if interval_steps != eval_interval_steps: #end = '' if eval_ops is None else '\n' #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, train_average_loss_str, end=end) logging.info2('{} eval_step: {} {}'.format( epoch_str, step, train_average_loss_str)) if log_dir: if not hasattr(train_once, 'summary_op'): #melt.print_summary_ops() if summary_excls is None: train_once.summary_op = tf.summary.merge_all() else: summary_ops = [] for op in tf.get_collection(tf.GraphKeys.SUMMARIES): for summary_excl in summary_excls: if not summary_excl in op.name: summary_ops.append(op) print('filtered summary_ops:') for op in summary_ops: print(op) train_once.summary_op = tf.summary.merge(summary_ops) train_once.summary_train_op = tf.summary.merge_all( key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) summary = tf.Summary() #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss if train_once.summary_train_op is not None: summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict) train_once.summary_writer.add_summary(summary_str, step) if eval_ops is None: #get train loss, for every batch train if train_once.summary_op is not None: #timer2 = gezi.Timer('sess run') try: summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) except Exception: summary_str = '' #timer2.print() train_once.summary_writer.add_summary(summary_str, step) else: #get eval loss for every batch eval, then add train loss for eval step average loss try: summary_str = sess.run( train_once.summary_op, feed_dict=eval_feed_dict ) if train_once.summary_op is not None else '' except Exception: logging.warning( 'summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) fail' ) #logging.warning(traceback.format_exc()) summary_str = '' #all single value results will be add to summary here not using tf.scalar_summary.. summary.ParseFromString(summary_str) suffix = 'eval' if not eval_names else '' melt.add_summarys(summary, eval_results, eval_names_, suffix=suffix) melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg%dsteps' % eval_interval_steps) if metric_evaluate: melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() #timer_.print() if print_time: full_duration = train_once.eval_timer.elapsed() if metric_evaluate: metric_full_duration = train_once.metric_eval_timer.elapsed() full_duration_str = 'elapsed:{:.3f} '.format(full_duration) #info.write('duration:{:.3f} '.format(timer.elapsed())) duration = timer.elapsed() info.write('duration:{:.3f} '.format(duration)) info.write(full_duration_str) info.write('eval_time_ratio:{:.3f} '.format(duration / full_duration)) if metric_evaluate: info.write('metric_time_ratio:{:.3f} '.format( duration / metric_full_duration)) #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue()) logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d' % step, info.getvalue())) return stop
def tf_train_flow( train_once_fn, model_dir='./model', max_models_keep=1, save_interval_seconds=600, save_interval_steps=1000, num_epochs=None, num_steps=None, save_model=True, save_interval_epochs=1, num_steps_per_epoch=0, restore_from_latest=True, metric_eval_fn=None, init_fn=None, restore_fn=None, restore_scope=None, save_all_scope=False, #TODO save load from restore scope only but svae all variables_to_restore=None, variables_to_save=None, #by default will be the same as variables_to_restore sess=None): """ similary flow as tf_flow, but add model try reload and save """ if sess is None: #TODO melt.get_session is global session but may cause non close at last sess = melt.get_session() logging.info('tf_train_flow start') print('max_models_keep:', max_models_keep, file=sys.stderr) print('save_interval_seconds:', save_interval_seconds, file=sys.stderr) #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars! #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here var_list = None if not restore_scope else tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope) if not variables_to_restore: variables_to_restore = var_list if not variables_to_save: variables_to_save = variables_to_restore if save_all_scope: variables_to_save = None if variables_to_restore is None: #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir) #print(varnames_in_checkpoint) variables_to_restore = slim.get_variables_to_restore( include=varnames_in_checkpoint) #logging.info('variables_to_restore:{}'.format(variables_to_restore)) loader = tf.train.Saver(var_list=variables_to_restore) saver = tf.train.Saver( max_to_keep=max_models_keep, keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0, var_list=variables_to_save) epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000) best_epoch_saver = tf.train.Saver(var_list=variables_to_save) ##TODO for safe restore all init will be ok ? #if variables_to_restore is None: init_op = tf.group( tf.global_variables_initializer( ), #variables_initializer(global_variables()) tf.local_variables_initializer() ) #variables_initializer(local_variables()) # else: # init_op = tf.group(tf.variables_initializer(variables_to_restore), # tf.local_variables_initializer()) ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong ##so assume to all run init op! if using assistant predictor, make sure it use another session sess.run(init_op) #melt.init_uninitialized_variables(sess) #pre_step means the step last saved, train without pretrained,then -1 pre_step = -1 fixed_pre_step = -1 #fixed pre step is for epoch num to be correct if yu change batch size model_path = _get_model_path(model_dir, save_model) model_dir = gezi.get_dir( model_dir) #incase you pass ./model/model-ckpt1000 -> ./model if model_path is not None: if not restore_from_latest: print('using recent but not latest model', file=sys.stderr) model_path = melt.recent_checkpoint(model_dir) model_name = os.path.basename(model_path) timer = gezi.Timer('Loading and training from existing model [%s]' % model_path) if restore_fn is not None: restore_fn(sess) loader.restore(sess, model_path) timer.print() pre_step = melt.get_model_step(model_path) pre_epoch = melt.get_model_epoch(model_path) fixed_pre_step = pre_step if pre_epoch is not None: #like using batch size 32, then reload train using batch size 64 if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1: fixed_pre_step = int(pre_epoch * num_steps_per_epoch) logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\ .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step)) else: print('Train all start step 0', file=sys.stderr) #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()), #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()), #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively. #init_op = tf.group(tf.global_variables_initializer(), # tf.local_variables_initializer()) #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope) #sess.run(init_op) #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok #for finetune from loading other model init if init_fn is not None: init_fn(sess) if save_interval_epochs and num_steps_per_epoch: epoch_dir = os.path.join(model_dir, 'epoch') gezi.try_mkdir(epoch_dir) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) checkpoint_path = os.path.join(model_dir, 'model.ckpt') tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt') only_one_step = False try: step = start = pre_step + 1 fixed_step = fixed_pre_step + 1 #hack just for save one model after load if num_steps < 0 or (num_steps and num_steps < step): print('just load and resave then exit', file=sys.stderr) saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) sess.close() exit(0) if num_epochs < 0: only_one_step = True print('just run one step', file=sys.stderr) early_stop = True #TODO allow config num_bad_epochs = 0 pre_epoch_eval_loss = 1e20 best_epoch_eval_loss = 1e20 num_allowed_bad_epochs = 4 #allow 5 non decrease eval loss epochs before stop while not coord.should_stop(): stop = train_once_fn(sess, step, is_start=(step == start), fixed_step=fixed_step) if only_one_step: stop = True if save_model and step: #step 0 is also saved! actually train one step and save if step % save_interval_steps == 0: timer = gezi.Timer('save model step %d to %s' % (step, checkpoint_path)) saver.save(sess, _get_checkpoint_path(checkpoint_path, fixed_step, num_steps_per_epoch), global_step=step) timer.print() #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0: #if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0: if save_interval_epochs and num_steps_per_epoch and fixed_step % num_steps_per_epoch == 0: #epoch = step // num_steps_per_epoch epoch = fixed_step // num_steps_per_epoch eval_loss = melt.eval_loss() if eval_loss: #['eval_loss:3.2','eal_accuracy:4.3'] eval_loss = float( eval_loss.strip('[]').split(',')[0].strip( "'").split(':')[-1]) if os.path.exists( os.path.join(epoch_dir, 'best_eval_loss.txt')): with open( os.path.join(epoch_dir, 'best_eval_loss.txt')) as f: best_epoch_eval_loss = float( f.readline().split()[-1].strip()) if eval_loss < best_epoch_eval_loss: best_epoch_eval_loss = eval_loss logging.info( 'Now best eval loss is epoch %d eval_loss:%f' % (epoch, eval_loss)) with open( os.path.join(epoch_dir, 'best_eval_loss.txt'), 'w') as f: f.write('%d %d %f\n' % (epoch, step, best_epoch_eval_loss)) best_epoch_saver.save( sess, os.path.join(epoch_dir, 'model.ckpt-best')) with open(os.path.join(epoch_dir, 'eval_loss.txt'), 'a') as f: f.write('%d %d %f\n' % (epoch, step, eval_loss)) if eval_loss >= pre_epoch_eval_loss: num_bad_epochs += 1 if num_bad_epochs > num_allowed_bad_epochs: logging.warning( 'Evaluate loss not decrease for last %d epochs' % (num_allowed_bad_epochs + 1)) if not os.path.exists( os.path.join(epoch_dir, 'model.ckpt-noimprove')): best_epoch_saver.save( sess, os.path.join(epoch_dir, 'model.ckpt-noimprove')) ##-------well remove it since #if early_stop: # stop = True else: num_bad_epochs = 0 pre_epoch_eval_loss = eval_loss if step % (num_steps_per_epoch * save_interval_epochs) == 0: epoch_saver.save(sess, os.path.join(epoch_dir, 'model.ckpt-%d' % epoch), global_step=step) #--------do not add step # epoch_saver.save(sess, # os.path.join(epoch_dir,'model.ckpt-%d'%epoch)) if stop is True: print('Early stop running %d stpes' % (step), file=sys.stderr) raise tf.errors.OutOfRangeError( None, None, 'Early stop running %d stpes' % (step)) if num_steps and (step + 1) == start + num_steps: raise tf.errors.OutOfRangeError(None, None, 'Reached max num steps') #max_num_epochs = 1000 max_num_epochs = num_epochs if max_num_epochs and num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs: raise tf.errors.OutOfRangeError( None, None, 'Reached max num epochs of %d' % max_num_epochs) step += 1 fixed_step += 1 except tf.errors.OutOfRangeError, e: if not (step == start) and save_model and step % save_interval_steps != 0: saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) if only_one_step: print('Done one step', file=sys.stderr) exit(0) if metric_eval_fn is not None: metric_eval_fn() if (num_epochs and step / num_steps_per_epoch >= num_epochs) or ( num_steps and (step + 1) == start + num_steps): print('Done training for %.3f epochs, %d steps.' % (step / num_steps_per_epoch, step + 1), file=sys.stderr) #FIXME becase coord.join seems not work, RuntimeError: Coordinator stopped with threads still running: Thread-9 exit(0) else: print('Should not stop, but stopped at epoch: %.3f' % (step / num_steps_per_epoch), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) raise e
def tf_train_flow(train_once_fn, model_dir='./model', max_models_keep=1, save_interval_seconds=600, save_interval_steps=1000, num_epochs=None, num_steps=None, save_model=True, save_interval_epochs=1, num_steps_per_epoch=0, restore_from_latest=True, metric_eval_fn=None, init_fn=None, sess=None): """ similary flow as tf_flow, but add model try reload and save """ if sess is None: #TODO melt.get_session is global session but may cause sess = melt.get_session() logging.info('tf_train_flow start') print('max_models_keep:', max_models_keep) print('save_interval_seconds:', save_interval_seconds) saver = tf.train.Saver( max_to_keep=max_models_keep, keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0) epoch_saver = tf.train.Saver() best_epoch_saver = tf.train.Saver() #pre_step means the step last saved, train without pretrained,then -1 pre_step = -1; model_path = _get_model_path(model_dir, save_model) model_dir = gezi.get_dir(model_dir) #incase you pass ./model/model-ckpt1000 -> ./model if model_path is not None: if not restore_from_latest: print('using recent but not latest model', file=sys.stderr) model_path = melt.recent_checkpoint(model_dir) model_name = os.path.basename(model_path) timer = gezi.Timer('Loading and training from existing model [%s]'%model_path) saver.restore(sess, model_path) timer.print() pre_step = melt.get_model_step(model_path) if 'epoch' in model_name: pre_step *= num_steps_per_epoch #for non 0 eopochs without this will be #Attempting to use uninitialized value input/input_producer/limit_epochs/epochs sess.run(tf.local_variables_initializer()) else: print('Train all start step 0', file=sys.stderr) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) if init_fn is not None: init_fn(sess) if save_interval_epochs and num_steps_per_epoch: epoch_dir = os.path.join(model_dir, 'epoch') gezi.try_mkdir(epoch_dir) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) checkpoint_path = os.path.join(model_dir, 'model.ckpt') tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt') try: step = start = pre_step + 1 #hack just for save one model after load if num_steps and num_steps < step: print('just load and resave then exit', file=sys.stderr) saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) sess.close() exit(0) early_stop = True #TODO allow config num_bad_epochs = 0 pre_epoch_eval_loss = 1e20 best_epoch_eval_loss = 1e20 num_allowed_bad_epochs = 4 #allow 5 non decrease eval loss epochs before stop while not coord.should_stop(): stop = train_once_fn(sess, step, is_start=(step==start)) if save_model and step: #step 0 is also saved! actually train one step and save if step % save_interval_steps == 0: timer = gezi.Timer('save model step %d to %s'%(step, checkpoint_path)) saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) timer.print() #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0: if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0: epoch = step // num_steps_per_epoch eval_loss = melt.eval_loss() if eval_loss: #['eval_loss:3.2','eal_accuracy:4.3'] eval_loss = float(eval_loss.strip('[]').split(',')[0].strip("'").split(':')[-1]) if os.path.exists(os.path.join(epoch_dir, 'best_eval_loss.txt')): with open(os.path.join(epoch_dir, 'best_eval_loss.txt')) as f: best_epoch_eval_loss = float(f.readline().split()[-1].strip()) if eval_loss < best_epoch_eval_loss: best_epoch_eval_loss = eval_loss logging.info('Now best eval loss is epoch %d eval_loss:%f' % (epoch, eval_loss)) with open(os.path.join(epoch_dir, 'best_eval_loss.txt'), 'w') as f: f.write('%d %d %f\n'%(epoch, step, best_epoch_eval_loss)) best_epoch_saver.save(sess, os.path.join(epoch_dir,'model.cpkt-best')) with open(os.path.join(epoch_dir, 'eval_loss.txt'), 'a') as f: f.write('%d %d %f\n'%(epoch, step, eval_loss)) if eval_loss >= pre_epoch_eval_loss: num_bad_epochs += 1 if num_bad_epochs > num_allowed_bad_epochs: logging.warning('Evaluate loss not decrease for last %d epochs'% (num_allowed_bad_epochs + 1)) if not os.path.exists(os.path.join(epoch_dir,'model.cpkt-noimprove')): best_epoch_saver.save(sess, os.path.join(epoch_dir,'model.cpkt-noimprove')) ##-------well remove it since #if early_stop: # stop = True else: num_bad_epochs = 0 pre_epoch_eval_loss = eval_loss if step % (num_steps_per_epoch * save_interval_epochs) == 0: epoch_saver.save(sess, os.path.join(epoch_dir,'model.cpkt-%d'%epoch), global_step=step) #--------do not add step # epoch_saver.save(sess, # os.path.join(epoch_dir,'model.cpkt-%d'%epoch)) if stop is True: print('Early stop running %d stpes'%(step), file=sys.stderr) raise tf.errors.OutOfRangeError(None, None,'Early stop running %d stpes'%(step)) if num_steps and (step + 1) == start + num_steps: raise tf.errors.OutOfRangeError(None, None,'Reached max num steps') #max_num_epochs = 1000 max_num_epochs = num_epochs if num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs: raise tf.errors.OutOfRangeError(None, None,'Reached max num epochs of %d'%max_num_epochs) step += 1 except tf.errors.OutOfRangeError, e: if not (step==start) and save_model and step % save_interval_steps != 0: saver.save(sess, _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), global_step=step) if metric_eval_fn is not None: metric_eval_fn() if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (num_steps and (step + 1) == start + num_steps) : print('Done training for %.3f epochs, %d steps.' % (step / num_steps_per_epoch, step + 1), file=sys.stderr) #FIXME becase coord.join seems not work, RuntimeError: Coordinator stopped with threads still running: Thread-9 exit(0) else: print('Should not stop, but stopped at epoch: %.3f'%(step / num_steps_per_epoch), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) raise e