示例#1
0
文件: util.py 项目: meng-jia/wenzheng
def create_restore_fn(checkpoint, model_name, restore_model_name):
  model_name = gezi.pascal2gnu(model_name)
  restore_model_name = gezi.pascal2gnu(restore_model_name)
  
  variables_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model_name)
  assert variables_to_restore, tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

  prefix  = '%s/%s' % (model_name, restore_model_name)

  # remove model_name
  def name_in_checkpoint(var):
    return var.op.name.replace(prefix, restore_model_name)

  variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore if var.op.name.startswith(prefix)}
  
  varnames_in_checkpoint = melt.get_checkpoint_varnames(checkpoint)
  # FIXME wrong..
  variables_to_restore = {var2:var for var2 in varnames_in_checkpoint}

  saver = tf.train.Saver(variables_to_restore)

  def restore_fn(sess):
    timer = gezi.Timer('restore var from %s %s' % (restore_model_name, checkpoint))
    saver.restore(sess, checkpoint)
    timer.print()

  return restore_fn
示例#2
0
文件: util.py 项目: meng-jia/wenzheng
def restore_from_path(sess, model_path, var_list=None):
  if var_list is None:
    varnames_in_checkpoint = melt.get_checkpoint_varnames(model_path)
    #logging.info('varnames_in_checkpoint: {}'.format(varnames_in_checkpoint))
    var_list = slim.get_variables_to_restore(include=varnames_in_checkpoint)
  saver = tf.train.Saver(var_list)
  saver.restore(sess, model_path)
  print('restore ok:', model_path, file=sys.stderr)
  sess.run(tf.local_variables_initializer())
  return saver
示例#3
0
文件: util.py 项目: meng-jia/wenzheng
def restore(sess, model_dir, var_list=None, model_name=None):
  assert model_dir
  if var_list is None:
    varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
    #logging.info('varnames_in_checkpoint: {}'.format(varnames_in_checkpoint))
    var_list = slim.get_variables_to_restore(include=varnames_in_checkpoint)
  saver = tf.train.Saver(var_list)
  model_path = get_model_path(model_dir, model_name)
  #assert model_path and os.path.exists(model_path), model_path
  saver.restore(sess, model_path)
  #@TODO still write to file ? using >
  print('restore ok:', model_path, file=sys.stderr)
  sess.run(tf.local_variables_initializer())
  return saver
