Ejemplo n.º 1
0
def override_params_from_input_flags(params, input_flags):
    """Update params dictionary with input flags.

  Args:
    params: ParamsDict object containing dictionary of model parameters.
    input_flags: All the flags with non-null value of overridden model
    parameters.

  Returns:
    ParamsDict object containing dictionary of model parameters.
  """
    if not isinstance(params, params_dict.ParamsDict):
        raise ValueError(
            'The base parameter set must be a ParamsDict, was: {}'.format(
                type(params)))

    essential_flag_dict = {}
    for key in ESSENTIAL_FLAGS:
        flag_value = input_flags.get_flag_value(key, None)

        if flag_value is None:
            logging.warning('Flag %s is None.', key)
        else:
            essential_flag_dict[key] = flag_value

    params_dict.override_params_dict(params,
                                     essential_flag_dict,
                                     is_strict=False)

    normal_flag_dict = get_dictionary_from_flags(params.as_dict(), input_flags)

    params_dict.override_params_dict(params, normal_flag_dict, is_strict=False)

    return params
Ejemplo n.º 2
0
def main(argv):
    del argv  # Unused.

    params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
                                    unet_config.UNET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=False)

    if FLAGS.training_file_pattern:
        params.override({'training_file_pattern': FLAGS.training_file_pattern},
                        is_strict=True)

    if FLAGS.eval_file_pattern:
        params.override({'eval_file_pattern': FLAGS.eval_file_pattern},
                        is_strict=True)

    train_epoch_steps = params.train_item_count // params.train_batch_size
    eval_epoch_steps = params.eval_item_count // params.eval_batch_size

    params.override(
        {
            'model_dir': FLAGS.model_dir,
            'min_eval_interval': FLAGS.min_eval_interval,
            'eval_timeout': FLAGS.eval_timeout,
            'tpu_config': tpu_executor.get_tpu_flags(),
            'lr_decay_steps': train_epoch_steps,
            'train_steps': params.train_epochs * train_epoch_steps,
            'eval_steps': eval_epoch_steps,
        },
        is_strict=False)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)

    params.validate()
    params.lock()

    train_input_fn = None
    eval_input_fn = None
    train_input_shapes = None
    eval_input_shapes = None
    if FLAGS.mode in ('train', 'train_and_eval'):
        train_input_fn = input_reader.LiverInputFn(
            params.training_file_pattern,
            params,
            mode=tf.estimator.ModeKeys.TRAIN)
        train_input_shapes = train_input_fn.get_input_shapes(params)
    if FLAGS.mode in ('eval', 'train_and_eval'):
        eval_input_fn = input_reader.LiverInputFn(
            params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL)
        eval_input_shapes = eval_input_fn.get_input_shapes(params)

    assert train_input_shapes is not None or eval_input_shapes is not None
    run_executer(params,
                 train_input_shapes=train_input_shapes,
                 eval_input_shapes=eval_input_shapes,
                 train_input_fn=train_input_fn,
                 eval_input_fn=eval_input_fn)
Ejemplo n.º 3
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)

    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.validate()
    params.lock()

    model_params = dict(params.as_dict(),
                        use_tpu=FLAGS.use_tpu,
                        mode=tf.estimator.ModeKeys.PREDICT,
                        transpose_input=False)

    print(' - Setting up TPUEstimator...')
    estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=serving.serving_model_fn_builder(
            FLAGS.use_tpu, FLAGS.output_image_info,
            FLAGS.output_normalized_coordinates,
            FLAGS.cast_num_detections_to_float),
        model_dir=None,
        config=tpu_config.RunConfig(
            tpu_config=tpu_config.TPUConfig(iterations_per_loop=1),
            master='local',
            evaluation_master='local'),
        params=model_params,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.batch_size,
        predict_batch_size=FLAGS.batch_size,
        export_to_tpu=FLAGS.use_tpu,
        export_to_cpu=True)

    print(' - Exporting the model...')
    input_type = FLAGS.input_type
    image_size = [int(x) for x in FLAGS.input_image_size.split(',')]
    export_path = estimator.export_saved_model(
        export_dir_base=FLAGS.export_dir,
        serving_input_receiver_fn=functools.partial(
            serving.serving_input_fn,
            batch_size=FLAGS.batch_size,
            desired_image_size=image_size,
            stride=(2**params.anchor.max_level),
            input_type=input_type,
            input_name=FLAGS.input_name),
        checkpoint_path=FLAGS.checkpoint_path)

    print(' - Done! path: %s' % export_path)
Ejemplo n.º 4
0
def main(argv):
    del argv  # Unused.

    # Configure parameters.
    params = params_dict.ParamsDict(mask_rcnn_config.MASK_RCNN_CFG,
                                    mask_rcnn_config.MASK_RCNN_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params = flags_to_params.override_params_from_input_flags(params, FLAGS)

    params.validate()
    params.lock()

    # Check data path
    train_input_fn = None
    eval_input_fn = None
    if (FLAGS.mode in ('train', 'train_and_eval')
            and not params.training_file_pattern):
        raise RuntimeError(
            'You must specify `training_file_pattern` for training.')
    if FLAGS.mode in ('eval', 'train_and_eval'):
        if not params.validation_file_pattern:
            raise RuntimeError('You must specify `validation_file_pattern` '
                               'for evaluation.')
        if not params.val_json_file and not params.include_groundtruth_in_features:
            raise RuntimeError(
                'You must specify `val_json_file` or '
                'include_groundtruth_in_features=True for evaluation.')

    if FLAGS.mode in ('train', 'train_and_eval'):
        train_input_fn = dataloader.InputReader(
            params.training_file_pattern,
            mode=tf.estimator.ModeKeys.TRAIN,
            use_fake_data=FLAGS.use_fake_data,
            use_instance_mask=params.include_mask)
    if (FLAGS.mode in ('eval', 'train_and_eval')
            or (FLAGS.mode == 'train' and FLAGS.eval_after_training)):
        eval_input_fn = dataloader.InputReader(
            params.validation_file_pattern,
            mode=tf.estimator.ModeKeys.PREDICT,
            num_examples=params.eval_samples,
            use_instance_mask=params.include_mask)

    run_executer(params, train_input_fn, eval_input_fn)
Ejemplo n.º 5
0
def main(_):
  config = params_dict.ParamsDict(mask_rcnn_config.MASK_RCNN_CFG,
                                  mask_rcnn_config.MASK_RCNN_RESTRICTIONS)
  config = params_dict.override_params_dict(
      config, FLAGS.config, is_strict=True)
  config.is_training_bn = False
  config.train_batch_size = FLAGS.batch_size
  config.eval_batch_size = FLAGS.batch_size

  config.validate()
  config.lock()

  model_params = dict(
      list(config.as_dict().items()),
      use_tpu=FLAGS.use_tpu,
      mode=tf.estimator.ModeKeys.PREDICT,
      transpose_input=False)

  print(' - Setting up TPUEstimator...')
  estimator = tf.estimator.tpu.TPUEstimator(
      model_fn=serving.serving_model_fn_builder(
          FLAGS.output_source_id, FLAGS.output_image_info,
          FLAGS.output_box_features, FLAGS.output_normalized_coordinates,
          FLAGS.cast_num_detections_to_float),
      model_dir=FLAGS.model_dir,
      config=tpu_config.RunConfig(
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=FLAGS.iterations_per_loop),
          master='local',
          evaluation_master='local'),
      params=model_params,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      predict_batch_size=FLAGS.batch_size,
      export_to_tpu=FLAGS.use_tpu,
      export_to_cpu=True)

  print(' - Exporting the model...')
  input_type = FLAGS.input_type
  export_path = estimator.export_saved_model(
      export_dir_base=FLAGS.export_dir,
      serving_input_receiver_fn=functools.partial(
          serving.serving_input_fn,
          batch_size=FLAGS.batch_size,
          desired_image_size=config.image_size,
          padding_stride=(2**config.max_level),
          input_type=input_type,
          input_name=FLAGS.input_name),
      checkpoint_path=FLAGS.checkpoint_path)

  if FLAGS.add_warmup_requests and input_type == 'image_bytes':
    inference_warmup.write_warmup_requests(
        export_path,
        FLAGS.model_name,
        config.image_size,
        batch_sizes=[FLAGS.batch_size],
        image_format='JPEG',
        input_signature=FLAGS.input_name)
  print(' - Done! path: %s' % export_path)
