Example #1
0
    def test_latest_module_exporter_with_eval_spec(self):
        model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        estimator = tf.estimator.Estimator(_get_model_fn(register_module=True),
                                           model_dir=model_dir)
        exporter = hub.LatestModuleExporter("tf_hub",
                                            _serving_input_fn,
                                            exports_to_keep=2)
        estimator.train(_input_fn, max_steps=1)
        export_base_dir = os.path.join(model_dir, "export", "tf_hub")

        exporter.export(estimator, export_base_dir)
        timestamp_dirs = tf.gfile.ListDirectory(export_base_dir)
        self.assertEquals(1, len(timestamp_dirs))
        oldest_timestamp = timestamp_dirs[0]

        expected_module_dir = os.path.join(export_base_dir, timestamp_dirs[0],
                                           _EXPORT_MODULE_NAME)
        self.assertTrue(tf.gfile.IsDirectory(expected_module_dir))

        exporter.export(estimator, export_base_dir)
        timestamp_dirs = tf.gfile.ListDirectory(export_base_dir)
        self.assertEquals(2, len(timestamp_dirs))

        # Triggering yet another export should clean the oldest export.
        exporter.export(estimator, export_base_dir)
        timestamp_dirs = tf.gfile.ListDirectory(export_base_dir)
        self.assertEquals(2, len(timestamp_dirs))
        self.assertFalse(oldest_timestamp in timestamp_dirs)
Example #2
0
    def testLatestModuleExporterDirectly(self):
        model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        export_base_dir = os.path.join(
            tempfile.mkdtemp(dir=self.get_temp_dir()), "export")

        estimator = tf.estimator.Estimator(_get_model_fn(register_module=True),
                                           model_dir=model_dir)
        estimator.train(input_fn=_input_fn, steps=1)

        exporter = hub.LatestModuleExporter("exporter_name", _serving_input_fn)
        export_dir = exporter.export(estimator=estimator,
                                     export_path=export_base_dir,
                                     eval_result=None,
                                     is_the_final_export=None)

        # Check that a timestamped directory is created in the expected location.
        timestamp_dirs = tf.gfile.ListDirectory(export_base_dir)
        self.assertEquals(1, len(timestamp_dirs))
        self.assertEquals(
            tf.compat.as_bytes(os.path.join(export_base_dir,
                                            timestamp_dirs[0])),
            tf.compat.as_bytes(export_dir))

        # Check the timestamped directory containts the exported modules inside.
        expected_module_dir = os.path.join(
            tf.compat.as_bytes(export_dir),
            tf.compat.as_bytes(_EXPORT_MODULE_NAME))
        self.assertTrue(tf.gfile.IsDirectory(expected_module_dir))
Example #3
0
def exporter():
    """Create exporters."""
    serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
        features=dict(block_ids=tf.placeholder(tf.int32, [None, None]),
                      block_mask=tf.placeholder(tf.int32, [None, None]),
                      block_segment_ids=tf.placeholder(tf.int32, [None, None]),
                      query_ids=tf.placeholder(tf.int32, [None, None]),
                      query_mask=tf.placeholder(tf.int32, [None, None]),
                      mask_query=tf.placeholder(tf.bool, [None])),
        default_batch_size=8)
    return hub.LatestModuleExporter("tf_hub",
                                    serving_input_fn,
                                    exports_to_keep=1)
Example #4
0
    def __init__(self, name, serving_input_fn, compare_fn, exports_to_keep=5):
        """Creates a BestModuleExporter to use with tf.estimator.EvalSpec.

    Args:
      name: unique name of this Exporter, which will be used in the export path.
      serving_input_fn: A function with no arguments that returns a
        ServingInputReceiver. LatestModuleExporter does not care about the
        actual behavior of this function, so any return value that looks like a
        ServingInputReceiver is fine.
      compare_fn: A function that compares two evaluation results. It should
        take two arguments, best_eval_result and current_eval_result, and return
        True if the current result is better; False otherwise. See the
        loss_smaller method for an example.
      exports_to_keep: Number of exports to keep. Older exports will be garbage
        collected. Set to None to disable.
    """
        self._compare_fn = compare_fn
        self._best_eval_result = None
        self._latest_module_exporter = hub.LatestModuleExporter(
            name, serving_input_fn, exports_to_keep=exports_to_keep)
