def test_fail_with_invalid_computation_shape(self):
        with self.assertRaisesRegexp(
                ValueError, 'computation_shape must be a list with length'
                ' 3 or None'):
            tpu_config_lib.TPUConfig(computation_shape=[2, 1])

        with self.assertRaisesRegexp(ValueError,
                                     'computation_shape elements can only be'):
            tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1])
示例#2
0
def create_estimator(master,
                     model_dir,
                     use_tpu,
                     iterations_per_loop,
                     num_shards,
                     model_params,
                     include_features_in_predictions=True,
                     decode_keys=(),
                     train_init_checkpoint=None,
                     train_warmup_steps=10000,
                     save_checkpoints_steps=1000,
                     keep_checkpoint_max=5):
    """Returns an tensorflow estimator."""

    run_config = tpu_config.RunConfig(
        master=master,
        model_dir=model_dir,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False),
        tpu_config=tpu_config.TPUConfig(iterations_per_loop),
        save_checkpoints_steps=save_checkpoints_steps,
        keep_checkpoint_max=keep_checkpoint_max)

    return tpu_estimator.TPUEstimator(
        model_fn=_estimator_model_fn(use_tpu, model_params, model_dir,
                                     include_features_in_predictions,
                                     decode_keys, train_init_checkpoint,
                                     train_warmup_steps),
        use_tpu=use_tpu,
        train_batch_size=model_params.batch_size * num_shards,
        eval_batch_size=model_params.batch_size * num_shards,
        predict_batch_size=model_params.batch_size * num_shards,
        config=run_config)
示例#3
0
def main(unused_argv):
    del unused_argv

    start = time.time()
    tf.logging.set_verbosity(tf.logging.INFO)

    run_config = tpu_config.RunConfig(
        master=FLAGS.master,
        evaluation_master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        save_checkpoints_secs=FLAGS.save_checkpoints_secs,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        #tpu_config=tpu_config.TPUConfig(5, FLAGS.num_shards, per_host_input_for_training = True),
        tpu_config=tpu_config.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
    )
    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           use_tpu=FLAGS.use_tpu,
                                           train_batch_size=128,
                                           eval_batch_size=128,
                                           config=run_config)
    estimator.train(input_fn=get_input_fn(FLAGS.train_file),
                    max_steps=FLAGS.train_steps)
    estimator.evaluate(input_fn=get_input_fn(FLAGS.eval_file), steps=100)

    total = time.time() - start
    print("Total time: " + str(total))
示例#4
0
def main(unused_argv):
  config = tpu_config.RunConfig(
      master=FLAGS.master,
      evaluation_master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_shards))
  resnet_classifier = tpu_estimator.TPUEstimator(
      model_fn=resnet_model_fn,
      config=config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size)

  if FLAGS.enable_eval:
    for cycle in range(FLAGS.train_steps // FLAGS.steps_per_eval):
      tf.logging.info('Starting a training cycle.')
      resnet_classifier.train(
          input_fn=ImageNetInput(True), steps=FLAGS.steps_per_eval)

      _EVAL_STEPS = 50000 // FLAGS.eval_batch_size
      tf.logging.info('Starting to evaluate.')
      eval_results = resnet_classifier.evaluate(
          input_fn=ImageNetInput(False), steps=_EVAL_STEPS)
      tf.logging.info('Eval results: %s' % eval_results)

  else:
    tf.logging.info('Starting training.')
    resnet_classifier.train(
        input_fn=ImageNetInput(True), steps=FLAGS.train_steps)
示例#5
0
文件: amoeba_net.py 项目: vinhngx/tpu
def build_run_config():
    """Return RunConfig for TPU estimator."""
    master = FLAGS.master if FLAGS.use_tpu else ''
    if master is None and FLAGS.tpu_name is None:
        raise RuntimeError('You must specify either --master or --tpu_name.')

    if master is not None:
        if FLAGS.tpu_name is not None:
            tf.logging.warn('Both --master and --tpu_name are set. Ignoring '
                            '--tpu_name and using --master.')
        tpu_grpc_url = master
    else:
        tpu_cluster_resolver = (tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
        tpu_grpc_url = tpu_cluster_resolver.get_master()
    eval_steps = model_lib.NUM_EVAL_IMAGES // FLAGS.eval_batch_size
    iterations_per_loop = (eval_steps if FLAGS.mode == 'eval' else
                           FLAGS.iterations_per_loop)
    save_checkpoints_steps = FLAGS.save_checkpoints_steps or iterations_per_loop
    run_config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        keep_checkpoint_max=None,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=FLAGS.num_shards,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            PER_HOST_V2))
    return run_config
