def main(argv): del argv # Unused. params = params_dict.ParamsDict(unet_config.UNET_CONFIG, unet_config.UNET_RESTRICTIONS) params = params_dict.override_params_dict(params, FLAGS.config_file, is_strict=False) if FLAGS.training_file_pattern: params.override({'training_file_pattern': FLAGS.training_file_pattern}, is_strict=True) if FLAGS.eval_file_pattern: params.override({'eval_file_pattern': FLAGS.eval_file_pattern}, is_strict=True) train_epoch_steps = params.train_item_count // params.train_batch_size eval_epoch_steps = params.eval_item_count // params.eval_batch_size params.override( { 'model_dir': FLAGS.model_dir, 'min_eval_interval': FLAGS.min_eval_interval, 'eval_timeout': FLAGS.eval_timeout, 'tpu_config': tpu_executor.get_tpu_flags(), 'lr_decay_steps': train_epoch_steps, 'train_steps': params.train_epochs * train_epoch_steps, 'eval_steps': eval_epoch_steps, }, is_strict=False) params = params_dict.override_params_dict(params, FLAGS.params_override, is_strict=True) params.validate() params.lock() train_input_fn = None eval_input_fn = None train_input_shapes = None eval_input_shapes = None if FLAGS.mode in ('train', 'train_and_eval'): train_input_fn = input_reader.LiverInputFn( params.training_file_pattern, params, mode=tf.estimator.ModeKeys.TRAIN) train_input_shapes = train_input_fn.get_input_shapes(params) if FLAGS.mode in ('eval', 'train_and_eval'): eval_input_fn = input_reader.LiverInputFn( params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL) eval_input_shapes = eval_input_fn.get_input_shapes(params) assert train_input_shapes is not None or eval_input_shapes is not None run_executer(params, train_input_shapes=train_input_shapes, eval_input_shapes=eval_input_shapes, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
def main(argv): del argv # Unused. params = params_dict.ParamsDict(unet_config.UNET_CONFIG, unet_config.UNET_RESTRICTIONS) params = params_dict.override_params_dict( params, FLAGS.config_file, is_strict=False) params = params_dict.override_params_dict( params, FLAGS.params_overrides, is_strict=False) params.override( { 'training_file_pattern': FLAGS.training_file_pattern, 'eval_file_pattern': FLAGS.eval_file_pattern, 'model_dir': FLAGS.model_dir, 'min_eval_interval': FLAGS.min_eval_interval, 'eval_timeout': FLAGS.eval_timeout, 'tpu_config': tpu_executor.get_tpu_flags() }, is_strict=False) params.validate() params.lock() train_input_fn = None eval_input_fn = None if FLAGS.mode in ('train', 'train_and_eval'): train_input_fn = input_reader.LiverInputFn( params.training_file_pattern, params, mode=tf.estimator.ModeKeys.TRAIN) if FLAGS.mode in ('eval', 'train_and_eval'): eval_input_fn = input_reader.LiverInputFn( params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL) run_executer( params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)