Example #5
0
  def test_latest_module_exporter_with_eval_spec(self):
    model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
    estimator = tf.estimator.Estimator(_get_model_fn(register_module=True),
                                       model_dir=model_dir)

    train_spec = tf.estimator.TrainSpec(
        input_fn=_input_fn,
        max_steps=1)

    eval_spec = tf.estimator.EvalSpec(
        input_fn=_input_fn,
        exporters=[
            hub.LatestModuleExporter(
                "tf_hub",
                _serving_input_fn,
                exports_to_keep=2),
        ])

    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

    export_base_dir = os.path.join(model_dir, "export", "tf_hub")
    timestamp_dirs = tf.gfile.ListDirectory(export_base_dir)
    self.assertEquals(1, len(timestamp_dirs))
    oldest_timestamp = timestamp_dirs[0]

    expected_module_dir = os.path.join(export_base_dir,
                                       timestamp_dirs[0],
                                       _EXPORT_MODULE_NAME)
    self.assertTrue(tf.gfile.IsDirectory(expected_module_dir))

    # Triggering a new train and evaluate should create a new timestamped
    # exported directory inside tf_hub exporter.
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    timestamp_dirs = tf.gfile.ListDirectory(export_base_dir)
    self.assertEquals(2, len(timestamp_dirs))

    # Triggering yet another export should clean the oldest export.
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    timestamp_dirs = tf.gfile.ListDirectory(export_base_dir)
    self.assertEquals(2, len(timestamp_dirs))
    self.assertFalse(oldest_timestamp in timestamp_dirs)
Example #6
0
def train_local(
    name,
    src_dir,
    data_dir,
    model_dir,
    run_config,
    train_spec,
    eval_spec,
    params):
  sys.path.append(src_dir)
  from doodle.inputs import train_input_fn, eval_input_fn, serving_input_fn
  from doodle.model import model_fn

  _model_dir = os.path.join(model_dir, name)
  _run_config = tf.estimator.RunConfig(
    model_dir=_model_dir,
    **run_config)

  _train_spec = tf.estimator.TrainSpec(
    input_fn=lambda: train_input_fn(data_dir, params),
    **train_spec)

  _eval_spec = tf.estimator.EvalSpec(
    input_fn=lambda: eval_input_fn(data_dir, params),
    exporters=[
      tf.estimator.LatestExporter('savedmodel', serving_input_fn(params)),
      tfhub.LatestModuleExporter('hub', serving_input_fn(params)),
    ],
    **eval_spec)

  estimator = tf.estimator.Estimator(
    model_fn=model_fn,
    config=_run_config,
    params=params)

  tf.estimator.train_and_evaluate(estimator, _train_spec, _eval_spec)

  metrics = estimator.evaluate(_eval_spec.input_fn, steps=_eval_spec.steps)
  print('###### metrics ' + '#' * 65)
  for name, value in sorted(six.iteritems(metrics)):
    print('{:<30}: {}'.format(name, value))
Example #7
0
    def test_latest_module_exporter_with_no_modules(self):
        model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        export_base_dir = os.path.join(
            tempfile.mkdtemp(dir=self.get_temp_dir()), "export")
        self.assertFalse(tf.gfile.Exists(export_base_dir))

        estimator = tf.estimator.Estimator(
            _get_model_fn(register_module=False), model_dir=model_dir)
        estimator.train(input_fn=_input_fn, steps=1)

        exporter = hub.LatestModuleExporter("exporter_name", _serving_input_fn)
        export_dir = exporter.export(estimator=estimator,
                                     export_path=export_base_dir,
                                     eval_result=None,
                                     is_the_final_export=None)

        # Check the result.
        self.assertIsNone(export_dir)

        # Check that a no directory has been created in the expected location.
        self.assertFalse(tf.gfile.Exists(export_base_dir))
