Пример #1
0
def main(argv):
    FLAGS = argv[0]  # pylint:disable=invalid-name,redefined-outer-name
    tf.logging.set_verbosity(tf.logging.INFO)

    # RevNet specific configuration
    config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)

    # Estimator specific configuration
    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.train_dir,  # Directory for storing checkpoints
        tf_random_seed=config.seed,
        save_summary_steps=config.log_every,
        save_checkpoints_steps=config.log_every,
        session_config=None,  # Using default
        keep_checkpoint_max=100,
        keep_checkpoint_every_n_hours=10000,  # Using default
        log_step_count_steps=config.log_every,
        train_distribute=None  # Default not use distribution strategy
    )

    # Construct estimator
    revnet_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                              model_dir=FLAGS.train_dir,
                                              config=run_config,
                                              params={"config": config})

    # Construct input functions
    train_input_fn = get_input_fn(config=config,
                                  data_dir=FLAGS.data_dir,
                                  split="train_all")
    eval_input_fn = get_input_fn(config=config,
                                 data_dir=FLAGS.data_dir,
                                 split="test")

    # Train and evaluate estimator
    revnet_estimator.train(input_fn=train_input_fn)
    revnet_estimator.evaluate(input_fn=eval_input_fn)

    if FLAGS.export:
        input_shape = (None, ) + config.input_shape
        inputs = tf.placeholder(tf.float32, shape=input_shape)
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
            {"image": inputs})
        revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn)
Пример #2
0
def main(argv):
  FLAGS = argv[0]  # pylint:disable=invalid-name,redefined-outer-name
  tf.logging.set_verbosity(tf.logging.INFO)

  # RevNet specific configuration
  config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)

  # Estimator specific configuration
  run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.train_dir,  # Directory for storing checkpoints
      tf_random_seed=config.seed,
      save_summary_steps=config.log_every,
      save_checkpoints_steps=config.log_every,
      session_config=None,  # Using default
      keep_checkpoint_max=100,
      keep_checkpoint_every_n_hours=10000,  # Using default
      log_step_count_steps=config.log_every,
      train_distribute=None  # Default not use distribution strategy
  )

  # Construct estimator
  revnet_estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      model_dir=FLAGS.train_dir,
      config=run_config,
      params={"config": config})

  # Construct input functions
  train_input_fn = get_input_fn(
      config=config, data_dir=FLAGS.data_dir, split="train_all")
  eval_input_fn = get_input_fn(
      config=config, data_dir=FLAGS.data_dir, split="test")

  # Train and evaluate estimator
  revnet_estimator.train(input_fn=train_input_fn)
  revnet_estimator.evaluate(input_fn=eval_input_fn)

  if FLAGS.export:
    input_shape = (None,) + config.input_shape
    inputs = tf.placeholder(tf.float32, shape=input_shape)
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        "image": inputs
    })
    revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn)
Пример #3
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  # RevNet specific configuration
  config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)

  if FLAGS.use_tpu:
    tf.logging.info("Using TPU.")
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
  else:
    tpu_cluster_resolver = None

  # TPU specific configuration
  tpu_config = tf.contrib.tpu.TPUConfig(
      # Recommended to be set as number of global steps for next checkpoint
      iterations_per_loop=FLAGS.iterations_per_loop,
      num_shards=FLAGS.num_shards)

  # Estimator specific configuration
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False),
      tpu_config=tpu_config,
  )

  # Construct TPU Estimator
  estimator = tf.contrib.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=config.tpu_batch_size,
      eval_batch_size=config.tpu_eval_batch_size,
      config=run_config,
      params={"config": config})

  # Construct input functions
  train_input_fn = get_input_fn(
      config=config, data_dir=FLAGS.data_dir, split="train_all")
  eval_input_fn = get_input_fn(
      config=config, data_dir=FLAGS.data_dir, split="test")

  # Disabling a range within an else block currently doesn't work
  # due to https://github.com/PyCQA/pylint/issues/872
  # pylint: disable=protected-access
  if FLAGS.mode == "eval":
    # TPUEstimator.evaluate *requires* a steps argument.
    # Note that the number of examples used during evaluation is
    # --eval_steps * --batch_size.
    # So if you change --batch_size then change --eval_steps too.
    eval_steps = 10000 // config.tpu_eval_batch_size

    # Run evaluation when there's a new checkpoint
    for ckpt in evaluation.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info("Starting to evaluate.")
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = estimator.evaluate(
            input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info("Eval results: %s. Elapsed seconds: %d" %
                        (eval_results, elapsed_time))

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split("-")[1])
        if current_step >= config.max_train_iter:
          tf.logging.info(
              "Evaluation finished after training step %d" % current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info(
            "Checkpoint %s no longer exists, skipping checkpoint" % ckpt)

  else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = estimator_._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)
    tf.logging.info("Training for %d steps . Current"
                    " step %d." % (config.max_train_iter, current_step))

    start_timestamp = time.time()  # This time will include compilation time
    if FLAGS.mode == "train":
      estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter)
    else:
      eval_steps = 10000 // config.tpu_eval_batch_size
      assert FLAGS.mode == "train_and_eval"
      while current_step < config.max_train_iter:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              config.max_train_iter)
        estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be consistently excluded modulo the batch size.
        tf.logging.info("Starting to evaluate.")
        eval_results = estimator.evaluate(
            input_fn=eval_input_fn, steps=eval_steps)
        tf.logging.info("Eval results: %s" % eval_results)

    elapsed_time = int(time.time() - start_timestamp)
    tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
                    (config.max_train_iter, elapsed_time))
