def main(argv): del argv # Unused. params = params_dict.ParamsDict(unet_config.UNET_CONFIG, unet_config.UNET_RESTRICTIONS) params = params_dict.override_params_dict( params, FLAGS.config_file, is_strict=False) params = params_dict.override_params_dict( params, FLAGS.params_overrides, is_strict=False) params.override( { 'training_file_pattern': FLAGS.training_file_pattern, 'eval_file_pattern': FLAGS.eval_file_pattern, 'model_dir': FLAGS.model_dir, 'min_eval_interval': FLAGS.min_eval_interval, 'eval_timeout': FLAGS.eval_timeout, 'tpu_config': tpu_executor.get_tpu_flags() }, is_strict=False) params.validate() params.lock() train_input_fn = None eval_input_fn = None if FLAGS.mode in ('train', 'train_and_eval'): train_input_fn = input_reader.LiverInputFn( params.training_file_pattern, params, mode=tf.estimator.ModeKeys.TRAIN) if FLAGS.mode in ('eval', 'train_and_eval'): eval_input_fn = input_reader.LiverInputFn( params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL) run_executer( params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
def main(_): params = params_dict.ParamsDict(unet_config.UNET_CONFIG, unet_config.UNET_RESTRICTIONS) params = params_dict.override_params_dict( params, FLAGS.config_file, is_strict=False) params.train_batch_size = FLAGS.batch_size params.eval_batch_size = FLAGS.batch_size params.use_bfloat16 = False model_params = dict( params.as_dict(), use_tpu=FLAGS.use_tpu, mode=tf.estimator.ModeKeys.PREDICT, transpose_input=False) print(' - Setting up TPUEstimator...') estimator = tf.estimator.tpu.TPUEstimator( model_fn=serving_model_fn, model_dir=FLAGS.model_dir, config=tf.estimator.tpu.RunConfig( tpu_config=tf.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop), master='local', evaluation_master='local'), params=model_params, use_tpu=FLAGS.use_tpu, train_batch_size=FLAGS.batch_size, predict_batch_size=FLAGS.batch_size, export_to_tpu=FLAGS.use_tpu, export_to_cpu=True) print(' - Exporting the model...') input_type = FLAGS.input_type export_path = estimator.export_saved_model( export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=functools.partial( serving_input_fn, batch_size=FLAGS.batch_size, input_type=input_type, params=params, input_name=FLAGS.input_name), checkpoint_path=FLAGS.checkpoint_path) print(' - Done! path: %s' % export_path)