Example #8
0
def save_module(model, savepath, max_steps):

    print('Save module')

    features = tf.placeholder(tf.float32,
                              shape=[None, None, None, None, nchannels],
                              name='input')
    labels = tf.placeholder(tf.float32,
                            shape=[None, None, None, None, ntargets],
                            name='labels')
    exporter = hub.LatestModuleExporter(
        "tf_hub",
        tf.estimator.export.build_raw_serving_input_receiver_fn(
            {
                'features': features,
                'labels': labels
            }, default_batch_size=None))
    modpath = exporter.export(model, savepath + 'module',
                              model.latest_checkpoint())
    modpath = modpath.decode("utf-8")
    check_module(modpath)
def train_and_eval():
    """Trains a network on (self) supervised data."""
    checkpoint_dir = os.path.join(FLAGS.workdir)

    if FLAGS.use_tpu:
        master = TPUClusterResolver(tpu=[os.environ['TPU_NAME']]).get_master()
    else:
        master = ''

    config = tf.contrib.tpu.RunConfig(
        model_dir=checkpoint_dir,
        tf_random_seed=FLAGS.get_flag_value('random_seed', None),
        master=master,
        evaluation_master=master,
        keep_checkpoint_every_n_hours=FLAGS.get_flag_value(
            'keep_checkpoint_every_n_hours', 4),
        save_checkpoints_secs=FLAGS.get_flag_value('save_checkpoints_secs',
                                                   600),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=TPU_ITERATIONS_PER_LOOP,
            tpu_job_name=FLAGS.tpu_worker_name))

    # The global batch-sizes are passed to the TPU estimator, and it will pass
    # along the local batch size in the model_fn's `params` argument dict.
    estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=get_self_supervision_model(FLAGS.task),
        model_dir=checkpoint_dir,
        config=config,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.get_flag_value('eval_batch_size',
                                             FLAGS.batch_size))

    if FLAGS.run_eval:
        data_fn = functools.partial(datasets.get_data,
                                    split_name=FLAGS.get_flag_value(
                                        'val_split', 'val'),
                                    is_training=False,
                                    shuffle=False,
                                    num_epochs=1,
                                    drop_remainder=FLAGS.use_tpu)

        # Contrary to what the documentation claims, the `train` and the
        # `evaluate` functions NEED to have `max_steps` and/or `steps` set and
        # cannot make use of the iterator's end-of-input exception, so we need
        # to do some math for that here.
        num_samples = datasets.get_count(
            FLAGS.get_flag_value('val_split', 'val'))
        num_steps = num_samples // FLAGS.get_flag_value(
            'eval_batch_size', FLAGS.batch_size)
        tf.logging.info('val_steps: %d', num_steps)

        for checkpoint in tf.contrib.training.checkpoints_iterator(
                estimator.model_dir, timeout=10 * 60):

            estimator.evaluate(checkpoint_path=checkpoint,
                               input_fn=data_fn,
                               steps=num_steps)

            hub_exporter = hub.LatestModuleExporter('hub', serving_input_fn)
            hub_exporter.export(estimator,
                                os.path.join(checkpoint_dir, 'export/hub'),
                                checkpoint)

            if tf.gfile.Exists(os.path.join(FLAGS.workdir,
                                            'TRAINING_IS_DONE')):
                break

        # Evaluates the latest checkpoint on validation set.
        result = estimator.evaluate(input_fn=data_fn, steps=num_steps)
        return result

    else:
        train_data_fn = functools.partial(
            datasets.get_data,
            split_name=FLAGS.get_flag_value('train_split', 'train'),
            is_training=True,
            num_epochs=int(math.ceil(FLAGS.epochs)),
            drop_remainder=True)

        # We compute the number of steps and make use of Estimator's max_steps
        # arguments instead of relying on the Dataset's iterator to run out after
        # a number of epochs so that we can use 'fractional' epochs, which are
        # used by regression tests. (And because TPUEstimator needs it anyways.)
        num_samples = datasets.get_count(
            FLAGS.get_flag_value('train_split', 'train'))
        # Depending on whether we drop the last batch each epoch or only at the
        # ver end, this should be ordered differently for rounding.
        updates_per_epoch = num_samples // FLAGS.batch_size
        num_steps = int(math.ceil(FLAGS.epochs * updates_per_epoch))
        tf.logging.info('train_steps: %d', num_steps)

        estimator.train(train_data_fn, max_steps=num_steps)