示例#6
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.train_file:
        tf.logging.fatal(
            "Flag --train_file must be set for training. Aborting.")

    if FLAGS.eval_steps and not FLAGS.eval_file:
        tf.logging.fatal(
            "Flag --eval_file must be set for evaluation. Aborting.")

    run_config = tpu_config.RunConfig(
        master=FLAGS.master,
        evaluation_master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=tpu_config.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
    )

    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           use_tpu=FLAGS.use_tpu,
                                           train_batch_size=FLAGS.batch_size,
                                           eval_batch_size=FLAGS.batch_size,
                                           config=run_config)

    estimator.train(input_fn=get_input_fn(FLAGS.train_file),
                    max_steps=FLAGS.train_steps)

    if FLAGS.eval_steps:
        estimator.evaluate(input_fn=get_input_fn(FLAGS.eval_file),
                           steps=FLAGS.eval_steps)
示例#7
0
def _get_tpu_estimator():
    tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=None, project=None)
    tpu_grpc_url = tpu_cluster_resolver.get_master()

    run_config = contrib_tpu_python_tpu_tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.work_dir,
        save_checkpoints_steps=max(1000, FLAGS.iterations_per_loop),
        save_summary_steps=FLAGS.summary_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=contrib_tpu_python_tpu_tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=contrib_tpu_python_tpu_tpu_config.
            InputPipelineConfig.PER_HOST_V2))

    return contrib_tpu_python_tpu_tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size * FLAGS.num_tpu_cores,
        eval_batch_size=FLAGS.train_batch_size * FLAGS.num_tpu_cores,
        params=FLAGS.flag_values_dict())
示例#8
0
def construct_estimator(model_fn, hparams, tpu=None):
    if hparams.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            tpu=tpu.name)
        master = tpu_cluster_resolver.get_master()
        config = tpu_config.RunConfig(
            master=master,
            evaluation_master=master,
            model_dir=hparams.output_dir,
            session_config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=True),
            tpu_config=tpu_config.TPUConfig(
                iterations_per_loop=FLAGS.tpu_iterations_per_loop,
                num_shards=FLAGS.tpu_shards),
            save_checkpoints_steps=FLAGS.eval_every)
        estimator = tpu_estimator.TPUEstimator(
            use_tpu=hparams.use_tpu,
            model_fn=model_fn,
            model_dir=hparams.output_dir,
            config=config,
            train_batch_size=hparams.batch_size,
            eval_batch_size=hparams.batch_size)
    else:
        gpu_config = tf.ConfigProto(allow_soft_placement=True)
        gpu_config.gpu_options.allow_growth = True
        run_config = tf.estimator.RunConfig(
            save_checkpoints_steps=FLAGS.eval_every, session_config=gpu_config)

        estimator = tf.estimator.Estimator(
            model_fn=tf.contrib.estimator.replicate_model_fn(model_fn),
            model_dir=hparams.output_dir,
            config=run_config)

    return estimator
示例#9
0
def run_training(hparams):
    """For benchmarking convenience, run only the training job."""
    model_module = {
        MATRIX_FACTORIZATION: matrix_factorization_model,
        DNN_SOFTMAX: dnn_softmax_model
    }[hparams.model_type]

    features_padding_fn, model_fn, target_features_fn = (
        model_module.get_pad_and_model_fns(hparams))

    estimator = tpu_estimator.TPUEstimator(
        model_dir=hparams.output_path,
        model_fn=model_fn,
        train_batch_size=hparams.batch_size,
        use_tpu=hparams.use_tpu,
        config=tpu_config.RunConfig(master=hparams.master,
                                    tpu_config=tpu_config.TPUConfig(
                                        hparams.tpu_loop_steps,
                                        num_shards=hparams.tpu_cores)))

    train_data_paths = os.path.join(hparams.train_data_dir, 'features_train-*')
    train_input_fn = make_input_fn(hparams=hparams,
                                   mode=tf.contrib.learn.ModeKeys.TRAIN,
                                   data_file_pattern=train_data_paths,
                                   features_padding_fn=features_padding_fn,
                                   target_features_fn=target_features_fn,
                                   randomize_input=hparams.randomize_input,
                                   queue_capacity=4 * hparams.batch_size)

    estimator.train(input_fn=train_input_fn, steps=hparams.train_steps)
示例#10
0
def main(argv):
    del argv  # Unused.

    if FLAGS.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tpu_grpc_url = tpu_cluster_resolver.get_master()

    else:
        tpu_grpc_url = None

    run_config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop),
    )

    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           use_tpu=FLAGS.use_tpu,
                                           config=run_config,
                                           train_batch_size=FLAGS.batch_size)
    estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)
示例#11
0
文件: cifar_keras.py 项目: tgrel/tpu
def main(argv):
    del argv  # Unused.

    if FLAGS.master is None and FLAGS.tpu_name is None:
        raise RuntimeError("You must specify either --master or --tpu_name.")

    if FLAGS.master is not None:
        if FLAGS.tpu_name is not None:
            tf.logging.warn("Both --master and --tpu_name are set. Ignoring "
                            "--tpu_name and using --master.")
        tpu_grpc_url = FLAGS.master
    else:
        tpu_cluster_resolver = (tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
        tpu_grpc_url = tpu_cluster_resolver.get_master()

    run_config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_secs=3600,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_shards),
    )

    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           use_tpu=FLAGS.use_tpu,
                                           config=run_config,
                                           train_batch_size=FLAGS.batch_size)
    estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)
