Ejemplo n.º 1
0
def main(argv):
  del argv  # Unused.

  params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
                                  unet_config.UNET_RESTRICTIONS)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=False)
  params = params_dict.override_params_dict(
      params, FLAGS.params_overrides, is_strict=False)
  params.override(
      {
          'training_file_pattern': FLAGS.training_file_pattern,
          'eval_file_pattern': FLAGS.eval_file_pattern,
          'model_dir': FLAGS.model_dir,
          'min_eval_interval': FLAGS.min_eval_interval,
          'eval_timeout': FLAGS.eval_timeout,
          'tpu_config': tpu_executor.get_tpu_flags()
      },
      is_strict=False)
  params.validate()
  params.lock()

  train_input_fn = None
  eval_input_fn = None
  if FLAGS.mode in ('train', 'train_and_eval'):
    train_input_fn = input_reader.LiverInputFn(
        params.training_file_pattern, params, mode=tf.estimator.ModeKeys.TRAIN)
  if FLAGS.mode in ('eval', 'train_and_eval'):
    eval_input_fn = input_reader.LiverInputFn(
        params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL)

  run_executer(
      params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
Ejemplo n.º 2
0
def main(_):
  params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
                                  unet_config.UNET_RESTRICTIONS)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=False)
  params.train_batch_size = FLAGS.batch_size
  params.eval_batch_size = FLAGS.batch_size
  params.use_bfloat16 = False

  model_params = dict(
      params.as_dict(),
      use_tpu=FLAGS.use_tpu,
      mode=tf.estimator.ModeKeys.PREDICT,
      transpose_input=False)

  print(' - Setting up TPUEstimator...')
  estimator = tf.estimator.tpu.TPUEstimator(
      model_fn=serving_model_fn,
      model_dir=FLAGS.model_dir,
      config=tf.estimator.tpu.RunConfig(
          tpu_config=tf.estimator.tpu.TPUConfig(
              iterations_per_loop=FLAGS.iterations_per_loop),
          master='local',
          evaluation_master='local'),
      params=model_params,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      predict_batch_size=FLAGS.batch_size,
      export_to_tpu=FLAGS.use_tpu,
      export_to_cpu=True)

  print(' - Exporting the model...')
  input_type = FLAGS.input_type
  export_path = estimator.export_saved_model(
      export_dir_base=FLAGS.export_dir,
      serving_input_receiver_fn=functools.partial(
          serving_input_fn,
          batch_size=FLAGS.batch_size,
          input_type=input_type,
          params=params,
          input_name=FLAGS.input_name),
      checkpoint_path=FLAGS.checkpoint_path)

  print(' - Done! path: %s' % export_path)