示例#4
0
def tf_train_flow(
        train_once_fn,
        model_dir=None,
        log_dir=None,
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=None,
        freeze_graph=False,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        valid_interval_epochs=0,
        inference_fn=None,
        inference_interval_epochs=0,
        init_fn=None,
        restore_fn=None,
        restore_include=None,
        restore_exclude=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        output_collection_names=None,
        output_node_names=None,
        learning_rate=None,  #not use yet, just use in train_once
        learning_rate_patience=None,
        learning_rate_decay_factor=None,
        write_during_train=True,
        model=None,
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    model_dir_ = model_dir
    if use_horovod and hvd.rank() != 0:
        model_dir = None

    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()

    if FLAGS.use_tpu:
        sess.run(tpu.initialize_system())
    #logging.info('tf_train_flow start')
    #logging.info('max_models_keep:', max_models_keep)
    #logging.info('save_interval_seconds:', save_interval_seconds)

    if model_dir:
        if model:
            checkpoint = tf.train.Checkpoint(model=model)
            ckpt_dir = model_dir + '/ckpt'
            checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')

        #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
        #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

        # var_list = None if not restore_scope else tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
        # #logging.info('-------------var_list', var_list)

        # if not variables_to_restore:
        #   variables_to_restore = var_list

        if not variables_to_restore:
            variables_to_restore = slim.get_variables_to_restore(
                include=restore_include, exclude=restore_exclude)

        if not variables_to_save:
            variables_to_save = variables_to_restore
        if save_all_scope:
            variables_to_save = None

        #if variables_to_restore is None:
        logging.info('variables_to_restore from %s' % model_dir)
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #logging.info('varnames_in_checkpoint: {}'.format(varnames_in_checkpoint))

        # TODO has someproblem say  tf.Variable 'r_net/text_encoder/cudnn_rnn/cu_dnngru/recurrent_kernel/adam_v:0' even though in checkpoint I have renated it as ignore/rnet
        variables_to_restore_from_model = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)
        #logging.info('variables_to_restore_from_model: {}'.format(variables_to_restore_from_model))
        if not variables_to_restore:
            variables_to_restore = variables_to_restore_from_model
        else:
            variables_to_restore = [
                v for v in variables_to_restore
                if v in variables_to_restore_from_model
            ]
        if restore_exclude:
            for excl in restore_exclude:
                variables_to_restore = [
                    v for v in variables_to_restore if not excl in v.name
                ]
        #--tf 1.6 adadelta will have same vars...
        variables_to_restore = list(set(variables_to_restore))
        #logging.info('variables_to_restore', variables_to_restore[:100])
        logging.info('variables_to_restore', [
            x for x in variables_to_restore if not 'OptimizeLoss' in x.name
        ][:100])

    ##finally remove global_step since melt.apps.train will handle it!
    global_step = tf.train.get_or_create_global_step()

    #variables_to_restore = [v for v in variables_to_restore if not tf.GraphKeys.GLOBAL_STEP in v.name]
    #variables_to_restore = [v for v in variables_to_restore if not 'learning_rate' in v.name]

    # TODO fixme if step, step2.. and in checkpoint step then here will be step2...
    #print('------------', [v for v in variables_to_restore if 'step' in v.name])
    loader = tf.train.Saver(var_list=variables_to_restore)

    logging.info('max models to keep {}, keep every {} hours'.format(
        max_models_keep, save_interval_seconds / 3600.0))
    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)
    #logging.info('variables_to_save:{}'.format(variables_to_save))

    # # #TODO for safe restore all init will be ok ?
    # # if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # # else:
    # #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    # #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    # https://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables
    # def guarantee_initialized_variables(session, list_of_variables = None):
    #   if list_of_variables is None:
    #       list_of_variables = tf.global_variables()
    #   uninitialized_variables = list(tf.get_variable(name) for name in
    #                                  session.run(tf.report_uninitialized_variables(list_of_variables)))
    #   return unintialized_variables

    # unintialized_variables = guarantee_initialized_variables(sess)
    # init_op = tf.group(tf.initialize_variables(uninitialized_vars), tf.local_variables_initializer())

    timer = gezi.Timer('sess run init_op in melt.tf_train_flow')
    #model.save('./weights')

    # notice
    sess.run(init_op)

    timer.print_elapsed()

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if you change batch size
    #print(model_dir)
    pre_epoch = None
    if model_dir:
        model_path = _get_model_path(model_dir, save_model)
        # if not model_path:
        #   model_path = _get_model_path(os.path.join(model_dir, 'epoch'))
        #print(model_path)
        model_dir = gezi.get_dir(
            model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model

        if model_path is not None:
            if not restore_from_latest:
                logging.info('using recent but not latest model')
                model_path = melt.recent_checkpoint(model_dir)
            model_name = os.path.basename(model_path)
            timer = gezi.Timer(
                'Loading and training from existing model [%s]' % model_path)
            if restore_fn is not None:
                restore_fn(sess)
            loader.restore(sess, model_path)
            ## not supported
            #model.save()
            #model.save_weights('./weights')
            timer.print()
            #pre_step = melt.get_model_step(model_path) - 1 if FLAGS.global_step is None else FLAGS.global_step -1
            # TODO check ..
            pre_step = sess.run(tf.train.get_global_step()) - 1
            pre_epoch = melt.get_model_epoch(
                model_path
            ) if FLAGS.global_epoch is None else FLAGS.global_epoch
            fixed_pre_step = pre_step
            # if pre_epoch is not None:
            #   #like using batch size 32, then reload train using batch size 64
            #   if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
            #     fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
            #     logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
            #       .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
        else:
            latest_checkpoint = None
            if not use_horovod:  #now will hang
                try:
                    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
                    if latest_checkpoint:
                        logging.info(
                            'Try start from eager trained mode, latest checkpoint:',
                            latest_checkpoint)
                        checkpoint.restore(latest_checkpoint).run_restore_ops(
                            session=sess)

                        pre_epoch = int(latest_checkpoint.split('-')[-1])
                        #pre_step = pre_epoch * num_steps_per_epoch - 1
                        # TODO check
                        pre_step = sess.run(tf.train.get_global_step()) - 1
                        fixed_pre_step = pre_step
                        logging.info('Start step is:', pre_step)
                except Exception:
                    logging.info(
                        'Something wrong with restore from eager trained model'
                    )
                if latest_checkpoint is None:
                    logging.info('Train all start step 0')
                    #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
                    #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
                    #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
                    #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
                    #init_op = tf.group(tf.global_variables_initializer(),
                    #                   tf.local_variables_initializer())
                    #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

                    #sess.run(init_op)

                    #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
                    #for finetune from loading other model init
                    if init_fn is not None:
                        init_fn(sess)

    if gezi.env_has('METRIC'):
        l = metric_eval_fn(model_path)
        print(list(zip(l[1], l[0])))
        exit(0)

    #sess.run(tf.assign(global_step, tf.constant(global_step_val, dtype=tf.int64)))
    try:
        learning_rate = tf.get_collection('learning_rate')[-1]
        learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
        sess.run(tf.assign(learning_rate,
                           learning_rate * learning_rate_weight))
    except Exception:
        # if not using weight_decay but using optimizer decay then will go here as learning rate is a tensor can not assign
        pass

    try:
        logging.info('Actual start global step:',
                     sess.run(global_step), 'learning rate:',
                     sess.run(learning_rate), 'learning_rate_weight:',
                     sess.run(learning_rate_weight))
    except Exception:
        pass

    if model_dir_:
        #if save_interval_epochs and num_steps_per_epoch and num_steps >= 0:
        epoch_dir = os.path.join(model_dir_, 'epoch')
        gezi.try_mkdir(epoch_dir)
        checkpoint_path = os.path.join(model_dir_, 'model.ckpt')

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if use_horovod:
        bcast = hvd.broadcast_global_variables(0)
        sess.run(bcast)

    #tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        if use_horovod:
            ## TODO FIXME why bcast here not work ? simple test work see tests/bcast.py
            #comm.bcast(pre_step, root=0)
            temp = np.array([pre_step, fixed_pre_step])
            comm.Bcast(temp, root=0)
            pre_step = temp[0]
            fixed_pre_step = temp[1]

        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1

        #first = True

        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            logging.info('just load and resave then exit')
            model_path_ = _get_checkpoint_path(checkpoint_path,
                                               step,
                                               num_steps_per_epoch,
                                               epoch=pre_epoch)
            saver.save(sess, model_path_, global_step=step + 1)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step + 1,
                                  output_collection_names, output_node_names)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            logging.info('just run one step')

        if FLAGS.work_mode != 'train':
            assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
            if 'valid' in FLAGS.work_mode:
                vals, names = metric_eval_fn(FLAGS.model_dir)
                logging.info(list(zip(names, vals)))
            if 'test' in FLAGS.work_mode:
                inference_fn(FLAGS.model_dir)
            exit(0)

        #early_stop = True #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        epoch_saved_step = 0
        while not coord.should_stop():
            model_step_path = None
            if model_dir_:
                model_path_ = os.path.join(
                    epoch_dir, 'model.ckpt-%.2f' %
                    (fixed_step / float(num_steps_per_epoch)))
                model_step_path_ = model_path_ + '-' + str(step)
                if (write_during_train and metric_eval_fn is not None
                        and valid_interval_epochs and fixed_step %
                        int(num_steps_per_epoch * valid_interval_epochs) == 0):
                    model_step_path = model_step_path_
                else:
                    model_step_path = None

            if step == 0:
                model_step_path = None

            #print('--------------------step', step)
            stop = train_once_fn(
                sess,
                step,
                is_start=(step == start),
                fixed_step=fixed_step,
                num_epochs=num_epochs,
                model_path=model_step_path,
                use_horovod=use_horovod,
                ## TODO FIXME this line will cause   tensorflow.python.framework.errors_impl.NotFoundError: Resource localhost/save_counter/N10tensorflow3VarE does not exist.
            )

            #first = False

            if only_one_step:
                stop = True

            step += 1
            fixed_step += 1

            if save_model and step and model_dir:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer(
                        'save model step %d to %s' % (step, checkpoint_path),
                        False)
                    model_path_ = _get_checkpoint_path(checkpoint_path,
                                                       fixed_step,
                                                       num_steps_per_epoch)
                    saver.save(sess, model_path_, global_step=step)
                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)
                    #if log_dir != model_dir:
                    #  assert log_dir
                    #  command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                    #  print(command, file=sys.stderr)
                    #  os.system(command)
                    timer.print_elapsed()

                if save_interval_steps and num_steps_per_epoch and fixed_step % int(
                        num_steps_per_epoch * save_interval_epochs) == 0:
                    # TODO only epoch in name not sep ?
                    epoch_saved_step = step
                    model_path_ = os.path.join(
                        epoch_dir, 'model.ckpt-%.2f' %
                        (fixed_step / float(num_steps_per_epoch)))
                    model_step_path = model_path_ + '-' + str(step)
                    epoch_saver.save(sess, model_path_, global_step=step)
                    #epoch_saver.save(sess, model_path_)

                    ## TODO FIXME do not support tf.keras save currently with horovod
                    # if model:
                    #   #model.save_weights(epoch_dir + '/ckpt-%.2f' % (fixed_step / float(num_steps_per_epoch)))
                    #   # TODO FIXME if restart will save from 1... again..
                    #   checkpoint.save(checkpoint_prefix, session=sess)
                    #   #print(sess.run(checkpoint.save_counter))

                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)

                if write_during_train:
                    if inference_fn is not None and inference_interval_epochs and fixed_step % int(
                            num_steps_per_epoch *
                            inference_interval_epochs) == 0:
                        model_step_path = model_path_ + '-' + str(step)
                        try:
                            #print('--------------inference fn')
                            inference_fn(model_path=model_step_path)
                        except Exception:
                            logging.info(traceback.format_exc())

                    # if metric_eval_fn is not None and valid_interval_epochs and fixed_step % int(num_steps_per_epoch * valid_interval_epochs) == 0:
                    #   model_step_path = model_path_ + '-' + str(step)
                    #   try:
                    #     metric_eval_fn(model_path=model_step_path)
                    #   except Exception:
                    #     logging.info(traceback.format_exc())

            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            #if max_num_epochs and num_steps_per_epoch and fixed_step // num_steps_per_epoch >= max_num_epochs:
            if max_num_epochs and num_steps_per_epoch and fixed_step / num_steps_per_epoch > max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
    #except tf.errors.OutOfRangeError, e:
    except tf.errors.OutOfRangeError:
        # if run 2 epoch and we have just epoch saved, do not need to save only 1 step more model
        if (step - epoch_saved_step > 1) and not (
                step == start
        ) and save_model and step % save_interval_steps != 0 and model_dir:
            model_path_ = _get_checkpoint_path(checkpoint_path, step,
                                               num_steps_per_epoch)
            saver.save(sess, model_path_, global_step=step)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step,
                                  output_collection_names, output_node_names)
            if log_dir != model_dir:
                assert log_dir
                command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                print(command, file=sys.stderr)
                os.system(command)
        if only_one_step:
            logging.info('Done one step')
            exit(0)

        # if (step - epoch_saved_step > 1) and metric_eval_fn is not None:
        #   metric_eval_fn(model_path=model_step_path)

        if (num_epochs and fixed_step / num_steps_per_epoch >= num_epochs) or (
                num_steps and step == start + num_steps):
            logging.info('Done training for %.3f epochs, %d steps.' %
                         (fixed_step / num_steps_per_epoch, step))
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            logging.info('Should not stop, but stopped at epoch: %.3f' %
                         (fixed_step / num_steps_per_epoch))
            logging.info(traceback.format_exc())
            #raise e
    finally:
        coord.request_stop()

    coord.join(threads, stop_grace_period_secs=5)
    #FIMXE due to use melt.get_session(global not handle del well)
    #Done training for 3090020 steps.
    #Exception TypeError: "'NoneType' object is not callable" in <bound method Session.__del__ of <tensorflow.python.client.session.Session object at 0x7f6cf33cd450>> ignored
    if FLAGS.use_tpu:
        sess.run(tpu.shutdown_system())
    sess.close()