示例#12
0
def main(unused_argv):
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data_dir)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data_dir)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                   batches_per_epoch, current_step))
    #start_timestamp = time.time()
    #while current_step < FLAGS.train_steps:
    # Train for up to steps_per_eval number of steps. At the end of training, a
    # checkpoint will be written to --model_dir.
    #  next_checkpoint = min(current_step + FLAGS.steps_per_eval,
    #                        FLAGS.train_steps)
    resnet_classifier.train(input_fn=imagenet_train.input_fn,
                            max_steps=FLAGS.train_steps)
def main(_):
    # define
    tpu_grpc_url = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu=["demo-tpu"]).get_master()
    model_dir = os.path.join(FLAGS.out_dir, str(int(time.time()))) + "/"
    run_config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        model_dir=model_dir,
        save_checkpoints_secs=3600,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=100, num_shards=FLAGS.num_replica
        )  # when you skip num_shards, 8 is used.
    )
    cifar10_resnet_classifier = tpu_estimator.TPUEstimator(
        model_fn=_my_model_fn,
        use_tpu=True,
        config=run_config,
        train_batch_size=batch_size)

    # run !
    cifar10_resnet_classifier.train(
        input_fn=_my_input_fn,
        #max_steps=50000 * 10 // batch_size) # Full spec
        max_steps=5000)  # For benchmarking
示例#14
0
def run_toy_model_tpu():
  """Run a toy model on TPU."""
  iterations_per_loop = FLAGS.iterations
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  config = tpu_config.RunConfig(
      master=FLAGS.master,
      evaluation_master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=None,  # Disable the default saver
      save_checkpoints_secs=None,  # Disable the default saver
      log_step_count_steps=iterations_per_loop,
      tpu_config=tpu_config.TPUConfig(
          num_shards=mesh_shape.size,
          iterations_per_loop=iterations_per_loop,
          num_cores_per_replica=1,
          per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))
  classifier = tpu_estimator.TPUEstimator(
      use_tpu=True,
      model_fn=model_fn,
      config=config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size)
  current_step = estimator_lib._load_global_step_from_checkpoint_dir(FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
  logging.info('Current step %d', current_step)
  while current_step < FLAGS.train_steps:
    next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint,
                          FLAGS.train_steps)
    classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint)
    current_step = next_checkpoint

    tf.logging.info('Starting to evaluate.')
    eval_results = classifier.evaluate(
        input_fn=ToyModelInput(),
        steps=156)  # since we have 10000 examples and batch_size = 64 per host
    logging.info('Eval results: %s', eval_results)