def main(_):
    params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
                                    unet_config.UNET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=False)
    params.train_batch_size = FLAGS.batch_size
    params.eval_batch_size = FLAGS.batch_size
    params.use_bfloat16 = False

    model_params = dict(params.as_dict(),
                        use_tpu=FLAGS.use_tpu,
                        mode=tf.estimator.ModeKeys.PREDICT,
                        transpose_input=False)

    print(' - Setting up TPUEstimator...')
    estimator = tf.estimator.tpu.TPUEstimator(
        model_fn=serving_model_fn,
        model_dir=FLAGS.model_dir,
        config=tf.estimator.tpu.RunConfig(
            tpu_config=tf.estimator.tpu.TPUConfig(
                iterations_per_loop=FLAGS.iterations_per_loop),
            master='local',
            evaluation_master='local'),
        params=model_params,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.batch_size,
        predict_batch_size=FLAGS.batch_size,
        export_to_tpu=FLAGS.use_tpu,
        export_to_cpu=True)

    print(' - Exporting the model...')
    input_type = FLAGS.input_type
    export_path = estimator.export_saved_model(
        export_dir_base=FLAGS.export_dir,
        serving_input_receiver_fn=functools.partial(
            serving_input_fn,
            batch_size=FLAGS.batch_size,
            input_type=input_type,
            params=params,
            input_name=FLAGS.input_name),
        checkpoint_path=FLAGS.checkpoint_path)

    print(' - Done! path: %s' % export_path)
Ejemplo n.º 7
0
    def __init__(
        self,
        config_file: str,
        checkpoint_path: str,
        batch_size: int,
        resize_shape: tuple[int, int],
        cache_dir: str,
        device: int | None = None,
    ):
        self.device = device
        self.batch_size = batch_size
        self.resize_shape = resize_shape

        params = config_factory.config_generator("mask_rcnn")
        if config_file:
            params = params_dict.override_params_dict(params,
                                                      config_file,
                                                      is_strict=True)
        params.validate()
        params.lock()
        self.max_level = params.architecture.max_level

        self._model = model_factory.model_generator(params)
        estimator = tf.estimator.Estimator(model_fn=self._model_fn, )

        # Use SavedModel instead of Estimator.predcit()
        # because it is difficult to download images from GCS
        # when executing these codes on Vertex Pipelines.

        with tempfile.TemporaryDirectory() as tmpdir:
            export_dir_parent = cache_dir or tmpdir
            children = list(Path(export_dir_parent).glob("*"))
            if children == []:
                logger.info(f"export saved_model: {export_dir_parent}")
                estimator.export_saved_model(
                    export_dir_base=export_dir_parent,
                    serving_input_receiver_fn=self._serving_input_receiver_fn,
                    checkpoint_path=checkpoint_path,
                )

            children = list(Path(export_dir_parent).glob("*"))
            export_dir = str(children[0])
            logger.info(f"load saved_model from {export_dir}")
            self.saved_model = tf.saved_model.load(export_dir=export_dir)
Ejemplo n.º 8
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)

    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override(
        {
            'platform': {
                'eval_master': FLAGS.eval_master,
                'tpu': FLAGS.tpu,
                'tpu_zone': FLAGS.tpu_zone,
                'gcp_project': FLAGS.gcp_project,
            },
            'tpu_job_name': FLAGS.tpu_job_name,
            'use_tpu': FLAGS.use_tpu,
            'model_dir': FLAGS.model_dir,
            'train': {
                'num_shards': FLAGS.num_cores,
            },
        },
        is_strict=False)
    # Only run spatial partitioning in training mode.
    if FLAGS.mode != 'train':
        params.train.input_partition_dims = None
        params.train.num_cores_per_replica = None

    params.validate()
    params.lock()
    pp = pprint.PrettyPrinter()
    params_str = pp.pformat(params.as_dict())
    tf.logging.info('Model Parameters: {}'.format(params_str))

    # Builds detection model on TPUs.
    model_fn = model_builder.ModelFn(params)
    executor = tpu_executor.TpuExecutor(model_fn, params)

    # Prepares input functions for train and eval.
    train_input_fn = input_reader.InputFn(params.train.train_file_pattern,
                                          params,
                                          mode=ModeKeys.TRAIN)
    eval_input_fn = input_reader.InputFn(params.eval.eval_file_pattern,
                                         params,
                                         mode=ModeKeys.PREDICT_WITH_GT)

    # Runs the model.
    if FLAGS.mode == 'train':
        save_config(params, params.model_dir)
        executor.train(train_input_fn, params.train.total_steps)
        if FLAGS.eval_after_training:
            executor.prepare_evaluation()
            executor.evaluate(
                eval_input_fn,
                params.eval.eval_samples // params.eval.eval_batch_size)

    elif FLAGS.mode == 'eval':

        def terminate_eval():
            tf.logging.info(
                'Terminating eval after %d seconds of no checkpoints' %
                params.eval.eval_timeout)
            return True

        executor.prepare_evaluation()
        # Runs evaluation when there's a new checkpoint.
        for ckpt in tf.contrib.training.checkpoints_iterator(
                params.model_dir,
                min_interval_secs=params.eval.min_eval_interval,
                timeout=params.eval.eval_timeout,
                timeout_fn=terminate_eval):
            # Terminates eval job when final checkpoint is reached.
            current_step = int(os.path.basename(ckpt).split('-')[1])

            tf.logging.info('Starting to evaluate.')
            try:
                executor.evaluate(
                    eval_input_fn,
                    params.eval.eval_samples // params.eval.eval_batch_size,
                    ckpt)

                if current_step >= params.train.total_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break
            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)

    elif FLAGS.mode == 'train_and_eval':
        save_config(params, params.model_dir)
        executor.prepare_evaluation()
        num_cycles = int(params.train.total_steps /
                         params.eval.num_steps_per_eval)
        for cycle in range(num_cycles):
            tf.logging.info('Start training cycle %d.' % cycle)
            current_cycle_last_train_step = ((cycle + 1) *
                                             params.eval.num_steps_per_eval)
            executor.train(train_input_fn, current_cycle_last_train_step)
            executor.evaluate(
                eval_input_fn,
                params.eval.eval_samples // params.eval.eval_batch_size)
    else:
        tf.logging.info('Mode not found.')
Ejemplo n.º 9
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)

    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override({
        'use_tpu': FLAGS.use_tpu,
        'model_dir': FLAGS.model_dir,
    },
                    is_strict=True)
    if not FLAGS.use_tpu:
        params.override(
            {
                'architecture': {
                    'use_bfloat16': False,
                },
                'batch_norm_activation': {
                    'use_sync_bn': False,
                },
            },
            is_strict=True)
    # Only run spatial partitioning in training mode.
    if FLAGS.mode != 'train':
        params.train.input_partition_dims = None
        params.train.num_cores_per_replica = None
    params_to_save = params_dict.ParamsDict(params)
    params.override(
        {
            'platform': {
                'eval_master': FLAGS.eval_master,
                'tpu': FLAGS.tpu,
                'tpu_zone': FLAGS.tpu_zone,
                'gcp_project': FLAGS.gcp_project,
            },
            'tpu_job_name': FLAGS.tpu_job_name,
            'train': {
                'num_shards': FLAGS.num_cores,
            },
        },
        is_strict=False)

    params.validate()
    params.lock()
    pp = pprint.PrettyPrinter()
    params_str = pp.pformat(params.as_dict())
    logging.info('Model Parameters: %s', params_str)

    # Builds detection model on TPUs.
    model_fn = model_builder.ModelFn(params)
    executor = tpu_executor.TpuExecutor(model_fn, params)

    # Prepares input functions for train and eval.
    train_input_fn = input_reader.InputFn(
        params.train.train_file_pattern,
        params,
        mode=ModeKeys.TRAIN,
        dataset_type=params.train.train_dataset_type)
    if params.eval.type == 'customized':
        eval_input_fn = input_reader.InputFn(
            params.eval.eval_file_pattern,
            params,
            mode=ModeKeys.EVAL,
            dataset_type=params.eval.eval_dataset_type)
    else:
        eval_input_fn = input_reader.InputFn(
            params.eval.eval_file_pattern,
            params,
            mode=ModeKeys.PREDICT_WITH_GT,
            dataset_type=params.eval.eval_dataset_type)

    if params.eval.eval_samples:
        eval_times = params.eval.eval_samples // params.eval.eval_batch_size
    else:
        eval_times = None

    # Runs the model.
    if FLAGS.mode == 'train':
        config_utils.save_config(params_to_save, params.model_dir)
        executor.train(train_input_fn, params.train.total_steps)
        if FLAGS.eval_after_training:
            executor.evaluate(eval_input_fn, eval_times)

    elif FLAGS.mode == 'eval':

        def terminate_eval():
            logging.info('Terminating eval after %d seconds of no checkpoints',
                         params.eval.eval_timeout)
            return True

        # Runs evaluation when there's a new checkpoint.
        for ckpt in tf.train.checkpoints_iterator(
                params.model_dir,
                min_interval_secs=params.eval.min_eval_interval,
                timeout=params.eval.eval_timeout,
                timeout_fn=terminate_eval):
            # Terminates eval job when final checkpoint is reached.
            current_step = int(
                six.ensure_str(os.path.basename(ckpt)).split('-')[1])

            logging.info('Starting to evaluate.')
            try:
                executor.evaluate(eval_input_fn, eval_times, ckpt)

                if current_step >= params.train.total_steps:
                    logging.info('Evaluation finished after training step %d',
                                 current_step)
                    break
            except tf.errors.NotFoundError as e:
                logging.info(
                    'Erorr occurred during evaluation: NotFoundError: %s', e)

    elif FLAGS.mode == 'train_and_eval':
        config_utils.save_config(params_to_save, params.model_dir)
        num_cycles = int(params.train.total_steps /
                         params.eval.num_steps_per_eval)
        for cycle in range(num_cycles):
            logging.info('Start training cycle %d.', cycle)
            current_cycle_last_train_step = ((cycle + 1) *
                                             params.eval.num_steps_per_eval)
            executor.train(train_input_fn, current_cycle_last_train_step)
            executor.evaluate(eval_input_fn, eval_times)
    else:
        logging.info('Mode not found.')
