示例#1
0
def train_flow(ops,
               names=None,
               gen_feed_dict_fn=None,
               deal_results_fn=melt.print_results,
               interval_steps=100,
               eval_ops=None,
               eval_names=None,
               gen_eval_feed_dict_fn=None,
               deal_eval_results_fn=melt.print_results,
               valid_interval_steps=100,
               eval_loops=1,
               print_time=True,
               print_avg_loss=True,
               model_dir=None,
               max_models_keep=5,
               save_interval_seconds=600,
               save_interval_steps=1000,
               freeze_graph=False,
               log_dir=None,
               no_log=False,
               num_epochs=None,
               num_steps=None,
               num_steps_per_epoch=None,
               optimizer=None,
               learning_rate=None,
               learning_rate_patience=None,
               learning_rate_decay_factor=None,
               save_model=True,
               save_interval_epochs=None,
               add_train_var_histogram=False,
               restore_from_latest=True,
               metric_eval_fn=None,
               metric_eval_interval_steps=0,
               valid_interval_epochs=0,
               inference_fn=None,
               inference_interval_epochs=0,
               summary_excls=None,
               init_fn=None,
               restore_fn=None,
               restore_include=None,
               restore_exclude=None,
               save_all_scope=False,
               variables_to_restore=None,
               variables_to_save=None,
               output_collection_names=None,
               output_node_names=None,
               write_during_train=True,
               use_horovod=False,
               model=None,
               sess=None):
    """
  train flow for tr records, with model saving/reload and summary considered
  summary logs will also write to model_dir 
  see examples/sparse-tensor-classification/train-melt-savemodel.py
  NOTICE: first ops must be train_op(optimizer related) which will later ignored
  #@TODO allow adding momentum for optimzer
  allow mutliple gpu
  @TODO can we show epoch num info ?
  """
    if optimizer is not None:
        loss = ops[0]
        if isinstance(optimizer, str):
            train_op = melt.gen_train_op_byname(loss, learning_rate, optimizer)
        else:
            train_op = optimizer(learning_rate).minimize(loss)
        ops = list(ops)
        ops.insert(0, train_op)

    if not model_dir:
        if log_dir and no_log:
            log_dir = None
        if not use_horovod:
            return simple_train_flow(
                ops,
                names,
                gen_feed_dict_fn,
                deal_results_fn,
                interval_steps,
                eval_ops,
                eval_names,
                gen_eval_feed_dict_fn,
                deal_eval_results_fn,
                valid_interval_steps,
                print_time,
                print_avg_loss,
                log_dir,
                num_steps,
                num_steps_per_epoch=num_steps_per_epoch,
                metric_eval_fn=metric_eval_fn,
                metric_eval_interval_steps=metric_eval_interval_steps,
                sess=sess)

    #if not set log dir try to use model dir to store log
    #so defaut is write log, if only want save model but disable log, set no_log=True
    if save_model:
        print('Will save model to %s' % model_dir)
    else:
        no_log = True
        print('Will not save model, only read model from %s if exists' %
              model_dir)

    if not log_dir and not no_log:
        log_dir = gezi.get_dir(model_dir)
    if log_dir:
        print('Will save log to %s' % log_dir)
        if add_train_var_histogram:
            # Add histograms for trainable variables.
            #this is also great for you to see all the trainable variables on tensorboard
            #NOTICE for big model this is too slow!
            melt.monitor_train_vars()
    else:
        print('Will not save log')

    def train_once_(sess,
                    step,
                    is_start=False,
                    fixed_step=None,
                    num_epochs=None,
                    model_path=None,
                    use_horovod=False):
        train_once(sess,
                   step,
                   ops,
                   names,
                   gen_feed_dict_fn,
                   deal_results_fn,
                   interval_steps,
                   eval_ops,
                   eval_names,
                   gen_eval_feed_dict_fn,
                   deal_eval_results_fn,
                   valid_interval_steps,
                   print_time,
                   print_avg_loss,
                   model_dir,
                   log_dir,
                   is_start,
                   num_steps_per_epoch,
                   metric_eval_fn=metric_eval_fn,
                   metric_eval_interval_steps=metric_eval_interval_steps,
                   summary_excls=summary_excls,
                   fixed_step=fixed_step,
                   eval_loops=eval_loops,
                   learning_rate=learning_rate,
                   learning_rate_patience=learning_rate_patience,
                   learning_rate_decay_factor=learning_rate_decay_factor,
                   num_epochs=num_epochs,
                   model_path=model_path,
                   use_horovod=use_horovod)

    #print('1.2--------------------OMPI_COMM_WORLD_RANK in', 'OMPI_COMM_WORLD_RANK' in os.environ, hvd.rank())
    tf_train_flow(train_once_,
                  model_dir,
                  log_dir,
                  max_models_keep,
                  save_interval_seconds,
                  save_interval_steps,
                  num_epochs,
                  num_steps,
                  save_model=save_model,
                  save_interval_epochs=save_interval_epochs,
                  freeze_graph=freeze_graph,
                  num_steps_per_epoch=num_steps_per_epoch,
                  restore_from_latest=restore_from_latest,
                  metric_eval_fn=metric_eval_fn,
                  valid_interval_epochs=valid_interval_epochs,
                  inference_fn=inference_fn,
                  inference_interval_epochs=inference_interval_epochs,
                  init_fn=init_fn,
                  restore_fn=restore_fn,
                  restore_include=restore_include,
                  restore_exclude=restore_exclude,
                  save_all_scope=save_all_scope,
                  variables_to_restore=variables_to_restore,
                  variables_to_save=variables_to_save,
                  output_collection_names=output_collection_names,
                  output_node_names=output_node_names,
                  learning_rate=learning_rate,
                  learning_rate_patience=learning_rate_patience,
                  learning_rate_decay_factor=learning_rate_decay_factor,
                  write_during_train=write_during_train,
                  model=model,
                  sess=sess)