示例#15
0
def main(argv):
    del argv
    training_examples = (FLAGS.train_epochs * 40000)
    eval_examples = 10000
    iterations_per_loop = ((training_examples // 10) // FLAGS.train_batch_size)

    if FLAGS.master is None and FLAGS.tpu_name is None:
        raise RuntimeError("You must specify either --master or --tpu_name.")

    if FLAGS.master is not None:
        if FLAGS.tpu_name is not None:
            tf.logging.warn("Both --master and --tpu_name are set. Ignoring "
                            "--tpu_name and using --master.")
        tpu_grpc_url = FLAGS.master
    else:
        tpu_cluster_resolver = (
            tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
                tpu_names=[FLAGS.tpu_name],
                zone=FLAGS.tpu_zone,
                project=FLAGS.gcp_project))
        tpu_grpc_url = tpu_cluster_resolver.get_master()

    run_config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.steps_per_checkpoint,
        log_step_count_steps=iterations_per_loop,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=FLAGS.num_shards,
        ),
    )

    estimator = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        params=dict(CIFAR_SMALL_PARAMS, use_tpu=FLAGS.use_tpu),
    )

    # Evaluate the test set after 5% of training examples are finished.
    for cycle in range(10):
        tf.logging.info("Starting %d train steps" %
                        (training_examples // 10 // FLAGS.train_batch_size))
        estimator.train(input_fn=InputReader(FLAGS.train_file,
                                             is_training=True),
                        steps=training_examples // 10 //
                        FLAGS.train_batch_size)

        tf.logging.info("Starting evaluation cycle %d ." % cycle)
        print(
            estimator.evaluate(
                input_fn=InputReader(FLAGS.train_file, is_training=False),
                steps=eval_examples // FLAGS.eval_batch_size,
            ))
示例#16
0
def main(argv):
  del argv

  # Hyperparameters derived from the paper
  hparams = mobilenet_hparams()
  hparams.parse(FLAGS.hparams)

  params = dict(
      hparams.values(),
      num_eval_examples=FLAGS.num_eval_examples,
      num_examples_per_epoch=FLAGS.num_examples_per_epoch,
      num_shards=FLAGS.num_shards,
      num_batches_per_epoch=FLAGS.num_examples_per_epoch / FLAGS.batch_size,
  )

  with tf.gfile.GFile(FLAGS.model_dir + "/hparams.json", "w") as f:
    tf.gfile.MakeDirs(FLAGS.model_dir)
    f.write(hparams.to_json())

  num_training_examples = FLAGS.num_examples_per_epoch * params["num_epochs"]
  num_eval_batches = FLAGS.num_eval_examples // FLAGS.batch_size
  num_training_batches = num_training_examples // FLAGS.batch_size

  run_config = tpu_config.RunConfig(
      master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      save_checkpoints_secs=FLAGS.save_checkpoints_secs,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False),
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=100,
          num_shards=FLAGS.num_shards,
      ),
  )

  estimator = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      config=run_config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size,
      params=dict(params, use_tpu=FLAGS.use_tpu),
  )

  # Evaluate the test set after each epoch of the training set is processed.
  for _ in range(FLAGS.num_epochs):
    tf.logging.info("Training one epoch: %s steps",
                    num_training_batches // FLAGS.num_epochs)
    estimator.train(
        input_fn=data_pipeline.InputReader(FLAGS.data_dir, is_training=True),
        steps=num_training_batches // FLAGS.num_epochs)

    tf.logging.info("Running evaluation")
    tf.logging.info("%s",
                    estimator.evaluate(
                        input_fn=data_pipeline.InputReader(
                            FLAGS.data_dir, is_training=False),
                        steps=num_eval_batches,
                    ))
示例#17
0
def main(_):
    config = params_dict.ParamsDict(mask_rcnn_config.MASK_RCNN_CFG,
                                    mask_rcnn_config.MASK_RCNN_RESTRICTIONS)
    config = params_dict.override_params_dict(config,
                                              FLAGS.config,
                                              is_strict=True)
    config.is_training_bn = False
    config.train_batch_size = FLAGS.batch_size
    config.eval_batch_size = FLAGS.batch_size

    config.validate()
    config.lock()

    model_params = dict(config.as_dict().items(),
                        use_tpu=FLAGS.use_tpu,
                        mode=tf.estimator.ModeKeys.PREDICT,
                        transpose_input=False)

    print(' - Setting up TPUEstimator...')
    estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=serving.serving_model_fn_builder(
            FLAGS.output_source_id, FLAGS.output_image_info,
            FLAGS.output_box_features, FLAGS.output_normalized_coordinates,
            FLAGS.cast_num_detections_to_float),
        model_dir=FLAGS.model_dir,
        config=tpu_config.RunConfig(tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop),
                                    master='local',
                                    evaluation_master='local'),
        params=model_params,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.batch_size,
        predict_batch_size=FLAGS.batch_size,
        export_to_tpu=FLAGS.use_tpu,
        export_to_cpu=True)

    print(' - Exporting the model...')
    input_type = FLAGS.input_type
    export_path = estimator.export_saved_model(
        export_dir_base=FLAGS.export_dir,
        serving_input_receiver_fn=functools.partial(
            serving.serving_input_fn,
            batch_size=FLAGS.batch_size,
            desired_image_size=config.image_size,
            padding_stride=(2**config.max_level),
            input_type=input_type,
            input_name=FLAGS.input_name),
        checkpoint_path=FLAGS.checkpoint_path)

    if FLAGS.add_warmup_requests and input_type == 'image_bytes':
        inference_warmup.write_warmup_requests(
            export_path,
            FLAGS.model_name,
            config.image_size,
            batch_sizes=[FLAGS.batch_size],
            image_format='JPEG',
            input_signature=FLAGS.input_name)
    print(' - Done! path: %s' % export_path)
示例#18
0
def main(argv):
    del argv

    if FLAGS.master is None and FLAGS.tpu_name is None:
        raise RuntimeError("You must specify either --master or --tpu_name.")

    if FLAGS.master is not None:
        if FLAGS.tpu_name is not None:
            tf.logging.warn("Both --master and --tpu_name are set. Ignoring "
                            "--tpu_name and using --master.")
        tpu_grpc_url = FLAGS.master
    else:
        tpu_cluster_resolver = (tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
        tpu_grpc_url = tpu_cluster_resolver.get_master()

    training_examples = 1300 * 1000 * FLAGS.num_epochs
    eval_examples = 50 * 1000

    params = {
        "num_classes": 1001,
        "lr": FLAGS.learning_rate,
        "min_lr": 0.005,
        "momentum": FLAGS.momentum,
        "optimizer": FLAGS.optimizer,
        "num_eval_examples": eval_examples,
        "num_shards": FLAGS.num_shards,
        "num_epochs": FLAGS.num_epochs,
    }

    run_config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_secs=FLAGS.save_checkpoints_secs,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations,
            num_shards=FLAGS.num_shards,
        ),
    )

    estimator = tpu_estimator.TPUEstimator(
        model_fn=squeezenet_model.model_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.batch_size,
        params=dict(params, use_tpu=FLAGS.use_tpu),
    )

    #num_evals = max(FLAGS.num_evals, 1)
    #examples_per_eval = training_examples // num_evals
    #for _ in range(num_evals):
    estimator.train(
        input_fn=data_pipeline.InputReader(FLAGS.data_dir, is_training=True),
        #steps=examples_per_eval // FLAGS.batch_size)
        steps=FLAGS.train_steps)