Ejemplo n.º 10
0
def main(unused_argv):
    del unused_argv  # Unused

    params = params_dict.ParamsDict({},
                                    mobilenet_config.MOBILENET_RESTRICTIONS)
    params = flags_to_params.override_params_from_input_flags(params, FLAGS)
    params = params_dict.override_params_dict(params,
                                              mobilenet_config.MOBILENET_CFG,
                                              is_strict=False)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)

    input_perm = [0, 1, 2, 3]
    output_perm = [0, 1, 2, 3]

    batch_axis = 0
    batch_size_per_shard = params.train_batch_size // params.num_cores
    if params.transpose_enabled:
        if batch_size_per_shard >= 64:
            input_perm = [3, 0, 1, 2]
            output_perm = [1, 2, 3, 0]
            batch_axis = 3
        else:
            input_perm = [2, 0, 1, 3]
            output_perm = [1, 2, 0, 3]
            batch_axis = 2

    additional_params = {
        'input_perm': input_perm,
        'output_perm': output_perm,
    }
    params = params_dict.override_params_dict(params,
                                              additional_params,
                                              is_strict=False)

    params.validate()
    params.lock()

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu if (FLAGS.tpu or params.use_tpu) else '',
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)

    if params.eval_total_size > 0:
        eval_size = params.eval_total_size
    else:
        eval_size = params.num_eval_images
    eval_steps = eval_size // params.eval_batch_size

    iterations = (eval_steps
                  if FLAGS.mode == 'eval' else params.iterations_per_loop)

    eval_batch_size = (None
                       if FLAGS.mode == 'train' else params.eval_batch_size)

    per_host_input_for_training = (params.num_cores <= 8
                                   if FLAGS.mode == 'train' else True)

    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_secs=FLAGS.save_checkpoints_secs,
        save_summary_steps=FLAGS.save_summary_steps,
        session_config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=iterations,
            per_host_input_for_training=per_host_input_for_training))

    inception_classifier = tf.contrib.tpu.TPUEstimator(
        model_fn=model_fn,
        use_tpu=params.use_tpu,
        config=run_config,
        params=params.as_dict(),
        train_batch_size=params.train_batch_size,
        eval_batch_size=eval_batch_size,
        batch_axis=(batch_axis, 0))

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = supervised_images.InputPipeline(is_training=True,
                                                     data_dir=FLAGS.data_dir)
    imagenet_eval = supervised_images.InputPipeline(is_training=False,
                                                    data_dir=FLAGS.data_dir)

    if params.moving_average:
        eval_hooks = [LoadEMAHook(FLAGS.model_dir)]
    else:
        eval_hooks = []

    if FLAGS.mode == 'eval':

        def terminate_eval():
            tf.logging.info('%d seconds without new checkpoints have elapsed '
                            '... terminating eval' % FLAGS.eval_timeout)
            return True

        def get_next_checkpoint():
            return evaluation.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=params.min_eval_interval,
                timeout=FLAGS.eval_timeout,
                timeout_fn=terminate_eval)

        for checkpoint in get_next_checkpoint():
            tf.logging.info('Starting to evaluate.')
            try:
                eval_results = inception_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    hooks=eval_hooks,
                    checkpoint_path=checkpoint)
                tf.logging.info('Evaluation results: %s' % eval_results)
            except tf.errors.NotFoundError:
                # skip checkpoint if it gets deleted prior to evaluation
                tf.logging.info('Checkpoint %s no longer exists ... skipping')

    elif FLAGS.mode == 'train_and_eval':
        for cycle in range(params.train_steps // params.train_steps_per_eval):
            tf.logging.info('Starting training cycle %d.' % cycle)
            inception_classifier.train(input_fn=imagenet_train.input_fn,
                                       steps=params.train_steps_per_eval)

            tf.logging.info('Starting evaluation cycle %d .' % cycle)
            eval_results = inception_classifier.evaluate(
                input_fn=imagenet_eval.input_fn,
                steps=eval_steps,
                hooks=eval_hooks)
            tf.logging.info('Evaluation results: %s' % eval_results)

    else:
        tf.logging.info('Starting training ...')
        inception_classifier.train(input_fn=imagenet_train.input_fn,
                                   steps=params.train_steps)

    if FLAGS.export_dir:
        tf.logging.info('Starting to export model with image input.')
        inception_classifier.export_saved_model(
            export_dir_base=FLAGS.export_dir,
            serving_input_receiver_fn=image_serving_input_fn)

    if FLAGS.tflite_export_dir:
        tf.logging.info('Starting to export default TensorFlow model.')
        savedmodel_dir = inception_classifier.export_saved_model(
            export_dir_base=FLAGS.tflite_export_dir,
            serving_input_receiver_fn=functools.partial(tensor_serving_input_fn, params))  # pylint: disable=line-too-long
        tf.logging.info('Starting to export TFLite.')
        converter = tf.lite.TFLiteConverter.from_saved_model(
            savedmodel_dir, output_arrays=['softmax_tensor'])
        tflite_file_name = 'mobilenet.tflite'
        if params.post_quantize:
            converter.post_training_quantize = True
            tflite_file_name = 'quantized_' + tflite_file_name
        tflite_file = os.path.join(savedmodel_dir, tflite_file_name)
        tflite_model = converter.convert()
        tf.gfile.GFile(tflite_file, 'wb').write(tflite_model)
def main(argv):
  del argv  # Unused.

  params = factory.config_generator(FLAGS.model)

  if FLAGS.config_file:
    params = params_dict.override_params_dict(
        params, FLAGS.config_file, is_strict=True)

  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  if not FLAGS.use_tpu:
    params.override({
        'architecture': {
            'use_bfloat16': False,
        },
        'batch_norm_activation': {
            'use_sync_bn': False,
        },
    }, is_strict=True)
  params.override({
      'platform': {
          'eval_master': FLAGS.eval_master,
          'tpu': FLAGS.tpu,
          'tpu_zone': FLAGS.tpu_zone,
          'gcp_project': FLAGS.gcp_project,
      },
      'tpu_job_name': FLAGS.tpu_job_name,
      'use_tpu': FLAGS.use_tpu,
      'model_dir': FLAGS.model_dir,
      'train': {
          'num_shards': FLAGS.num_cores,
      },
  }, is_strict=False)
  # Only run spatial partitioning in training mode.
  if FLAGS.mode != 'train':
    params.train.input_partition_dims = None
    params.train.num_cores_per_replica = None

  params.validate()
  params.lock()
  pp = pprint.PrettyPrinter()
  params_str = pp.pformat(params.as_dict())
  logging.info('Model Parameters: %s', params_str)

  # Builds detection model on TPUs.
  model_fn = model_builder.ModelFn(params)
  executor = tpu_executor.TpuExecutor(model_fn, params)

  # Prepares input functions for train and eval.
  train_input_fn = input_reader.InputFn(
      params.train.train_file_pattern, params, mode=ModeKeys.TRAIN,
      dataset_type=params.train.train_dataset_type)
  if params.eval.type == 'customized':
    eval_input_fn = input_reader.InputFn(
        params.eval.eval_file_pattern, params, mode=ModeKeys.EVAL,
        dataset_type=params.eval.eval_dataset_type)
  else:
    eval_input_fn = input_reader.InputFn(
        params.eval.eval_file_pattern, params, mode=ModeKeys.PREDICT_WITH_GT,
        dataset_type=params.eval.eval_dataset_type)

  # Runs the model.
  if FLAGS.mode == 'train':
    config_utils.save_config(params, params.model_dir)
    executor.train(train_input_fn, params.train.total_steps)
    if FLAGS.eval_after_training:
      executor.evaluate(
          eval_input_fn,
          params.eval.eval_samples // params.eval.eval_batch_size)

  elif FLAGS.mode == 'eval':
    def terminate_eval():
      logging.info('Terminating eval after %d seconds of no checkpoints',
                   params.eval.eval_timeout)
      return True
    # Runs evaluation when there's a new checkpoint.
    for ckpt in tf.train.checkpoints_iterator(
        params.model_dir,
        min_interval_secs=params.eval.min_eval_interval,
        timeout=params.eval.eval_timeout,
        timeout_fn=terminate_eval):
      # Terminates eval job when final checkpoint is reached.
      current_step = int(os.path.basename(ckpt).split('-')[1])

      logging.info('Starting to evaluate.')
      try:
        executor.evaluate(
            eval_input_fn,
            params.eval.eval_samples // params.eval.eval_batch_size, ckpt)

        if current_step >= params.train.total_steps:
          logging.info('Evaluation finished after training step %d',
                       current_step)
          break
      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        logging.info('Checkpoint %s no longer exists, skipping checkpoint',
                     ckpt)

  elif FLAGS.mode == 'train_and_eval':
    config_utils.save_config(params, params.model_dir)
    num_cycles = int(params.train.total_steps / params.eval.num_steps_per_eval)
    for cycle in range(num_cycles):
      logging.info('Start training cycle %d.', cycle)
      current_cycle_last_train_step = ((cycle + 1)
                                       * params.eval.num_steps_per_eval)
      executor.train(train_input_fn, current_cycle_last_train_step)
      executor.evaluate(
          eval_input_fn,
          params.eval.eval_samples // params.eval.eval_batch_size)

  elif FLAGS.mode == 'predict':
    file_pattern = FLAGS.predict_file_pattern
    if not file_pattern:
        raise ValueError('"predict_file_pattern" parameter is required.')

    output_dir = FLAGS.predict_output_dir
    if not output_dir:
        raise ValueError('"predict_output_dir" parameter is required.')

    test_input_fn = input_reader.InputFn(
        file_pattern, params, mode=ModeKeys.PREDICT_WITH_GT,
        dataset_type=params.eval.eval_dataset_type)

    checkpoint_prefix = 'model.ckpt-' + FLAGS.predict_checkpoint_step
    checkpoint_path = os.path.join(FLAGS.model_dir, checkpoint_prefix)
    if not tf.train.checkpoint_exists(checkpoint_path):
        checkpoint_path = os.path.join(FLAGS.model_dir, 'best_checkpoints', checkpoint_prefix)
        if not tf.train.checkpoint_exists(checkpoint_path):
            raise ValueError('Checkpoint not found: %s/%s' % (FLAGS.model_dir, checkpoint_prefix))

    executor.predict(test_input_fn, checkpoint_path, output_dir=output_dir)

  else:
    logging.info('Mode not found.')
Ejemplo n.º 12
0
def export(export_dir,
           checkpoint_path,
           model,
           config_file='',
           params_override='',
           use_tpu=False,
           batch_size=1,
           image_size=(1024, 1024),
           input_type='raw_image_tensor',
           input_name='input',
           output_image_info=True,
           output_normalized_coordinates=False,
           cast_num_detections_to_float=False,
           cast_detection_classes_to_float=False):
  """Exports the SavedModel."""
  control_flow_util.enable_control_flow_v2()

  params = factory.config_generator(model)
  if config_file:
    params = params_dict.override_params_dict(
        params, config_file, is_strict=True)
  # Use `is_strict=False` to load params_override with run_time variables like
  # `train.num_shards`.
  params = params_dict.override_params_dict(
      params, params_override, is_strict=False)
  if not use_tpu:
    params.override({
        'architecture': {
            'use_bfloat16': use_tpu,
        },
    }, is_strict=True)
  if batch_size is None:
    params.override({
        'postprocess': {
            'use_batched_nms': True,
        }
    })
  params.validate()
  params.lock()

  model_params = dict(
      params.as_dict(),
      use_tpu=use_tpu,
      mode=tf.estimator.ModeKeys.PREDICT,
      transpose_input=False)
  tf.logging.info('model_params is:\n %s', model_params)

  if model in ['attribute_mask_rcnn']:
    model_fn = serving.serving_model_fn_builder(
        use_tpu, output_image_info, output_normalized_coordinates,
        cast_num_detections_to_float, cast_detection_classes_to_float)
    serving_input_receiver_fn = functools.partial(
        serving.serving_input_fn,
        batch_size=batch_size,
        desired_image_size=image_size,
        stride=(2 ** params.architecture.max_level),
        input_type=input_type,
        input_name=input_name)
  else:
    raise ValueError('The model type `{} is not supported.'.format(model))

  print(' - Setting up TPUEstimator...')
  estimator = tf.estimator.tpu.TPUEstimator(
      model_fn=model_fn,
      model_dir=None,
      config=tf.estimator.tpu.RunConfig(
          tpu_config=tf.estimator.tpu.TPUConfig(iterations_per_loop=1),
          master='local',
          evaluation_master='local'),
      params=model_params,
      use_tpu=use_tpu,
      train_batch_size=batch_size,
      predict_batch_size=batch_size,
      export_to_tpu=use_tpu,
      export_to_cpu=True)

  print(' - Exporting the model...')

  dir_name = os.path.dirname(export_dir)

  if not tf.gfile.Exists(dir_name):
    tf.logging.info('Creating base dir: %s', dir_name)
    tf.gfile.MakeDirs(dir_name)

  export_path = estimator.export_saved_model(
      export_dir_base=dir_name,
      serving_input_receiver_fn=serving_input_receiver_fn,
      checkpoint_path=checkpoint_path)

  tf.logging.info(
      'Exported SavedModel to %s, renaming to %s',
      export_path, export_dir)

  if tf.gfile.Exists(export_dir):
    tf.logging.info('Deleting existing SavedModel dir: %s', export_dir)
    tf.gfile.DeleteRecursively(export_dir)

  tf.gfile.Rename(export_path, export_dir)
Ejemplo n.º 13
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)
    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)
    # Use `is_strict=False` to load params_override with run_time variables like
    # `train.num_shards`.
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=False)
    params.validate()
    params.lock()

    image_size = [int(x) for x in FLAGS.input_image_size.split(',')]

    g = tf.Graph()
    with g.as_default():
        # Build the input.
        _, features = inputs.build_serving_input(
            input_type=FLAGS.input_type,
            batch_size=FLAGS.batch_size,
            desired_image_size=image_size,
            stride=(2**params.anchor.max_level))

        # Build the model.
        print(' - Building the graph...')
        if FLAGS.model in ['retinanet', 'mask_rcnn', 'shapemask']:
            graph_fn = detection.serving_model_graph_builder(
                FLAGS.output_image_info, FLAGS.output_normalized_coordinates,
                FLAGS.cast_num_detections_to_float)
        else:
            raise ValueError('The model type `{}` is not supported.'.format(
                FLAGS.model))

        predictions = graph_fn(features, params)

        # Add a saver for checkpoint loading.
        tf.train.Saver()

        inference_graph_def = g.as_graph_def()
        optimized_graph_def = inference_graph_def

        if FLAGS.optimize_graph:
            print(' - Optimizing the graph...')
            # Trim the unused nodes in the graph.
            output_nodes = [
                output_node.op.name for output_node in predictions.values()
            ]
            # TODO(pengchong): Consider to use `strip_unused_lib.strip_unused` and/or
            # `optimize_for_inference_lib.optimize_for_inference` to trim the graph.
            # Use `optimize_for_inference` if we decide to export the frozen graph
            # (graph + checkpoint) and want explictily fold in batchnorm variables.
            optimized_graph_def = graph_util.remove_training_nodes(
                optimized_graph_def, output_nodes)

    print(' - Saving the graph...')
    tf.train.write_graph(optimized_graph_def, FLAGS.export_dir,
                         'inference_graph.pbtxt')
    print(' - Done!')
Ejemplo n.º 14
0
def main(unused_argv):
  del unused_argv

  params = config_factory.config_generator(FLAGS.model)
  if FLAGS.config_file:
    params = params_dict.override_params_dict(
        params, FLAGS.config_file, is_strict=True)
  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  # We currently only support batch_size = 1 to evaluate images one by one.
  # Override the `eval_batch_size` = 1 here.
  params.override({
      'eval': {
          'eval_batch_size': 1,
      },
  })
  params.validate()
  params.lock()

  model = model_factory.model_generator(params)
  evaluator = evaluator_factory.evaluator_generator(params.eval)

  parse_fn = functools.partial(parse_single_example, params=params)
  with tf.Graph().as_default():
    dataset = tf.data.Dataset.list_files(
        params.eval.eval_file_pattern, shuffle=False)
    dataset = dataset.apply(
        tf.data.experimental.parallel_interleave(
            lambda filename: tf.data.TFRecordDataset(filename).prefetch(1),
            cycle_length=32,
            sloppy=False))
    dataset = dataset.map(parse_fn, num_parallel_calls=64)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(1, drop_remainder=False)

    images, labels, groundtruths = dataset.make_one_shot_iterator().get_next()
    images.set_shape([
        1,
        params.retinanet_parser.output_size[0],
        params.retinanet_parser.output_size[1],
        3])

    # model inference
    outputs = model.build_outputs(images, labels, mode=mode_keys.PREDICT)

    predictions = outputs
    predictions.update({
        'source_id': groundtruths['source_id'],
        'image_info': labels['image_info'],
    })

    # Create a saver in order to load the pre-trained checkpoint.
    saver = tf.train.Saver()

    with tf.Session() as sess:
      saver.restore(sess, FLAGS.checkpoint_path)

      num_batches = params.eval.eval_samples // params.eval.eval_batch_size
      for i in range(num_batches):
        if i % 100 == 0:
          print('{}/{} batches...'.format(i, num_batches))
        predictions_np, groundtruths_np = sess.run([predictions, groundtruths])
        evaluator.update(predictions_np, groundtruths_np)

    if FLAGS.dump_predictions_only:
      print('Dumping the predction results...')
      evaluator.dump_predictions(FLAGS.predictions_path)
      print('Done!')
    else:
      print('Evaluating the prediction results...')
      metrics = evaluator.evaluate()
      print('Eval results: {}'.format(metrics))
Ejemplo n.º 15
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)

    if FLAGS.config_file:
        params = params_dict.override_params_dict(
            params, FLAGS.config_file, is_strict=True)

    params = params_dict.override_params_dict(
        params, FLAGS.params_override, is_strict=True)

    params.train.input_partition_dims = None
    params.train.num_cores_per_replica = None
    params.architecture.use_bfloat16 = False
    # params.maskrcnn_parser.use_autoaugment = False

    params.validate()
    params.lock()

    # Prepares input functions for train and eval.
    train_input_fn = input_reader.InputFnTest(
        params.train.train_file_pattern, params, mode=ModeKeys.TRAIN,
        dataset_type=params.train.train_dataset_type)

    batch_size = 1
    dataset = train_input_fn({'batch_size': batch_size})

    category_index = {}
    for i in range(50):
        category_index[i] = {
            'name': 'test_%d' % i,
        }

    for i, (image_batch, labels_batch) in enumerate(dataset.take(10)):
        image_batch = tf.transpose(image_batch, [3, 0, 1, 2])
        image_batch = tf.map_fn(denormalize_image, image_batch, dtype=tf.uint8, back_prop=False)

        image_shape = tf.shape(image_batch)[1:3]

        masks_batch = []
        for image, bboxes, masks in zip(image_batch, labels_batch['gt_boxes'], labels_batch['gt_masks']):
            # extract masks
            bboxes = tf.numpy_function(box_utils.yxyx_to_xywh, [bboxes], tf.float32)
            binary_masks = tf.numpy_function(mask_utils.paste_instance_masks,
                                             [masks, bboxes, image_shape[0], image_shape[1]],
                                             tf.uint8)

            masks_batch.append(binary_masks)

        masks_batch = tf.stack(masks_batch, axis=0)

        scores_mask = tf.cast(tf.greater(labels_batch['gt_classes'], -1), tf.float32)
        scores = tf.ones_like(labels_batch['gt_classes'], dtype=tf.float32) * scores_mask

        images = draw_bounding_boxes_on_image_tensors(image_batch,
                                                      labels_batch['gt_boxes'],
                                                      labels_batch['gt_classes'],
                                                      scores,
                                                      category_index,
                                                      instance_masks=masks_batch,
                                                      use_normalized_coordinates=False)

        for j, image in enumerate(images):
            image_bytes = tf.io.encode_jpeg(image)
            tf.io.write_file(root_dir('data/visualizations/aug_%d.jpg' % (i * batch_size + j)), image_bytes)
