Exemple #1
0
def main(argv):
    del argv  # Unused.

    params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
                                    unet_config.UNET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=False)

    if FLAGS.training_file_pattern:
        params.override({'training_file_pattern': FLAGS.training_file_pattern},
                        is_strict=True)

    if FLAGS.eval_file_pattern:
        params.override({'eval_file_pattern': FLAGS.eval_file_pattern},
                        is_strict=True)

    train_epoch_steps = params.train_item_count // params.train_batch_size
    eval_epoch_steps = params.eval_item_count // params.eval_batch_size

    params.override(
        {
            'model_dir': FLAGS.model_dir,
            'min_eval_interval': FLAGS.min_eval_interval,
            'eval_timeout': FLAGS.eval_timeout,
            'tpu_config': tpu_executor.get_tpu_flags(),
            'lr_decay_steps': train_epoch_steps,
            'train_steps': params.train_epochs * train_epoch_steps,
            'eval_steps': eval_epoch_steps,
        },
        is_strict=False)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)

    params.validate()
    params.lock()

    train_input_fn = None
    eval_input_fn = None
    train_input_shapes = None
    eval_input_shapes = None
    if FLAGS.mode in ('train', 'train_and_eval'):
        train_input_fn = input_reader.LiverInputFn(
            params.training_file_pattern,
            params,
            mode=tf.estimator.ModeKeys.TRAIN)
        train_input_shapes = train_input_fn.get_input_shapes(params)
    if FLAGS.mode in ('eval', 'train_and_eval'):
        eval_input_fn = input_reader.LiverInputFn(
            params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL)
        eval_input_shapes = eval_input_fn.get_input_shapes(params)

    assert train_input_shapes is not None or eval_input_shapes is not None
    run_executer(params,
                 train_input_shapes=train_input_shapes,
                 eval_input_shapes=eval_input_shapes,
                 train_input_fn=train_input_fn,
                 eval_input_fn=eval_input_fn)
Exemple #2
0
def main(argv):
  del argv  # Unused.

  params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
                                  unet_config.UNET_RESTRICTIONS)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=False)
  params = params_dict.override_params_dict(
      params, FLAGS.params_overrides, is_strict=False)
  params.override(
      {
          'training_file_pattern': FLAGS.training_file_pattern,
          'eval_file_pattern': FLAGS.eval_file_pattern,
          'model_dir': FLAGS.model_dir,
          'min_eval_interval': FLAGS.min_eval_interval,
          'eval_timeout': FLAGS.eval_timeout,
          'tpu_config': tpu_executor.get_tpu_flags()
      },
      is_strict=False)
  params.validate()
  params.lock()

  train_input_fn = None
  eval_input_fn = None
  if FLAGS.mode in ('train', 'train_and_eval'):
    train_input_fn = input_reader.LiverInputFn(
        params.training_file_pattern, params, mode=tf.estimator.ModeKeys.TRAIN)
  if FLAGS.mode in ('eval', 'train_and_eval'):
    eval_input_fn = input_reader.LiverInputFn(
        params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL)

  run_executer(
      params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)