示例#19
0
def main(unused_argv):
    del unused_argv

    start = time.time()
    tf.logging.set_verbosity(tf.logging.INFO)
    print('Tensorflow version: ' + str(tf.__version__))
    for k, v in iter(tf.app.flags.FLAGS.flag_values_dict().items()):
        print("***%s: %s" % (k, v))

    if FLAGS.use_tpu == True:
        if FLAGS.tpu_name is None:
            raise RuntimeError("You must specify --tpu_name.")

        else:
            if '1.6.0' in tf.__version__:
                tpu_cluster_resolver = (
                    tf.contrib.cluster_resolver.TPUClusterResolver(
                        tpu_names=[FLAGS.tpu_name],
                        zone=FLAGS.tpu_zone,
                        project=FLAGS.gcp_project))
            else:
                tpu_cluster_resolver = (
                    tf.contrib.cluster_resolver.TPUClusterResolver(
                        FLAGS.tpu_name,
                        zone=FLAGS.tpu_zone,
                        project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        tpu_grpc_url = ''

    run_config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_secs=None,
        session_config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False,
            gpu_options=tf.GPUOptions(allow_growth=True)),
        tpu_config=tpu_config.TPUConfig(iterations_per_loop=FLAGS.iterations,
                                        num_shards=FLAGS.num_shards),
    )
    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           params={
                                               "output_size": output_size,
                                               "input_size": input_size
                                           },
                                           use_tpu=FLAGS.use_tpu,
                                           train_batch_size=batch_size,
                                           config=run_config)
    estimator.train(input_fn=get_input_fn(input_size, output_size),
                    max_steps=FLAGS.train_steps)

    total_time = time.time() - start
    example_per_sec = batch_size * FLAGS.train_steps / total_time
    global_step_per_sec = FLAGS.train_steps / total_time
    print("Total time: " + str(total_time))
示例#20
0
def main(unused_argv):

    start = time.time()
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.use_tpu:
        tf.logging.info("Using TPUs.")
    else:
        tf.logging.info("NOT using TPUs.")

    if FLAGS.use_tpu:
        tf.logging.info('tpu name:', FLAGS.tpu_name)
        if FLAGS.tpu_name is None:
            raise RuntimeError("You must specify --tpu_name.")

        else:
            if '1.6.0' in tf.__version__:
                tpu_cluster_resolver = (
                    tf.contrib.cluster_resolver.TPUClusterResolver(
                        tpu_names=[os.uname()[1]],
                        zone=FLAGS.tpu_zone,
                        project=FLAGS.gcp_project))
            else:
                tpu_cluster_resolver = (
                    tf.contrib.cluster_resolver.TPUClusterResolver(
                        os.uname()[1],
                        zone=FLAGS.tpu_zone,
                        project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        tpu_grpc_url = ''

    run_config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_secs=None,
        tpu_config=tpu_config.TPUConfig(iterations_per_loop=FLAGS.iterations,
                                        num_shards=FLAGS.num_shards),
    )

    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           params={
                                               "bs": FLAGS.batch_size,
                                               "output_dim": output_dim,
                                               "input_dim": input_dim
                                           },
                                           use_tpu=FLAGS.use_tpu,
                                           train_batch_size=FLAGS.batch_size,
                                           config=run_config)
    estimator.train(input_fn=get_input_fn(FLAGS.batch_size, input_dim,
                                          output_dim),
                    max_steps=FLAGS.train_steps)

    total = time.time() - start
    tf.logging.info("Total time: " + str(total))