Example #10
0
def train_and_eval():
  """Trains a network on (self) supervised data."""
  checkpoint_dir = FLAGS.get_flag_value("checkpoint", FLAGS.workdir)
  tf.gfile.MakeDirs(checkpoint_dir)

  if FLAGS.tpu_name:
    cluster = TPUClusterResolver(tpu=[FLAGS.tpu_name])
  else:
    cluster = None

  # tf.logging.info("master: %s", master)
  config = RunConfig(
      model_dir=checkpoint_dir,
      tf_random_seed=FLAGS.random_seed,
      cluster=cluster,
      keep_checkpoint_max=None,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=TPUConfig(iterations_per_loop=TPU_ITERATIONS_PER_LOOP))

  # Optionally resume from a stored checkpoint.
  if FLAGS.path_to_initial_ckpt:
    warm_start_from = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=FLAGS.path_to_initial_ckpt,
        # The square bracket is important for loading all the
        # variables from GLOBAL_VARIABLES collection.
        # See https://www.tensorflow.org/api_docs/python/tf/estimator/WarmStartSettings  # pylint: disable=line-too-long
        # section vars_to_warm_start for more details.
        vars_to_warm_start=[FLAGS.vars_to_restore]
    )
  else:
    warm_start_from = None

  # The global batch-sizes are passed to the TPU estimator, and it will pass
  # along the local batch size in the model_fn's `params` argument dict.
  estimator = TPUEstimator(
      model_fn=semi_supervised.get_model(FLAGS.task),
      model_dir=checkpoint_dir,
      config=config,
      use_tpu=FLAGS.tpu_name is not None,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.get_flag_value("eval_batch_size", FLAGS.batch_size),
      warm_start_from=warm_start_from
  )

  if FLAGS.run_eval:
    data_fn = functools.partial(
        datasets.get_data,
        split_name=FLAGS.val_split,
        preprocessing=FLAGS.get_flag_value("preprocessing_eval",
                                           FLAGS.preprocessing),
        is_training=False,
        shuffle=False,
        num_epochs=1,
        drop_remainder=True)

    # Contrary to what the documentation claims, the `train` and the
    # `evaluate` functions NEED to have `max_steps` and/or `steps` set and
    # cannot make use of the iterator's end-of-input exception, so we need
    # to do some math for that here.
    num_samples = datasets.get_count(FLAGS.val_split)
    num_steps = num_samples // FLAGS.get_flag_value("eval_batch_size",
                                                    FLAGS.batch_size)
    tf.logging.info("val_steps: %d", num_steps)

    for checkpoint in checkpoints_iterator(
        estimator.model_dir, timeout=FLAGS.eval_timeout_mins * 60):

      result_dict_val = estimator.evaluate(
          checkpoint_path=checkpoint, input_fn=data_fn, steps=num_steps)

      hub_exporter = hub.LatestModuleExporter("hub", serving_input_fn)
      hub_exporter.export(
          estimator,
          os.path.join(checkpoint_dir, "export/hub"),
          checkpoint)
      # This is here instead of using the above `checkpoints_iterator`'s
      # `timeout_fn` param, because that would wait forever on failed
      # trainers which will never create this file.
      if tf.gfile.Exists(os.path.join(FLAGS.workdir, "TRAINING_IS_DONE")):
        break

    # Evaluates the latest checkpoint on validation set.
    result_dict_val = estimator.evaluate(input_fn=data_fn, steps=num_steps)
    tf.logging.info(result_dict_val)

    # Optionally evaluates the latest checkpoint on test set.
    if FLAGS.test_split:
      data_fn = functools.partial(
          datasets.get_data,
          split_name=FLAGS.test_split,
          preprocessing=FLAGS.get_flag_value("preprocessing_eval",
                                             FLAGS.preprocessing),
          is_training=False,
          shuffle=False,
          num_epochs=1,
          drop_remainder=True)
      num_samples = datasets.get_count(FLAGS.test_split)
      num_steps = num_samples // FLAGS.get_flag_value("eval_batch_size",
                                                      FLAGS.batch_size)
      result_dict_test = estimator.evaluate(input_fn=data_fn, steps=num_steps)
      tf.logging.info(result_dict_test)
    return result_dict_val

  else:
    train_data_fn = functools.partial(
        datasets.get_data,
        split_name=FLAGS.train_split,
        preprocessing=FLAGS.preprocessing,
        is_training=True,
        num_epochs=None,  # read data indefenitely for training
        drop_remainder=True)

    # We compute the number of steps and make use of Estimator's max_steps
    # arguments instead of relying on the Dataset's iterator to run out after
    # a number of epochs so that we can use "fractional" epochs, which are
    # used by regression tests. (And because TPUEstimator needs it anyways.)
    num_samples = datasets.get_count(FLAGS.train_split)
    if FLAGS.num_supervised_examples:
      num_samples = FLAGS.num_supervised_examples
    # Depending on whether we drop the last batch each epoch or only at the
    # ver end, this should be ordered differently for rounding.
    updates_per_epoch = num_samples // FLAGS.batch_size
    epochs = utils.str2intlist(FLAGS.schedule, strict_int=False)[-1]
    num_steps = int(math.ceil(epochs * updates_per_epoch))
    tf.logging.info("train_steps: %d", num_steps)

    return estimator.train(
        train_data_fn,
        max_steps=num_steps)
