コード例 #1
0
ファイル: eval_ckpt_main.py プロジェクト: zysxjtu/tpu
def get_eval_driver(model_name, include_background_label=False):
  """Get a eval driver."""
  if model_name.startswith('efficientnet-edgetpu'):
    _, _, image_size, _ = (
        efficientnet_edgetpu_builder.efficientnet_edgetpu_params(model_name))
  elif model_name.startswith('efficientnet-condconv'):
    _, _, image_size, _, _ = (
        efficientnet_condconv_builder.efficientnet_condconv_params(model_name))
  elif model_name.startswith('efficientnet'):
    _, _, image_size, _ = efficientnet_builder.efficientnet_params(model_name)
  else:
    raise ValueError(
        'Model must be either efficientnet-b* or efficientnet-edgetpu* or '
        'efficientnet-condconv*')

  return EvalCkptDriver(
      model_name=model_name,
      batch_size=1,
      image_size=image_size,
      include_background_label=include_background_label)
コード例 #2
0
def get_model_input_size(model_name):
  """Get model input size for a given model name."""
  if model_name.startswith('efficientnet-lite'):
    _, _, image_size, _ = (
        efficientnet_lite_builder.efficientnet_lite_params(model_name))
  elif model_name.startswith('efficientnet-edgetpu-'):
    _, _, image_size, _ = (
        efficientnet_edgetpu_builder.efficientnet_edgetpu_params(model_name))
  elif model_name.startswith('efficientnet-condconv-'):
    _, _, image_size, _, _ = (
        efficientnet_condconv_builder.efficientnet_condconv_params(model_name))
  elif model_name.startswith('efficientnet-x'):
    _, _, image_size, _, _ = efficientnet_x_builder.efficientnet_x_params(
        model_name)
  elif model_name.startswith('efficientnet'):
    _, _, image_size, _ = efficientnet_builder.efficientnet_params(model_name)
  else:
    raise ValueError(
        'Model must be either efficientnet-b* or efficientnet-x-b* or efficientnet-edgetpu* or '
        'efficientnet-condconv*, efficientnet-lite*')
  return image_size
コード例 #3
0
def main(unused_argv):
    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        if FLAGS.model_name.startswith('efficientnet-edgetpu'):
            _, _, input_image_size, _ = efficientnet_edgetpu_builder.efficientnet_edgetpu_params(
                FLAGS.model_name)
        elif FLAGS.model_name.startswith('efficientnet-tpu'):
            _, _, input_image_size, _ = efficientnet_tpu_builder.efficientnet_tpu_params(
                FLAGS.model_name)
        elif FLAGS.model_name.startswith('efficientnet-condconv'):
            _, _, input_image_size, _, _ = efficientnet_condconv_builder.efficientnet_condconv_params(
                FLAGS.model_name)
        elif FLAGS.model_name.startswith('efficientnet'):
            _, _, input_image_size, _ = efficientnet_builder.efficientnet_params(
                FLAGS.model_name)
        else:
            raise ValueError(
                'input_image_size must be set except for EfficientNet')

    # For imagenet dataset, include background label if number of output classes
    # is 1001
    include_background_label = (FLAGS.num_label_classes == 1001)

    if FLAGS.tpu or FLAGS.use_tpu:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    if FLAGS.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
    config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long
    # Initializes model parameters.
    params = dict(steps_per_epoch=FLAGS.num_train_images /
                  FLAGS.train_batch_size,
                  use_bfloat16=FLAGS.use_bfloat16,
                  batch_size=FLAGS.eval_batch_size)
    est = tf.estimator.Estimator(  # used to be tf.estimator.tpu.TPUEstimator
        model_fn=model_fn,
        config=config,
        params=params)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    def build_imagenet_input(is_training):
        """Generate ImageNetInput for training and eval."""
        if FLAGS.bigtable_instance:
            logging.info('Using Bigtable dataset, table %s',
                         FLAGS.bigtable_table)
            select_train, select_eval = _select_tables_from_flags()
            return imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=FLAGS.use_bfloat16,
                transpose_input=FLAGS.transpose_input,
                selection=select_train if is_training else select_eval,
                num_label_classes=FLAGS.num_label_classes,
                include_background_label=include_background_label,
                augment_name=FLAGS.augment_name,
                mixup_alpha=FLAGS.mixup_alpha,
                use_randaug=FLAGS.use_randaug,
                randaug_num_layers=FLAGS.randaug_num_layers,
                randaug_magnitude=FLAGS.randaug_magnitude)
        else:
            if FLAGS.data_dir == FAKE_DATA_DIR:
                logging.info('Using fake dataset.')
            else:
                logging.info('Using dataset: %s', FLAGS.data_dir)

            return imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                image_size=input_image_size,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=FLAGS.use_bfloat16,
                num_label_classes=FLAGS.num_label_classes,
                include_background_label=include_background_label,
                augment_name=FLAGS.augment_name,
                mixup_alpha=FLAGS.mixup_alpha,
                randaug_num_layers=FLAGS.randaug_num_layers,
                randaug_magnitude=FLAGS.randaug_magnitude)

    imagenet_train = build_imagenet_input(is_training=True)
    imagenet_eval = build_imagenet_input(is_training=False)

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                  timeout=FLAGS.eval_timeout):
            logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = est.evaluate(input_fn=imagenet_eval.input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                logging.info('Eval results: %s. Elapsed seconds: %d',
                             eval_results, elapsed_time)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

                quit()  # only eval once
                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-b')[1])
                if current_step >= FLAGS.train_steps:
                    logging.info('Evaluation finished after training step %d',
                                 current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

        logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / params['steps_per_epoch'], current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if FLAGS.use_async_checkpointing:
                try:
                    from tensorflow.contrib.tpu.python.tpu import \
                        async_checkpoint  # pylint: disable=g-import-not-at-top
                except ImportError as e:
                    logging.exception(
                        'Async checkpointing is not supported in TensorFlow 2.x'
                    )
                    raise e

                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, FLAGS.iterations_per_loop)))
            est.train(input_fn=imagenet_train.input_fn,
                      max_steps=FLAGS.train_steps,
                      hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                est.train(input_fn=imagenet_train.input_fn,
                          max_steps=next_checkpoint)
                current_step = next_checkpoint

                logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                logging.info('Starting to evaluate.')
                eval_results = est.evaluate(input_fn=imagenet_eval.input_fn,
                                            steps=FLAGS.num_eval_images //
                                            FLAGS.eval_batch_size)
                logging.info('Eval results at step %d: %s', next_checkpoint,
                             eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                FLAGS.train_steps, elapsed_time)
    if FLAGS.export_dir:
        export(est, FLAGS.export_dir, input_image_size)