示例#21
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)

    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.validate()
    params.lock()

    model_params = dict(params.as_dict(),
                        use_tpu=FLAGS.use_tpu,
                        mode=tf.estimator.ModeKeys.PREDICT,
                        transpose_input=False)

    print(' - Setting up TPUEstimator...')
    estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=serving.serving_model_fn_builder(
            FLAGS.use_tpu, FLAGS.output_image_info,
            FLAGS.output_normalized_coordinates,
            FLAGS.cast_num_detections_to_float),
        model_dir=None,
        config=tpu_config.RunConfig(
            tpu_config=tpu_config.TPUConfig(iterations_per_loop=1),
            master='local',
            evaluation_master='local'),
        params=model_params,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.batch_size,
        predict_batch_size=FLAGS.batch_size,
        export_to_tpu=FLAGS.use_tpu,
        export_to_cpu=True)

    print(' - Exporting the model...')
    input_type = FLAGS.input_type
    image_size = [int(x) for x in FLAGS.input_image_size.split(',')]
    export_path = estimator.export_saved_model(
        export_dir_base=FLAGS.export_dir,
        serving_input_receiver_fn=functools.partial(
            serving.serving_input_fn,
            batch_size=FLAGS.batch_size,
            desired_image_size=image_size,
            stride=(2**params.anchor.max_level),
            input_type=input_type,
            input_name=FLAGS.input_name),
        checkpoint_path=FLAGS.checkpoint_path)

    print(' - Done! path: %s' % export_path)
示例#22
0
 def _make_run_config(model_dir):
     return tpu_config.RunConfig(
         master=master,
         model_dir=model_dir,
         save_checkpoints_secs=10000,
         session_config=config_pb2.ConfigProto(allow_soft_placement=True,
                                               log_device_placement=False),
         tpu_config=tpu_config.TPUConfig(
             iterations_per_loop=num_steps,
             num_shards=num_shards,
         ),
     )
示例#23
0
def main(unused_argv):
    flags.mark_flag_as_required('model_dir')
    flags.mark_flag_as_required('pipeline_config_path')

    tpu_cluster_resolver = (
        tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
            tpu_names=[FLAGS.tpu_name],
            zone=FLAGS.tpu_zone,
            project=FLAGS.gcp_project))
    tpu_grpc_url = tpu_cluster_resolver.get_master()

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_shards))

    kwargs = {}
    if FLAGS.train_batch_size:
        kwargs['batch_size'] = FLAGS.train_batch_size

    train_and_eval_dict = model_lib.create_estimator_and_inputs(
        run_config=config,
        hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
        pipeline_config_path=FLAGS.pipeline_config_path,
        train_steps=FLAGS.num_train_steps,
        eval_steps=FLAGS.num_eval_steps,
        use_tpu_estimator=True,
        use_tpu=FLAGS.use_tpu,
        num_shards=FLAGS.num_shards,
        **kwargs)
    estimator = train_and_eval_dict['estimator']
    train_input_fn = train_and_eval_dict['train_input_fn']
    eval_input_fn = train_and_eval_dict['eval_input_fn']
    eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
    train_steps = train_and_eval_dict['train_steps']
    eval_steps = train_and_eval_dict['eval_steps']

    if FLAGS.mode == 'train':
        estimator.train(input_fn=train_input_fn, max_steps=train_steps)

    # Continuously evaluating.
    if FLAGS.mode == 'eval':
        if FLAGS.eval_training_data:
            name = 'training_data'
            input_fn = eval_on_train_input_fn
        else:
            name = 'validation_data'
            input_fn = eval_input_fn
        model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn,
                                  eval_steps, train_steps, name)
示例#24
0
def get_estimator(model_dir, resolution):
    tpu_cluster_resolver = None

    if FLAGS.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

        config = tpu_config.RunConfig(
            cluster=tpu_cluster_resolver,
            model_dir=model_dir,
            tpu_config=tpu_config.TPUConfig(
                num_shards=FLAGS.num_shards,
                iterations_per_loop=FLAGS.iterations_per_loop))

        est = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                         use_tpu=FLAGS.use_tpu,
                                         config=config,
                                         params={
                                             "data_dir": FLAGS.data_dir,
                                             "resolution": resolution
                                         },
                                         train_batch_size=FLAGS.batch_size,
                                         eval_batch_size=FLAGS.batch_size)

        local_est = tpu_estimator.TPUEstimator(
            model_fn=model_fn,
            use_tpu=False,
            config=config,
            params={
                "data_dir": FLAGS.data_dir,
                "resolution": resolution
            },
            predict_batch_size=FLAGS.num_eval_images)
    else:
        est = tf.estimator.Estimator(model_fn=model_fn,
                                     model_dir=model_dir,
                                     params={
                                         "data_dir": FLAGS.data_dir,
                                         "batch_size": FLAGS.batch_size,
                                         "resolution": resolution
                                     })

        local_est = tf.estimator.Estimator(model_fn=model_fn,
                                           model_dir=model_dir,
                                           params={
                                               "data_dir": FLAGS.data_dir,
                                               "batch_size":
                                               FLAGS.num_eval_images,
                                               "resolution": resolution
                                           })
    return est, local_est