Example #11
0
def main(argv):
    del argv

    params = FLAGS.flag_values_dict()
    DATA_SHAPE = DATA_SHAPES[FLAGS.data_set]

    params['activation']  = getattr(tf.nn, params['activation'])

    if len(DATA_SHAPE)>2:
        params['width']       = DATA_SHAPE[0]
        params['height']      = DATA_SHAPE[1]
        params['n_channels']  = DATA_SHAPE[2]
    else:
        params['length']      = DATA_SHAPE[0]
        params['n_channels']  = DATA_SHAPE[1]
    
    params['data_shape'] = DATA_SHAPE
    flatten = True

    params['output_size'] = np.prod(DATA_SHAPE)
    params['full_size']   = [None,params['output_size']] 

    if params['network_type']=='conv':
        flatten = False
        params['output_size'] = DATA_SHAPE
        params['full_size']   = [None,params['width'],params['height'],params['n_channels']]

    
    params['label']       = os.path.join('%s'%params['data_set'], '%s'%params['likelihood'], 'class%d'%params['class_label'], 'latent_size%d'%params['latent_size'],'net_type_%s'%params['network_type'],params['tag'])
    if params['AE']:
        params['label']+='AE'

    params['model_dir']   = os.path.join(params['model_dir'], params['label'])
    params['module_dir']  = os.path.join(params['module_dir'], params['label'])
    
    for dd in ['model_dir', 'module_dir', 'data_dir']:
        if not os.path.isdir(params[dd]):
            os.makedirs(params[dd], exist_ok=True)

    if not os.path.isdir('./params'):
        os.makedirs('./params')
    if params['AE']:
        pkl.dump(params, open('./params/params_%s_%s_%d_%d_%s-AE.pkl'%(params['data_set'],params['likelihood'],params['class_label'],params['latent_size'],params['network_type']),'wb'))
    else:
        pkl.dump(params, open('./params/params_%s_%s_%d_%d_%s.pkl'%(params['data_set'],params['likelihood'],params['class_label'],params['latent_size'],params['network_type']),'wb'))
 
    if params['data_set']=='celeba':
        input_fns      = crd.build_input_fn_celeba(params)
        train_input_fn = input_fns['train']
        eval_input_fn  = input_fns['validation']
    else:
        train_input_fn, eval_input_fn = crd.build_input_fns(params,label=FLAGS.class_label,flatten=flatten)

    estimator = tf.estimator.Estimator(model_fn, params=params, config=tf.estimator.RunConfig(model_dir=params['model_dir']))
    c = tf.placeholder(tf.float32,params['full_size'])
    serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(features=dict(x=c))

    #train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=params['max_steps'])
    #eval_spec  = tf.estimator.EvalSpec(input_fn=eval_input_fn)
    exporter   = hub.LatestModuleExporter("tf_hub", serving_input_fn)
    #tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    #exporter.export(estimator, params['module_dir'], estimator.latest_checkpoint())


    n_steps = FLAGS.n_steps
    for ii in range(FLAGS.max_steps//n_steps):
        estimator.train(train_input_fn, steps=n_steps)
        eval_results = estimator.evaluate(eval_input_fn)
        print('model evaluation on test set:', eval_results)
        exporter.export(estimator, params['module_dir'], estimator.latest_checkpoint())

    return True
Example #12
0
File: main.py Project: kjjjjjjj/PAE
def main(argv):
    del argv
    DATA_SHAPES = dict(mnist=[28,28,1],fmnist=[28,28,1],cifar10=[32,32,3],celeba=[FLAGS.celeba_dim,FLAGS.celeba_dim,3],banana=[32,1])
    params = FLAGS.flag_values_dict()
    DATA_SHAPE = DATA_SHAPES[FLAGS.data_set]
    params['activation']  = getattr(tf.nn, params['activation'])
    print(DATA_SHAPE)
    if len(DATA_SHAPE)>2:
        params['width']       = DATA_SHAPE[0]
        params['height']      = DATA_SHAPE[1]
        params['n_channels']  = DATA_SHAPE[2]
    else:
        params['length']      = DATA_SHAPE[0]
        params['n_channels']  = DATA_SHAPE[1]
    
    params['data_shape'] = DATA_SHAPE
    flatten = True

    params['output_size'] = np.prod(DATA_SHAPE)
    params['full_size']   = [params['batch_size'],params['output_size']] 

    if params['network_type'] in ['conv','infoGAN','resnet_conv']:
        flatten = False
        params['output_size'] = DATA_SHAPE
        params['full_size']   = [params['batch_size'],params['width'],params['height'],params['n_channels']]

    if params['full_sigma']:
        params['tag']+='_full_sigma'
    else:
        params['tag']+='_mean_sigma'

    if params['beta']:
        if params['loss']=='VAE':
            params['tag']+='_beta%d'%params['beta']
    if params['C_annealing']:
        if params['loss']=='VAE':
            params['tag']+='_C%d'%params['C']
    
    params['label']       = os.path.join('%s'%params['data_set'], 'class%d'%params['class_label'], 'latent_size%d'%params['latent_size'],'net_type_%s'%params['network_type'],'loss_%s'%params['loss'],params['tag'])

    params['model_dir']   = os.path.join(params['model_dir'], params['label'])
    params['module_dir']  = os.path.join(params['module_dir'], params['label'])
    print(params['module_dir'])
    
    for dd in ['model_dir', 'module_dir', 'data_dir']:
        if not os.path.isdir(params[dd]):
            os.makedirs(params[dd], exist_ok=True)

    if not os.path.isdir('./params'):
        os.makedirs('./params')
    pkl.dump(params, open('./params/params_%s_%d_%d_%s_%s_%s.pkl'%(params['data_set'],params['class_label'],params['latent_size'],params['network_type'],params['loss'],params['tag']),'wb'))
 
    train_input_fn, eval_input_fn = crd.build_input_fns(params,label=FLAGS.class_label,flatten=flatten)

    estimator = tf.estimator.Estimator(model_fn, params=params, config=tf.estimator.RunConfig(model_dir=params['model_dir']))
    c         = tf.placeholder(tf.float32,params['full_size'])
    serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(features=dict(x=c))

    exporter   = hub.LatestModuleExporter("tf_hub", serving_input_fn)
    #lr = params['learning_rate']
    n_epoch = 0 
    n_steps = FLAGS.n_steps
    for ii in range(FLAGS.max_steps//n_steps):	
        #params['learning_rate'] = lr * math.pow(0.5, np.floor(float(n_epoch) / float(150)))
        estimator.train(train_input_fn, steps=n_steps)
        eval_results = estimator.evaluate(eval_input_fn)
        print('model evaluation on test set:', eval_results)
        print('n_epoch', n_epoch)
        exporter.export(estimator, params['module_dir'], estimator.latest_checkpoint())
    return True