示例#5
0
def tf_train_flow(
        train_once_fn,
        model_dir='./model',
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=1,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        init_fn=None,
        restore_fn=None,
        restore_scope=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()
    logging.info('tf_train_flow start')
    print('max_models_keep:', max_models_keep, file=sys.stderr)
    print('save_interval_seconds:', save_interval_seconds, file=sys.stderr)

    #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
    #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

    var_list = None if not restore_scope else tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
    if not variables_to_restore:
        variables_to_restore = var_list
    if not variables_to_save:
        variables_to_save = variables_to_restore
    if save_all_scope:
        variables_to_save = None

    if variables_to_restore is None:
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #print(varnames_in_checkpoint)
        variables_to_restore = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)

    #logging.info('variables_to_restore:{}'.format(variables_to_restore))
    loader = tf.train.Saver(var_list=variables_to_restore)

    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)

    ##TODO for safe restore all init will be ok ?
    #if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # else:
    #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    sess.run(init_op)

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if yu change batch size
    model_path = _get_model_path(model_dir, save_model)
    model_dir = gezi.get_dir(
        model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model
    if model_path is not None:
        if not restore_from_latest:
            print('using recent but not latest model', file=sys.stderr)
            model_path = melt.recent_checkpoint(model_dir)
        model_name = os.path.basename(model_path)
        timer = gezi.Timer('Loading and training from existing model [%s]' %
                           model_path)
        if restore_fn is not None:
            restore_fn(sess)
        loader.restore(sess, model_path)
        timer.print()
        pre_step = melt.get_model_step(model_path)
        pre_epoch = melt.get_model_epoch(model_path)
        fixed_pre_step = pre_step
        if pre_epoch is not None:
            #like using batch size 32, then reload train using batch size 64
            if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
                fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
                logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
                  .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
    else:
        print('Train all start step 0', file=sys.stderr)
        #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
        #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
        #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
        #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
        #init_op = tf.group(tf.global_variables_initializer(),
        #                   tf.local_variables_initializer())
        #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

        #sess.run(init_op)

        #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
        #for finetune from loading other model init
        if init_fn is not None:
            init_fn(sess)

    if save_interval_epochs and num_steps_per_epoch:
        epoch_dir = os.path.join(model_dir, 'epoch')
        gezi.try_mkdir(epoch_dir)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    checkpoint_path = os.path.join(model_dir, 'model.ckpt')

    tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1
        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            print('just load and resave then exit', file=sys.stderr)
            saver.save(sess,
                       _get_checkpoint_path(checkpoint_path, step,
                                            num_steps_per_epoch),
                       global_step=step)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            print('just run one step', file=sys.stderr)

        early_stop = True  #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        while not coord.should_stop():
            stop = train_once_fn(sess,
                                 step,
                                 is_start=(step == start),
                                 fixed_step=fixed_step)
            if only_one_step:
                stop = True
            if save_model and step:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer('save model step %d to %s' %
                                       (step, checkpoint_path))
                    saver.save(sess,
                               _get_checkpoint_path(checkpoint_path,
                                                    fixed_step,
                                                    num_steps_per_epoch),
                               global_step=step)
                    timer.print()
                #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0:
                #if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0:
                if save_interval_epochs and num_steps_per_epoch and fixed_step % num_steps_per_epoch == 0:
                    #epoch = step // num_steps_per_epoch
                    epoch = fixed_step // num_steps_per_epoch
                    eval_loss = melt.eval_loss()
                    if eval_loss:
                        #['eval_loss:3.2','eal_accuracy:4.3']
                        eval_loss = float(
                            eval_loss.strip('[]').split(',')[0].strip(
                                "'").split(':')[-1])
                        if os.path.exists(
                                os.path.join(epoch_dir, 'best_eval_loss.txt')):
                            with open(
                                    os.path.join(epoch_dir,
                                                 'best_eval_loss.txt')) as f:
                                best_epoch_eval_loss = float(
                                    f.readline().split()[-1].strip())
                        if eval_loss < best_epoch_eval_loss:
                            best_epoch_eval_loss = eval_loss
                            logging.info(
                                'Now best eval loss is epoch %d eval_loss:%f' %
                                (epoch, eval_loss))
                            with open(
                                    os.path.join(epoch_dir,
                                                 'best_eval_loss.txt'),
                                    'w') as f:
                                f.write('%d %d %f\n' %
                                        (epoch, step, best_epoch_eval_loss))
                            best_epoch_saver.save(
                                sess, os.path.join(epoch_dir,
                                                   'model.ckpt-best'))

                        with open(os.path.join(epoch_dir, 'eval_loss.txt'),
                                  'a') as f:
                            f.write('%d %d %f\n' % (epoch, step, eval_loss))
                        if eval_loss >= pre_epoch_eval_loss:
                            num_bad_epochs += 1
                            if num_bad_epochs > num_allowed_bad_epochs:
                                logging.warning(
                                    'Evaluate loss not decrease for last %d epochs'
                                    % (num_allowed_bad_epochs + 1))
                                if not os.path.exists(
                                        os.path.join(epoch_dir,
                                                     'model.ckpt-noimprove')):
                                    best_epoch_saver.save(
                                        sess,
                                        os.path.join(epoch_dir,
                                                     'model.ckpt-noimprove'))
                                ##-------well remove it since
                                #if early_stop:
                                #  stop = True
                        else:
                            num_bad_epochs = 0
                        pre_epoch_eval_loss = eval_loss
                    if step % (num_steps_per_epoch *
                               save_interval_epochs) == 0:
                        epoch_saver.save(sess,
                                         os.path.join(epoch_dir,
                                                      'model.ckpt-%d' % epoch),
                                         global_step=step)
                    #--------do not add step
                    # epoch_saver.save(sess,
                    #        os.path.join(epoch_dir,'model.ckpt-%d'%epoch))
            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            if max_num_epochs and num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
            step += 1
            fixed_step += 1
    except tf.errors.OutOfRangeError, e:
        if not (step
                == start) and save_model and step % save_interval_steps != 0:
            saver.save(sess,
                       _get_checkpoint_path(checkpoint_path, step,
                                            num_steps_per_epoch),
                       global_step=step)
        if only_one_step:
            print('Done one step', file=sys.stderr)
            exit(0)
        if metric_eval_fn is not None:
            metric_eval_fn()
        if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (
                num_steps and (step + 1) == start + num_steps):
            print('Done training for %.3f epochs, %d steps.' %
                  (step / num_steps_per_epoch, step + 1),
                  file=sys.stderr)
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            print('Should not stop, but stopped at epoch: %.3f' %
                  (step / num_steps_per_epoch),
                  file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
            raise e