示例#25
0
def main(argv):
  del argv
  training_examples = 1300 * 1000 * FLAGS.num_epochs
  eval_examples = 50 * 1000

  params = {
      "num_classes": 1001,
      "lr": 0.04,
      "min_lr": 0.0004,
      "momentum": FLAGS.momentum,
      "optimizer": FLAGS.optimizer,
      "num_eval_examples": eval_examples,
      "num_shards": FLAGS.num_shards,
      "num_epochs": FLAGS.num_epochs,
  }

  run_config = tpu_config.RunConfig(
      master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      save_checkpoints_secs=FLAGS.save_checkpoints_secs,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False),
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=100,
          num_shards=FLAGS.num_shards,
      ),
  )

  estimator = tpu_estimator.TPUEstimator(
      model_fn=squeezenet_model.model_fn,
      use_tpu=FLAGS.use_tpu,
      config=run_config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size,
      params=dict(params, use_tpu=FLAGS.use_tpu),
  )

  # Evaluate the test set after 5% of training examples are finished.
  num_evals = 20
  for _ in range(num_evals):
    estimator.train(
        input_fn=data_pipeline.InputReader(FLAGS.data_dir, is_training=True),
        steps=training_examples // (num_evals * FLAGS.batch_size))

    tf.logging.info("Running evaluation")
    tf.logging.info("%s",
                    estimator.evaluate(
                        input_fn=data_pipeline.InputReader(
                            FLAGS.data_dir, is_training=False),
                        steps=eval_examples // FLAGS.batch_size,
                    ))
示例#26
0
def train(*tf_records, steps=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=None, project=None)
        tpu_grpc_url = tpu_cluster_resolver.get_master()

        config = tpu_config.RunConfig(
            master=tpu_grpc_url,
            evaluation_master=tpu_grpc_url,
            model_dir=FLAGS.model_dir,
            save_checkpoints_steps=max(800, FLAGS.iterations_per_loop),
            session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True),
            tpu_config=tpu_config.TPUConfig(
                iterations_per_loop=FLAGS.iterations_per_loop,
                num_shards=FLAGS.num_tpu_cores,
                per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

        estimator = tpu_estimator.TPUEstimator(
            use_tpu=FLAGS.use_tpu,
            model_fn=model_fn,
            config=config,
            train_batch_size=FLAGS.train_batch_size * FLAGS.num_tpu_cores,
            eval_batch_size=FLAGS.train_batch_size * FLAGS.num_tpu_cores)

        def input_fn(params):
            return preprocessing.get_tpu_input_tensors(params['batch_size'],
                                                       tf_records)

        # TODO: get hooks working again with TPUestimator.
        hooks = []
    else:
        estimator = get_estimator(FLAGS.model_dir)

        def input_fn():
            return preprocessing.get_input_tensors(
                FLAGS.train_batch_size,
                tf_records,
                filter_amount=1.0,
                shuffle_buffer_size=FLAGS.shuffle_buffer_size)

        hooks = [
            UpdateRatioSessionHook(FLAGS.model_dir),
            EchoStepCounterHook(output_dir=FLAGS.model_dir)
        ]

    if steps is None:
        steps = EXAMPLES_PER_GENERATION // FLAGS.train_batch_size
    print("Training, steps = {}".format(steps))
    estimator.train(input_fn, steps=int(steps), hooks=hooks)
示例#27
0
    def __init__(self, observation_size_x, observation_size_y,
                 observation_size_z, action_size):
        super().__init__(observation_size_x, observation_size_y,
                         observation_size_z, action_size)

        if 'TPU_NAME' in os.environ:
            self.tpu = [os.environ['TPU_NAME']]
        else:
            self.tpu = ["demo-tpu"]
        self.tpu_zone = "us-central1-b"
        self.gcp_project = "alpha-zero-arvi"

        self.batch_size = 1024
        self.learning_rate = 0.001
        self.use_tpu = False
        self.iterations_per_loop = 10
        self.num_shards = 8

        self.temp_dir = "temp/tpu_estimator_data"
        os.makedirs(self.temp_dir, exist_ok=True)

        if self.use_tpu:
            print("TPU config: \n"
                  " tpu name: %s \n"
                  " project: %s \n"
                  " tpu_zone: %s" %
                  (self.tpu, self.gcp_project, self.tpu_zone))

            tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
                tpu=self.tpu, zone=self.tpu_zone, project=self.gcp_project)

            self.run_config = tpu_config.RunConfig(
                cluster=tpu_cluster_resolver,
                save_checkpoints_secs=1200,
                session_config=tf.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=True),
                tpu_config=tpu_config.TPUConfig(
                    iterations_per_loop=self.iterations_per_loop,
                    num_shards=self.num_shards),
            )

        self.cur_model_dir = self.temp_dir

        self._build_estimator(self.cur_model_dir)
        self.fast_estimator = None

        self.saved = False

        self.export_dir = None
        self.predict_fn = None