Ejemplo n.º 16
0
def main(unused_argv):
    params = params_dict.ParamsDict(squeezenet_config.SQUEEZENET_CFG,
                                    squeezenet_config.SQUEEZENET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)

    params = flags_to_params.override_params_from_input_flags(params, FLAGS)

    total_steps = (
        (params.train.num_epochs * params.train.num_examples_per_epoch) //
        params.train.train_batch_size)
    params.override(
        {
            "train": {
                "total_steps": total_steps
            },
            "eval": {
                "num_steps_per_eval": (total_steps // params.eval.num_evals)
            },
        },
        is_strict=False)

    params.validate()
    params.lock()

    tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    if not params.use_async_checkpointing:
        save_checkpoints_steps = max(5000, params.train.iterations_per_loop)

    run_config = contrib_tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False),
        tpu_config=contrib_tpu.TPUConfig(
            iterations_per_loop=params.train.iterations_per_loop,
            num_shards=params.train.num_cores_per_replica,
        ),
    )

    estimator = contrib_tpu.TPUEstimator(
        model_fn=squeezenet_model.model_fn,
        use_tpu=params.use_tpu,
        config=run_config,
        train_batch_size=params.train.train_batch_size,
        eval_batch_size=params.eval.eval_batch_size,
        params=params.as_dict(),
    )

    for eval_cycle in range(params.eval.num_evals):
        current_cycle_last_train_step = ((eval_cycle + 1) *
                                         params.eval.num_steps_per_eval)
        estimator.train(input_fn=data_pipeline.InputReader(FLAGS.data_dir,
                                                           is_training=True),
                        steps=current_cycle_last_train_step)

        tf.logging.info("Running evaluation")
        tf.logging.info(
            "%s",
            estimator.evaluate(input_fn=data_pipeline.InputReader(
                FLAGS.data_dir, is_training=False),
                               steps=(params.eval.num_eval_examples //
                                      params.eval.eval_batch_size)))