示例#2
0
def train_flow(ops,
               names=None,
               gen_feed_dict_fn=None,
               deal_results_fn=melt.print_results,
               interval_steps=100,
               eval_ops=None,
               eval_names=None,
               gen_eval_feed_dict_fn=None,
               deal_eval_results_fn=melt.print_results,
               eval_interval_steps=100,
               print_time=True,
               print_avg_loss=True,
               model_dir='./model',
               max_models_keep=5,
               save_interval_seconds=600,
               save_interval_steps=1000,
               log_dir=None,
               no_log=False,
               num_epochs=None,
               num_steps=None,
               num_steps_per_epoch=None,
               optimizer=None,
               learning_rate=0.1,
               save_model=True,
               save_interval_epochs=True,
               add_train_var_histogram=False,
               restore_from_latest=True,
               metric_eval_fn=None,
               metric_eval_interval_steps=0,
               summary_excls=None,
               init_fn=None,
               sess=None):
    """
  train flow for tr records, with model saving/reload and summary considered
  summary logs will also write to model_dir 
  see examples/sparse-tensor-classification/train-melt-savemodel.py
  NOTICE: first ops must be train_op(optimizer related) which will later ignored
  #@TODO allow adding momentum for optimzer
  allow mutliple gpu
  @TODO can we show epoch num info ?
  """
    if optimizer is not None:
        loss = ops[0]
        if isinstance(optimizer, str):
            train_op = melt.gen_train_op_byname(loss, learning_rate, optimizer)
        else:
            train_op = optimizer(learning_rate).minimize(loss)
        ops = list(ops)
        ops.insert(0, train_op)

    if not model_dir:
        if log_dir and no_log:
            log_dir = None
        return simple_train_flow(
            ops,
            names,
            gen_feed_dict_fn,
            deal_results_fn,
            interval_steps,
            eval_ops,
            eval_names,
            gen_eval_feed_dict_fn,
            deal_eval_results_fn,
            eval_interval_steps,
            print_time,
            print_avg_loss,
            log_dir,
            num_steps,
            num_steps_per_epoch=num_steps_per_epoch,
            metric_eval_fn=metric_eval_fn,
            metric_eval_interval_steps=metric_eval_interval_steps,
            sess=sess)

    #if not set log dir try to use model dir to store log
    #so defaut is write log, if only want save model but disable log, set no_log=True
    if save_model:
        print('Will save model to %s' % model_dir)
    else:
        no_log = True
        print('Will not save model, only read model from %s if exists' %
              model_dir)

    if not log_dir and not no_log:
        log_dir = gezi.get_dir(model_dir)
    if log_dir:
        print('Will save log to %s' % log_dir)
        if add_train_var_histogram:
            # Add histograms for trainable variables.
            #this is also great for you to see all the trainable variables on tensorboard
            #NOTICE for big model this is too slow!
            melt.monitor_train_vars()
    else:
        print('Will not save log')

    def train_once_(sess, step, is_start=False):
        train_once(sess,
                   step,
                   ops,
                   names,
                   gen_feed_dict_fn,
                   deal_results_fn,
                   interval_steps,
                   eval_ops,
                   eval_names,
                   gen_eval_feed_dict_fn,
                   deal_eval_results_fn,
                   eval_interval_steps,
                   print_time,
                   print_avg_loss,
                   model_dir,
                   log_dir,
                   is_start,
                   num_steps_per_epoch,
                   metric_eval_fn=metric_eval_fn,
                   metric_eval_interval_steps=metric_eval_interval_steps,
                   summary_excls=summary_excls)

    tf_train_flow(train_once_,
                  model_dir,
                  max_models_keep,
                  save_interval_seconds,
                  save_interval_steps,
                  num_epochs,
                  num_steps,
                  save_model=save_model,
                  save_interval_epochs=save_interval_epochs,
                  num_steps_per_epoch=num_steps_per_epoch,
                  restore_from_latest=restore_from_latest,
                  metric_eval_fn=metric_eval_fn,
                  init_fn=init_fn,
                  sess=sess)