Exemplo n.º 1
0
    def _retrieve_data(self, is_training, data_dir):

        dataset = imagenet_input.ImageNetInput(is_training=is_training,
                                               data_dir=data_dir,
                                               transpose_input=False,
                                               num_parallel_calls=8,
                                               use_bfloat16=False)

        return dataset
Exemplo n.º 2
0
def imagenet_test(transpose, params, use_bfloat16):
    return imagenet_input.ImageNetInput(
        is_training=False,
        dataset_split="test",
        batch_size=FLAGS.predict_batch_size,
        data_dir=FLAGS.data_dir,
        transpose_input=transpose,
        cache=False,
        image_size=params['image_size'],
        num_parallel_calls=params['num_parallel_calls'],
        use_bfloat16=use_bfloat16)
Exemplo n.º 3
0
def main(unused_argv):
    params = resnet_params.from_file(FLAGS.param_file)
    params = resnet_params.override(params, FLAGS.param_overrides)
    resnet_params.log_hparams_to_model_dir(params, FLAGS.model_dir)
    tf.logging.info('Model params: {}'.format(params))

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu if (FLAGS.tpu or params['use_tpu']) else '',
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)

    if params['use_async_checkpointing']:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, params['iterations_per_loop'])
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=params['iterations_per_loop'],
            num_shards=params['num_cores'],
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    if FLAGS.inference_with_all_cores:
        resnet_classifier = tf.contrib.tpu.TPUEstimator(
            use_tpu=params['use_tpu'],
            model_fn=resnet_model_fn,
            config=config,
            params=params,
            train_batch_size=params['train_batch_size'],
            eval_batch_size=params['eval_batch_size'],
            export_to_tpu=FLAGS.export_to_tpu,
            experimental_exported_model_uses_all_cores=FLAGS.
            inference_with_all_cores)
    else:
        resnet_classifier = tf.contrib.tpu.TPUEstimator(
            use_tpu=params['use_tpu'],
            model_fn=resnet_model_fn,
            config=config,
            params=params,
            train_batch_size=params['train_batch_size'],
            eval_batch_size=params['eval_batch_size'],
            export_to_tpu=FLAGS.export_to_tpu)
    assert (params['precision'] == 'bfloat16' or params['precision']
            == 'float32'), ('Invalid value for precision parameter; '
                            'must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', params['precision'])
    use_bfloat16 = params['precision'] == 'bfloat16'

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=use_bfloat16,
                transpose_input=params['transpose_input'],
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=params['transpose_input'],
                cache=params['use_cache'] and is_training,
                image_size=params['image_size'],
                num_parallel_calls=params['num_parallel_calls'],
                use_bfloat16=use_bfloat16) for is_training in [True, False]
        ]

    steps_per_epoch = params['num_train_images'] // params['train_batch_size']
    eval_steps = params['num_eval_images'] // params['eval_batch_size']

    if FLAGS.mode == 'eval':

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= params['train_steps']:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        steps_per_epoch = params['num_train_images'] // params[
            'train_batch_size']
        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', params['train_steps'],
            params['train_steps'] / steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if params['use_async_checkpointing']:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, params['iterations_per_loop'])))
            if FLAGS.profile_every_n_steps > 0:
                hooks.append(
                    tpu_profiler_hook.TPUProfilerHook(
                        save_steps=FLAGS.profile_every_n_steps,
                        output_dir=FLAGS.model_dir,
                        tpu=FLAGS.tpu))
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=params['train_steps'],
                                    hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < params['train_steps']:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      params['train_steps'])
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=params['num_eval_images'] //
                    params['eval_batch_size'])
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                params['train_steps'], elapsed_time)

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            export_path = resnet_classifier.export_saved_model(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
            if FLAGS.add_warmup_requests:
                inference_warmup.write_warmup_requests(
                    export_path,
                    FLAGS.model_name,
                    params['image_size'],
                    batch_sizes=FLAGS.inference_batch_sizes,
                    image_format='JPEG')
Exemplo n.º 4
0
def main(unused_argv):
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu if (FLAGS.tpu or FLAGS.use_tpu) else '',
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)

    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=max(600, FLAGS.iterations_per_loop),
        log_step_count_steps=FLAGS.log_step_count_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

    resnet_classifier = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        export_to_tpu=False)
    assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
        'Invalid value for --precision flag; must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', FLAGS.precision)
    use_bfloat16 = FLAGS.precision == 'bfloat16'

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=use_bfloat16,
                transpose_input=FLAGS.transpose_input,
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=use_bfloat16) for is_training in [True, False]
        ]

    steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size
    eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size

    if FLAGS.mode == 'eval':

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=FLAGS.train_steps)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=FLAGS.num_eval_images // FLAGS.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                FLAGS.train_steps, elapsed_time)

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            resnet_classifier.export_savedmodel(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
Exemplo n.º 5
0
def main(unused_argv):
  params = params_dict.ParamsDict(
      resnet_config.RESNET_CFG, resnet_config.RESNET_RESTRICTIONS)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=True)
  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)

  params = flags_to_params.override_params_from_input_flags(params, FLAGS)

  params.validate()
  params.lock()

  tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu if (FLAGS.tpu or params.use_tpu) else '',
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project)

  if params.use_async_checkpointing:
    save_checkpoints_steps = None
  else:
    save_checkpoints_steps = max(5000, params.iterations_per_loop)
  config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      log_step_count_steps=FLAGS.log_step_count_steps,
      session_config=tf.ConfigProto(
          graph_options=tf.GraphOptions(
              rewrite_options=rewriter_config_pb2.RewriterConfig(
                  disable_meta_optimizer=True))),
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=params.iterations_per_loop,
          num_shards=params.num_cores,
          per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
          .PER_HOST_V2))  # pylint: disable=line-too-long

  resnet_classifier = tf.estimator.tpu.TPUEstimator(
      use_tpu=params.use_tpu,
      model_fn=resnet_model_fn,
      config=config,
      params=params.as_dict(),
      train_batch_size=params.train_batch_size,
      eval_batch_size=params.eval_batch_size,
      export_to_tpu=FLAGS.export_to_tpu)

  assert (params.precision == 'bfloat16' or
          params.precision == 'float32'), (
              'Invalid value for precision parameter; '
              'must be bfloat16 or float32.')
  tf.logging.info('Precision: %s', params.precision)
  use_bfloat16 = params.precision == 'bfloat16'

  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  if FLAGS.bigtable_instance:
    tf.logging.info('Using Bigtable dataset, table %s', FLAGS.bigtable_table)
    select_train, select_eval = _select_tables_from_flags()
    imagenet_train, imagenet_eval = [
        imagenet_input.ImageNetBigtableInput(  # pylint: disable=g-complex-comprehension
            is_training=is_training,
            use_bfloat16=use_bfloat16,
            transpose_input=params.transpose_input,
            selection=selection,
            augment_name=FLAGS.augment_name,
            randaug_num_layers=FLAGS.randaug_num_layers,
            randaug_magnitude=FLAGS.randaug_magnitude)
        for (is_training, selection) in [(True,
                                          select_train), (False, select_eval)]
    ]
  else:
    if FLAGS.data_dir == FAKE_DATA_DIR:
      tf.logging.info('Using fake dataset.')
    else:
      tf.logging.info('Using dataset: %s', FLAGS.data_dir)
    imagenet_train, imagenet_eval = [
        imagenet_input.ImageNetInput(  # pylint: disable=g-complex-comprehension
            is_training=is_training,
            data_dir=FLAGS.data_dir,
            transpose_input=params.transpose_input,
            cache=params.use_cache and is_training,
            image_size=params.image_size,
            num_parallel_calls=params.num_parallel_calls,
            include_background_label=(params.num_label_classes == 1001),
            use_bfloat16=use_bfloat16,
            augment_name=FLAGS.augment_name,
            randaug_num_layers=FLAGS.randaug_num_layers,
            randaug_magnitude=FLAGS.randaug_magnitude)
        for is_training in [True, False]
    ]

  steps_per_epoch = params.num_train_images // params.train_batch_size
  eval_steps = params.num_eval_images // params.eval_batch_size

  if FLAGS.mode == 'eval':

    # Run evaluation when there's a new checkpoint
    for ckpt in tf.train.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info('Starting to evaluate.')
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                        eval_results, elapsed_time)

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split('-')[1])
        if current_step >= params.train_steps:
          tf.logging.info(
              'Evaluation finished after training step %d', current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info(
            'Checkpoint %s no longer exists, skipping checkpoint', ckpt)

  else:   # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    try:
      current_step = tf.train.load_variable(FLAGS.model_dir,
                                            tf.GraphKeys.GLOBAL_STEP)
    except (TypeError, ValueError, tf.errors.NotFoundError):
      current_step = 0
    steps_per_epoch = params.num_train_images // params.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.',
                    params.train_steps,
                    params.train_steps / steps_per_epoch,
                    current_step)

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == 'train':
      hooks = []
      if params.use_async_checkpointing:
        try:
          from tensorflow.contrib.tpu.python.tpu import async_checkpoint  # pylint: disable=g-import-not-at-top
        except ImportError as e:
          logging.exception(
              'Async checkpointing is not supported in TensorFlow 2.x')
          raise e

        hooks.append(
            async_checkpoint.AsyncCheckpointSaverHook(
                checkpoint_dir=FLAGS.model_dir,
                save_steps=max(5000, params.iterations_per_loop)))
      if FLAGS.profile_every_n_steps > 0:
        hooks.append(
            tpu_profiler_hook.TPUProfilerHook(
                save_steps=FLAGS.profile_every_n_steps,
                output_dir=FLAGS.model_dir, tpu=FLAGS.tpu)
            )
      resnet_classifier.train(
          input_fn=imagenet_train.input_fn,
          max_steps=params.train_steps,
          hooks=hooks)

    else:
      assert FLAGS.mode == 'train_and_eval'
      while current_step < params.train_steps:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              params.train_steps)
        resnet_classifier.train(
            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                        next_checkpoint, int(time.time() - start_timestamp))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=params.num_eval_images // params.eval_batch_size)
        tf.logging.info('Eval results at step %d: %s',
                        next_checkpoint, eval_results)

      elapsed_time = int(time.time() - start_timestamp)
      tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                      params.train_steps, elapsed_time)

    if FLAGS.export_dir is not None:
      # The guide to serve a exported TensorFlow model is at:
      #    https://www.tensorflow.org/serving/serving_basic
      tf.logging.info('Starting to export model.')
      export_path = resnet_classifier.export_saved_model(
          export_dir_base=FLAGS.export_dir,
          serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
      if FLAGS.add_warmup_requests:
        inference_warmup.write_warmup_requests(
            export_path,
            FLAGS.model_name,
            params.image_size,
            batch_sizes=FLAGS.inference_batch_sizes,
            image_format='JPEG')
Exemplo n.º 6
0
def main(unused_argv):
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.iterations_per_loop,
        keep_checkpoint_max=None,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data_dir,
        use_bfloat16=True,
        transpose_input=FLAGS.transpose_input)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data_dir,
        use_bfloat16=True,
        transpose_input=FLAGS.transpose_input)

    if FLAGS.use_fast_lr:
        resnet_main.LR_SCHEDULE = [  # (multiplier, epoch to start) tuples
            (1.0, 4), (0.1, 21), (0.01, 35), (0.001, 43)
        ]
        imagenet_train_small = imagenet_input.ImageNetInput(
            is_training=True,
            image_size=128,
            data_dir=FLAGS.data_dir_small,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_bfloat16=True,
            transpose_input=FLAGS.transpose_input,
            cache=True)
        imagenet_eval_small = imagenet_input.ImageNetInput(
            is_training=False,
            image_size=128,
            data_dir=FLAGS.data_dir_small,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_bfloat16=True,
            transpose_input=FLAGS.transpose_input,
            cache=True)
        imagenet_train_large = imagenet_input.ImageNetInput(
            is_training=True,
            image_size=288,
            data_dir=FLAGS.data_dir,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_bfloat16=True,
            transpose_input=FLAGS.transpose_input)
        imagenet_eval_large = imagenet_input.ImageNetInput(
            is_training=False,
            image_size=288,
            data_dir=FLAGS.data_dir,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_bfloat16=True,
            transpose_input=FLAGS.transpose_input)

    resnet_classifier = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_main.resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    if FLAGS.mode == 'train':
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                       batches_per_epoch, current_step))

        start_timestamp = time.time(
        )  # This time will include compilation time

        # Write a dummy file at the start of training so that we can measure the
        # runtime at each checkpoint from the file write time.
        tf.gfile.MkDir(FLAGS.model_dir)
        if not tf.gfile.Exists(os.path.join(FLAGS.model_dir, 'START')):
            with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'START'),
                                'w') as f:
                f.write(str(start_timestamp))

        if FLAGS.use_fast_lr:
            small_steps = int(18 * NUM_TRAIN_IMAGES / FLAGS.train_batch_size)
            normal_steps = int(41 * NUM_TRAIN_IMAGES / FLAGS.train_batch_size)
            large_steps = int(
                min(50 * NUM_TRAIN_IMAGES / FLAGS.train_batch_size,
                    FLAGS.train_steps))

            resnet_classifier.train(input_fn=imagenet_train_small.input_fn,
                                    max_steps=small_steps)
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=normal_steps)
            resnet_classifier.train(input_fn=imagenet_train_large.input_fn,
                                    max_steps=large_steps)
        else:
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=FLAGS.train_steps)

    else:
        assert FLAGS.mode == 'eval'

        start_timestamp = tf.gfile.Stat(os.path.join(FLAGS.model_dir,
                                                     'START')).mtime_nsec
        results = []
        eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

        ckpt_steps = set()
        all_files = tf.gfile.ListDirectory(FLAGS.model_dir)
        for f in all_files:
            mat = re.match(CKPT_PATTERN, f)
            if mat is not None:
                ckpt_steps.add(int(mat.group('gs')))
        ckpt_steps = sorted(list(ckpt_steps))
        tf.logging.info('Steps to be evaluated: %s' % str(ckpt_steps))

        for step in ckpt_steps:
            ckpt = os.path.join(FLAGS.model_dir, 'model.ckpt-%d' % step)

            batches_per_epoch = NUM_TRAIN_IMAGES // FLAGS.train_batch_size
            current_epoch = step // batches_per_epoch

            if FLAGS.use_fast_lr:
                if current_epoch < 18:
                    eval_input_fn = imagenet_eval_small.input_fn
                if current_epoch >= 18 and current_epoch < 41:
                    eval_input_fn = imagenet_eval.input_fn
                if current_epoch >= 41:  # 41:
                    eval_input_fn = imagenet_eval_large.input_fn
            else:
                eval_input_fn = imagenet_eval.input_fn

            end_timestamp = tf.gfile.Stat(ckpt + '.index').mtime_nsec
            elapsed_hours = (end_timestamp - start_timestamp) / (1e9 * 3600.0)

            tf.logging.info('Starting to evaluate.')
            eval_start = time.time()  # This time will include compilation time
            eval_results = resnet_classifier.evaluate(input_fn=eval_input_fn,
                                                      steps=eval_steps,
                                                      checkpoint_path=ckpt)
            eval_time = int(time.time() - eval_start)
            tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                            (eval_results, eval_time))
            results.append([
                current_epoch,
                elapsed_hours,
                '%.2f' % (eval_results['top_1_accuracy'] * 100),
                '%.2f' % (eval_results['top_5_accuracy'] * 100),
            ])

            time.sleep(60)

        with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.tsv'), 'wb') as tsv_file:  # pylint: disable=line-too-long
            writer = csv.writer(tsv_file, delimiter='\t')
            writer.writerow(['epoch', 'hours', 'top1Accuracy', 'top5Accuracy'])
            writer.writerows(results)