示例#28
0
def main(_):
  """Run training/eval/inference."""
  cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
      tpu=[FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  my_tpu_config = tpu_config.TPUConfig(
      iterations_per_loop=FLAGS.iterations_per_loop,
      num_cores_per_replica=1,
      per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
  )

  run_config = tpu_config.RunConfig(
      cluster=cluster,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=my_tpu_config)

  estimator = tpu_estimator.TPUEstimator(
      model_fn=my_model_fn,
      config=run_config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size,
      predict_batch_size=FLAGS.batch_size,
      use_tpu=FLAGS.tpu,
      export_to_tpu=False)

  def input_fn(params):
    del params
    return transformer_dataset.get_dataset(
        FLAGS.dataset,
        FLAGS.data_dir or None,
        train=(FLAGS.mode == "train"),
        batch_size=FLAGS.batch_size,
        length=length_from_flags())

  if FLAGS.mode == "train":
    estimator.train(
        input_fn=input_fn,
        max_steps=FLAGS.train_steps
    )
  elif FLAGS.mode == "evaluate":
    estimator.evaluate(
        input_fn=input_fn,
        steps=FLAGS.eval_steps,
    )
  elif FLAGS.mode == "infer":
    decode_from_file(estimator)
  else:
    raise ValueError(
        "unknown mode %s - must be train/evaluate/infer" % FLAGS.mode)
示例#29
0
def main(_):
  config = mask_rcnn_params.default_config()
  config = params_io.override_hparams(config, FLAGS.config)
  config.is_training_bn = False
  config.train_batch_size = FLAGS.batch_size
  config.eval_batch_size = FLAGS.batch_size

  model_params = dict(
      config.values(),
      use_tpu=FLAGS.use_tpu,
      mode=tf.estimator.ModeKeys.PREDICT,
      transpose_input=False)

  print(' - Setting up TPUEstimator...')
  estimator = tf.contrib.tpu.TPUEstimator(
      model_fn=mask_rcnn_model.mask_rcnn_model_fn,
      model_dir=FLAGS.model_dir,
      config=tpu_config.RunConfig(
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=FLAGS.iterations_per_loop),
          master='local',
          evaluation_master='local'),
      params=model_params,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      predict_batch_size=FLAGS.batch_size,
      export_to_tpu=FLAGS.use_tpu,
      export_to_cpu=True,
      experimental_exported_model_uses_all_cores=FLAGS.inference_with_all_cores)

  print(' - Exporting the model...')
  input_type = FLAGS.input_type
  export_path = estimator.export_saved_model(
      export_dir_base=FLAGS.export_dir,
      serving_input_receiver_fn=functools.partial(
          serving_inputs.serving_input_fn,
          batch_size=FLAGS.batch_size,
          desired_image_size=config.image_size,
          padding_stride=(2**config.max_level),
          input_type=input_type),
      checkpoint_path=FLAGS.checkpoint_path)
  if FLAGS.add_warmup_requests and input_type == 'image_bytes':
    inference_warmup.write_warmup_requests(
        export_path,
        FLAGS.model_name,
        config.image_size,
        batch_sizes=[FLAGS.batch_size],
        image_format='JPEG',
        input_signature=serving_inputs.INPUT_SIGNATURE)
示例#30
0
def main(unused_argv):
  del unused_argv  # Unused

  if FLAGS.input_layout not in ['NCHW', 'NHWC']:
    raise RuntimeError('--input_layout must be one of [NCHW, NHWC]')

  run_config = tpu_config.RunConfig(
      master=FLAGS.master,
      evaluation_master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      save_checkpoints_secs=FLAGS.save_checkpoints_secs,
      save_summary_steps=FLAGS.save_summary_steps,
      session_config=tf.ConfigProto(
          allow_soft_placement=True,
          log_device_placement=FLAGS.log_device_placement),
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=FLAGS.iterations,
          num_shards=FLAGS.num_shards))

  inception_classifier = tpu_estimator.TPUEstimator(
      model_fn=inception_model_fn,
      use_tpu=FLAGS.use_tpu,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      batch_axis=(get_batch_axis(
          FLAGS.train_batch_size // FLAGS.num_shards), 0))

  for cycle in range(FLAGS.train_steps // FLAGS.train_steps_per_eval):
    # tensors_to_log = {
    #     'learning_rate': 'learning_rate',
    #     'prediction_loss': 'prediction_loss',
    #     'train_accuracy': 'train_accuracy'
    # }

    # logging_hook = tf.train.LoggingTensorHook(
    #     tensors=tensors_to_log, every_n_iter=100)

    tf.logging.info('Starting training cycle %d.' % cycle)
    inception_classifier.train(
        input_fn=ImageNetInput(True), steps=FLAGS.train_steps_per_eval)

    if FLAGS.eval_enabled:
      eval_steps = (imagenet.get_split_size('validation') //
                    FLAGS.eval_batch_size)
      tf.logging.info('Starting evaluation cycle %d .' % cycle)
      eval_results = inception_classifier.evaluate(
          input_fn=ImageNetInput(False), steps=eval_steps)
      tf.logging.info('Evaluation results: %s' % eval_results)