Пример #4
0
def main(argv):
  FLAGS = argv[0]  # pylint:disable=invalid-name,redefined-outer-name
  tf.logging.set_verbosity(tf.logging.INFO)

  # RevNet specific configuration
  config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)

  if FLAGS.use_tpu:
    tf.logging.info("Using TPU.")
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
  else:
    tpu_cluster_resolver = None

  # TPU specific configuration
  tpu_config = tf.contrib.tpu.TPUConfig(
      # Recommended to be set as number of global steps for next checkpoint
      iterations_per_loop=FLAGS.iterations_per_loop,
      num_shards=FLAGS.num_shards)

  # Estimator specific configuration
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False),
      tpu_config=tpu_config,
  )

  # Construct TPU Estimator
  estimator = tf.contrib.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=config.tpu_batch_size,
      eval_batch_size=config.tpu_eval_batch_size,
      config=run_config,
      params={
          "FLAGS": FLAGS,
          "config": config,
      })

  # Construct input functions
  train_input_fn = get_input_fn(
      config=config, data_dir=FLAGS.data_dir, split="train_all")
  eval_input_fn = get_input_fn(
      config=config, data_dir=FLAGS.data_dir, split="test")

  # Disabling a range within an else block currently doesn't work
  # due to https://github.com/PyCQA/pylint/issues/872
  # pylint: disable=protected-access
  if FLAGS.mode == "eval":
    # TPUEstimator.evaluate *requires* a steps argument.
    # Note that the number of examples used during evaluation is
    # --eval_steps * --batch_size.
    # So if you change --batch_size then change --eval_steps too.
    eval_steps = 10000 // config.tpu_eval_batch_size

    # Run evaluation when there's a new checkpoint
    for ckpt in evaluation.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info("Starting to evaluate.")
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = estimator.evaluate(
            input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info("Eval results: %s. Elapsed seconds: %d" %
                        (eval_results, elapsed_time))

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split("-")[1])
        if current_step >= config.max_train_iter:
          tf.logging.info(
              "Evaluation finished after training step %d" % current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info(
            "Checkpoint %s no longer exists, skipping checkpoint" % ckpt)

  else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = estimator_._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)
    tf.logging.info("Training for %d steps . Current"
                    " step %d." % (config.max_train_iter, current_step))

    start_timestamp = time.time()  # This time will include compilation time
    if FLAGS.mode == "train":
      estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter)
    else:
      eval_steps = 10000 // config.tpu_eval_batch_size
      assert FLAGS.mode == "train_and_eval"
      while current_step < config.max_train_iter:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              config.max_train_iter)
        estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be consistently excluded modulo the batch size.
        tf.logging.info("Starting to evaluate.")
        eval_results = estimator.evaluate(
            input_fn=eval_input_fn, steps=eval_steps)
        tf.logging.info("Eval results: %s" % eval_results)

    elapsed_time = int(time.time() - start_timestamp)
    tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
                    (config.max_train_iter, elapsed_time))