Ejemplo n.º 17
0
def initiate():
    # Load the label map.
    print(' - Loading the label map...')
    label_map_dict = {}
    if 'csv' == 'csv':
        with tf.gfile.Open('dataset/fashionpedia_label_map.csv',
                           'r') as csv_file:
            reader = csv.reader(csv_file, delimiter=':')
            for row in reader:
                if len(row) != 2:
                    raise ValueError(
                        'Each row of the csv label map file must be in '
                        '`id:name` format.')
                id_index = int(row[0])
                name = row[1]
                label_map_dict[id_index] = {
                    'id': id_index,
                    'name': name,
                }
    else:
        raise ValueError('Unsupported label map format: {}.'.format('csv'))

    params = config_factory.config_generator('attribute_mask_rcnn')
    if 'configs/yaml/spinenet49_amrcnn.yaml':
        params = params_dict.override_params_dict(
            params, 'configs/yaml/spinenet49_amrcnn.yaml', is_strict=True)
    params = params_dict.override_params_dict(params, '', is_strict=True)
    params.override(
        {
            'architecture': {
                'use_bfloat16': False,  # The inference runs on CPU/GPU.
            },
        },
        is_strict=True)
    params.validate()
    params.lock()

    model = model_factory.model_generator(params)

    with tf.Graph().as_default():
        image_input = tf.placeholder(shape=(), dtype=tf.string)
        image = tf.io.decode_image(image_input, channels=3)
        image.set_shape([None, None, 3])

        image = input_utils.normalize_image(image)
        image_size = [640, 640]
        image, image_info = input_utils.resize_and_crop_image(
            image,
            image_size,
            image_size,
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image.set_shape([image_size[0], image_size[1], 3])

        # batching.
        images = tf.reshape(image, [1, image_size[0], image_size[1], 3])
        images_info = tf.expand_dims(image_info, axis=0)

        # model inference
        outputs = model.build_outputs(images, {'image_info': images_info},
                                      mode=mode_keys.PREDICT)

        outputs['detection_boxes'] = (
            outputs['detection_boxes'] /
            tf.tile(images_info[:, 2:3, :], [1, 1, 2]))

        predictions = outputs

        # Create a saver in order to load the pre-trained checkpoint.
        saver = tf.train.Saver()
        sess = tf.Session()
        print(' - Loading the checkpoint...')
        saver.restore(sess, 'fashionpedia-spinenet-49/model.ckpt')
        print(' - Checkpoint Loaded...')
        return sess, predictions, image_input
Ejemplo n.º 18
0
def main(unused_argv):
    del unused_argv
    # Load the label map.
    print(' - Loading the label map...')
    label_map_dict = {}
    if FLAGS.label_map_format == 'csv':
        with tf.gfile.Open(FLAGS.label_map_file, 'r') as csv_file:
            reader = csv.reader(csv_file, delimiter=':')
            for row in reader:
                if len(row) != 2:
                    raise ValueError(
                        'Each row of the csv label map file must be in '
                        '`id:name` format.')
                id_index = int(row[0])
                name = row[1]
                label_map_dict[id_index] = {
                    'id': id_index,
                    'name': name,
                }
    else:
        raise ValueError('Unsupported label map format: {}.'.format(
            FLAGS.label_mape_format))

    params = config_factory.config_generator(FLAGS.model)
    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override(
        {
            'architecture': {
                'use_bfloat16': False,  # The inference runs on CPU/GPU.
            },
        },
        is_strict=True)
    params.validate()
    params.lock()

    model = model_factory.model_generator(params)

    with tf.Graph().as_default():
        image_input = tf.placeholder(shape=(), dtype=tf.string)
        image = tf.io.decode_image(image_input, channels=3)
        image.set_shape([None, None, 3])

        image = input_utils.normalize_image(image)
        image_size = [FLAGS.image_size, FLAGS.image_size]
        image, image_info = input_utils.resize_and_crop_image(
            image,
            image_size,
            image_size,
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image.set_shape([image_size[0], image_size[1], 3])

        # batching.
        images = tf.reshape(image, [1, image_size[0], image_size[1], 3])
        images_info = tf.expand_dims(image_info, axis=0)

        # model inference
        outputs = model.build_outputs(images, {'image_info': images_info},
                                      mode=mode_keys.PREDICT)

        # outputs['detection_boxes'] = (
        #     outputs['detection_boxes'] / tf.tile(images_info[:, 2:3, :], [1, 1, 2]))

        predictions = outputs

        # Create a saver in order to load the pre-trained checkpoint.
        saver = tf.train.Saver()

        image_with_detections_list = []
        with tf.Session() as sess:
            print(' - Loading the checkpoint...')
            saver.restore(sess, FLAGS.checkpoint_path)

            image_files = tf.gfile.Glob(FLAGS.image_file_pattern)
            for i, image_file in enumerate(image_files):
                print(' - Processing image %d...' % i)

                with tf.gfile.GFile(image_file, 'rb') as f:
                    image_bytes = f.read()

                image = Image.open(image_file)
                image = image.convert(
                    'RGB')  # needed for images with 4 channels.
                width, height = image.size
                np_image = (np.array(image.getdata()).reshape(
                    height, width, 3).astype(np.uint8))
                print(np_image.shape)

                predictions_np = sess.run(predictions,
                                          feed_dict={image_input: image_bytes})

                logits = predictions_np['logits'][0]
                print(logits.shape)

                labels = np.argmax(logits.squeeze(), -1)
                print(labels.shape)
                print(labels)
                labels = np.array(Image.fromarray(labels.astype('uint8')))
                print(labels.shape)

                plt.imshow(labels)
                plt.savefig(f"temp-{i}.png")
Ejemplo n.º 19
0
def main(unused_argv):
    params = params_dict.ParamsDict(mnasnet_config.MNASNET_CFG,
                                    mnasnet_config.MNASNET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)

    params = flags_to_params.override_params_from_input_flags(params, FLAGS)

    additional_params = {
        'steps_per_epoch': params.num_train_images / params.train_batch_size,
        'quantized_training': FLAGS.quantized_training,
    }

    params = params_dict.override_params_dict(params,
                                              additional_params,
                                              is_strict=False)

    params.validate()
    params.lock()

    if FLAGS.tpu or params.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    if params.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, params.iterations_per_loop)
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=params.iterations_per_loop,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    # Validates Flags.
    if params.precision == 'bfloat16' and params.use_keras:
        raise ValueError(
            'Keras layers do not have full support to bfloat16 activation training.'
            ' You have set precision as %s and use_keras as %s' %
            (params.precision, params.use_keras))

    # Initializes model parameters.
    mnasnet_est = tf.contrib.tpu.TPUEstimator(
        use_tpu=params.use_tpu,
        model_fn=mnasnet_model_fn,
        config=config,
        train_batch_size=params.train_batch_size,
        eval_batch_size=params.eval_batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params.as_dict())

    if FLAGS.mode == 'export_only':
        export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
        return

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=False,
                transpose_input=params.transpose_input,
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=params.transpose_input,
                cache=params.use_cache and is_training,
                image_size=params.input_image_size,
                num_parallel_calls=params.num_parallel_calls,
                use_bfloat16=(params.precision == 'bfloat16'))
            for is_training in [True, False]
        ]

    if FLAGS.mode == 'eval':
        eval_steps = params.num_eval_images // params.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= params.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

        if FLAGS.export_dir:
            export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(  # pylint: disable=protected-access
            FLAGS.model_dir)

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', params.train_steps,
            params.train_steps / params.steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if params.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, params.iterations_per_loop)))
            mnasnet_est.train(input_fn=imagenet_train.input_fn,
                              max_steps=params.train_steps,
                              hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < params.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      params.train_steps)
                mnasnet_est.train(input_fn=imagenet_train.input_fn,
                                  max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=params.num_eval_images // params.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                params.train_steps, elapsed_time)
            if FLAGS.export_dir:
                export(mnasnet_est, FLAGS.export_dir, params,
                       FLAGS.post_quantize)
Ejemplo n.º 20
0
def main(unused_argv):
  params = params_dict.ParamsDict(
      resnet_config.RESNET_CFG, resnet_config.RESNET_RESTRICTIONS)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=True)
  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)

  params = flags_to_params.override_params_from_input_flags(params, FLAGS)

  params.validate()
  params.lock()

  tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu if (FLAGS.tpu or params.use_tpu) else '',
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project)

  if params.use_async_checkpointing:
    save_checkpoints_steps = None
  else:
    save_checkpoints_steps = max(5000, params.iterations_per_loop)
  config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      log_step_count_steps=FLAGS.log_step_count_steps,
      session_config=tf.ConfigProto(
          graph_options=tf.GraphOptions(
              rewrite_options=rewriter_config_pb2.RewriterConfig(
                  disable_meta_optimizer=True))),
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=params.iterations_per_loop,
          num_shards=params.num_cores,
          per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
          .PER_HOST_V2))  # pylint: disable=line-too-long

  resnet_classifier = tf.estimator.tpu.TPUEstimator(
      use_tpu=params.use_tpu,
      model_fn=resnet_model_fn,
      config=config,
      params=params.as_dict(),
      train_batch_size=params.train_batch_size,
      eval_batch_size=params.eval_batch_size,
      export_to_tpu=FLAGS.export_to_tpu)

  assert (params.precision == 'bfloat16' or
          params.precision == 'float32'), (
              'Invalid value for precision parameter; '
              'must be bfloat16 or float32.')
  tf.logging.info('Precision: %s', params.precision)
  use_bfloat16 = params.precision == 'bfloat16'

  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  if FLAGS.bigtable_instance:
    tf.logging.info('Using Bigtable dataset, table %s', FLAGS.bigtable_table)
    select_train, select_eval = _select_tables_from_flags()
    imagenet_train, imagenet_eval = [
        imagenet_input.ImageNetBigtableInput(  # pylint: disable=g-complex-comprehension
            is_training=is_training,
            use_bfloat16=use_bfloat16,
            transpose_input=params.transpose_input,
            selection=selection,
            augment_name=FLAGS.augment_name,
            randaug_num_layers=FLAGS.randaug_num_layers,
            randaug_magnitude=FLAGS.randaug_magnitude)
        for (is_training, selection) in [(True,
                                          select_train), (False, select_eval)]
    ]
  else:
    if FLAGS.data_dir == FAKE_DATA_DIR:
      tf.logging.info('Using fake dataset.')
    else:
      tf.logging.info('Using dataset: %s', FLAGS.data_dir)
    imagenet_train, imagenet_eval = [
        imagenet_input.ImageNetInput(  # pylint: disable=g-complex-comprehension
            is_training=is_training,
            data_dir=FLAGS.data_dir,
            transpose_input=params.transpose_input,
            cache=params.use_cache and is_training,
            image_size=params.image_size,
            num_parallel_calls=params.num_parallel_calls,
            include_background_label=(params.num_label_classes == 1001),
            use_bfloat16=use_bfloat16,
            augment_name=FLAGS.augment_name,
            randaug_num_layers=FLAGS.randaug_num_layers,
            randaug_magnitude=FLAGS.randaug_magnitude)
        for is_training in [True, False]
    ]

  steps_per_epoch = params.num_train_images // params.train_batch_size
  eval_steps = params.num_eval_images // params.eval_batch_size

  if FLAGS.mode == 'eval':

    # Run evaluation when there's a new checkpoint
    for ckpt in tf.train.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info('Starting to evaluate.')
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                        eval_results, elapsed_time)

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split('-')[1])
        if current_step >= params.train_steps:
          tf.logging.info(
              'Evaluation finished after training step %d', current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info(
            'Checkpoint %s no longer exists, skipping checkpoint', ckpt)

  else:   # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    try:
      current_step = tf.train.load_variable(FLAGS.model_dir,
                                            tf.GraphKeys.GLOBAL_STEP)
    except (TypeError, ValueError, tf.errors.NotFoundError):
      current_step = 0
    steps_per_epoch = params.num_train_images // params.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.',
                    params.train_steps,
                    params.train_steps / steps_per_epoch,
                    current_step)

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == 'train':
      hooks = []
      if params.use_async_checkpointing:
        try:
          from tensorflow.contrib.tpu.python.tpu import async_checkpoint  # pylint: disable=g-import-not-at-top
        except ImportError as e:
          logging.exception(
              'Async checkpointing is not supported in TensorFlow 2.x')
          raise e

        hooks.append(
            async_checkpoint.AsyncCheckpointSaverHook(
                checkpoint_dir=FLAGS.model_dir,
                save_steps=max(5000, params.iterations_per_loop)))
      if FLAGS.profile_every_n_steps > 0:
        hooks.append(
            tpu_profiler_hook.TPUProfilerHook(
                save_steps=FLAGS.profile_every_n_steps,
                output_dir=FLAGS.model_dir, tpu=FLAGS.tpu)
            )
      resnet_classifier.train(
          input_fn=imagenet_train.input_fn,
          max_steps=params.train_steps,
          hooks=hooks)

    else:
      assert FLAGS.mode == 'train_and_eval'
      while current_step < params.train_steps:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              params.train_steps)
        resnet_classifier.train(
            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                        next_checkpoint, int(time.time() - start_timestamp))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=params.num_eval_images // params.eval_batch_size)
        tf.logging.info('Eval results at step %d: %s',
                        next_checkpoint, eval_results)

      elapsed_time = int(time.time() - start_timestamp)
      tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                      params.train_steps, elapsed_time)

    if FLAGS.export_dir is not None:
      # The guide to serve a exported TensorFlow model is at:
      #    https://www.tensorflow.org/serving/serving_basic
      tf.logging.info('Starting to export model.')
      export_path = resnet_classifier.export_saved_model(
          export_dir_base=FLAGS.export_dir,
          serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
      if FLAGS.add_warmup_requests:
        inference_warmup.write_warmup_requests(
            export_path,
            FLAGS.model_name,
            params.image_size,
            batch_sizes=FLAGS.inference_batch_sizes,
            image_format='JPEG')
Ejemplo n.º 21
0
def main(unused_argv):
    del unused_argv
    # Load the label map.
    print(' - Loading the label map...')
    label_map_dict = {}
    if FLAGS.label_map_format == 'csv':
        with tf.gfile.Open(FLAGS.label_map_file, 'r') as csv_file:
            reader = csv.reader(csv_file, delimiter=':')
            for row in reader:
                if len(row) != 2:
                    raise ValueError(
                        'Each row of the csv label map file must be in '
                        '`id:name` format.')
                id_index = int(row[0])
                name = row[1]
                label_map_dict[id_index] = {
                    'id': id_index,
                    'name': name,
                }
    else:
        raise ValueError('Unsupported label map format: {}.'.format(
            FLAGS.label_mape_format))

    params = config_factory.config_generator(FLAGS.model)
    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.validate()
    params.lock()

    model = model_factory.model_generator(params)

    with tf.Graph().as_default():
        image_input = tf.placeholder(shape=(), dtype=tf.string)
        image = tf.io.decode_image(image_input, channels=3)
        image.set_shape([None, None, 3])

        image = input_utils.normalize_image(image)
        image_size = [FLAGS.image_size, FLAGS.image_size]
        image, image_info = input_utils.resize_and_crop_image(
            image,
            image_size,
            image_size,
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image.set_shape([image_size[0], image_size[1], 3])

        # batching.
        images = tf.reshape(image, [1, image_size[0], image_size[1], 3])
        images_info = tf.expand_dims(image_info, axis=0)

        # model inference
        outputs = model.build_outputs(images, {'image_info': images_info},
                                      mode=mode_keys.PREDICT)

        outputs['detection_boxes'] = (
            outputs['detection_boxes'] /
            tf.tile(images_info[:, 2:3, :], [1, 1, 2]))

        predictions = outputs

        # Create a saver in order to load the pre-trained checkpoint.
        saver = tf.train.Saver()

        image_with_detections_list = []
        with tf.Session() as sess:
            print(' - Loading the checkpoint...')
            saver.restore(sess, FLAGS.checkpoint_path)

            image_files = tf.gfile.Glob(FLAGS.image_file_pattern)
            for i, image_file in enumerate(image_files):
                print(' - Processing image %d...' % i)

                with tf.gfile.GFile(image_file, 'rb') as f:
                    image_bytes = f.read()

                image = Image.open(image_file)
                image = image.convert(
                    'RGB')  # needed for images with 4 channels.
                width, height = image.size
                np_image = (np.array(image.getdata()).reshape(
                    height, width, 3).astype(np.uint8))

                predictions_np = sess.run(predictions,
                                          feed_dict={image_input: image_bytes})

                num_detections = int(predictions_np['num_detections'][0])
                np_boxes = predictions_np['detection_boxes'][
                    0, :num_detections]
                np_scores = predictions_np['detection_scores'][
                    0, :num_detections]
                np_classes = predictions_np['detection_classes'][
                    0, :num_detections]
                np_classes = np_classes.astype(np.int32)
                np_masks = None
                if 'detection_masks' in predictions_np:
                    instance_masks = predictions_np['detection_masks'][
                        0, :num_detections]
                    np_masks = mask_utils.paste_instance_masks(
                        instance_masks, box_utils.yxyx_to_xywh(np_boxes),
                        height, width)

                image_with_detections = (
                    visualization_utils.
                    visualize_boxes_and_labels_on_image_array(
                        np_image,
                        np_boxes,
                        np_classes,
                        np_scores,
                        label_map_dict,
                        instance_masks=np_masks,
                        use_normalized_coordinates=False,
                        max_boxes_to_draw=FLAGS.max_boxes_to_draw,
                        min_score_thresh=FLAGS.min_score_threshold))
                image_with_detections_list.append(image_with_detections)

    print(' - Saving the outputs...')
    formatted_image_with_detections_list = [
        Image.fromarray(image.astype(np.uint8))
        for image in image_with_detections_list
    ]
    html_str = '<html>'
    image_strs = []
    for formatted_image in formatted_image_with_detections_list:
        with io.BytesIO() as stream:
            formatted_image.save(stream, format='JPEG')
            data_uri = base64.b64encode(stream.getvalue()).decode('utf-8')
        image_strs.append(
            '<img src="data:image/jpeg;base64,{}", height=800>'.format(
                data_uri))
    images_str = ' '.join(image_strs)
    html_str += images_str
    html_str += '</html>'
    with tf.gfile.GFile(FLAGS.output_html, 'w') as f:
        f.write(html_str)
Ejemplo n.º 22
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)
    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)
    # Use `is_strict=False` to load params_override with run_time variables like
    # `train.num_shards`.
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=False)
    params.validate()
    params.lock()

    model_params = dict(params.as_dict(),
                        use_tpu=FLAGS.use_tpu,
                        mode=tf.estimator.ModeKeys.PREDICT,
                        transpose_input=False)

    tf.logging.info('model_params is:\n %s', model_params)

    image_size = [int(x) for x in FLAGS.input_image_size.split(',')]

    if FLAGS.model == 'retinanet':
        model_fn = serving.serving_model_fn_builder(
            FLAGS.use_tpu, FLAGS.output_image_info,
            FLAGS.output_normalized_coordinates,
            FLAGS.cast_num_detections_to_float)
        serving_input_receiver_fn = functools.partial(
            serving.serving_input_fn,
            batch_size=FLAGS.batch_size,
            desired_image_size=image_size,
            stride=(2**params.anchor.max_level),
            input_type=FLAGS.input_type,
            input_name=FLAGS.input_name)
    else:
        raise ValueError('Model %s is not supported.' % params.type)

    print(' - Setting up TPUEstimator...')
    estimator = tf.estimator.tpu.TPUEstimator(
        model_fn=model_fn,
        model_dir=None,
        config=tf.estimator.tpu.RunConfig(
            tpu_config=tf.estimator.tpu.TPUConfig(iterations_per_loop=1),
            master='local',
            evaluation_master='local'),
        params=model_params,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.batch_size,
        predict_batch_size=FLAGS.batch_size,
        export_to_tpu=FLAGS.use_tpu,
        export_to_cpu=True)

    print(' - Exporting the model...')

    dir_name = os.path.dirname(FLAGS.export_dir)

    if not tf.gfile.Exists(dir_name):
        tf.logging.info('Creating base dir: %s', dir_name)
        tf.gfile.MakeDirs(dir_name)

    export_path = estimator.export_saved_model(
        export_dir_base=dir_name,
        serving_input_receiver_fn=serving_input_receiver_fn,
        checkpoint_path=FLAGS.checkpoint_path)

    tf.logging.info('Exported SavedModel to %s, renaming to %s', export_path,
                    FLAGS.export_dir)

    if tf.gfile.Exists(FLAGS.export_dir):
        tf.logging.info('Deleting existing SavedModel dir: %s',
                        FLAGS.export_dir)
        tf.gfile.DeleteRecursively(FLAGS.export_dir)

    tf.gfile.Rename(export_path, FLAGS.export_dir)