Beispiel #1
0
    def evaluate_checkpoints(self):
        print("Loading model... ")

        # Load pipeline config and build a detection model
        configs = config_util.get_configs_from_pipeline_file(
            self._pipeline_config_path)
        self._detection_model = model_builder.build(
            model_config=configs["model"], is_training=False)

        # Load evaluation inputs
        strategy = tf.compat.v2.distribute.get_strategy()
        eval_input = strategy.experimental_distribute_dataset(
            inputs.eval_input(
                eval_config=configs['eval_config'],
                eval_input_config=configs['eval_input_configs'][0],
                model_config=configs['model'],
                model=self._detection_model))

        self._ckpt_paths = tf.train.get_checkpoint_state(
            self._training_loop_path).all_model_checkpoint_paths

        global_step = tf.compat.v2.Variable(0,
                                            trainable=False,
                                            dtype=tf.compat.v2.dtypes.int64)

        results = []
        if os.path.exists(self._eval_results_filename):
            with open(self._eval_results_filename, "rb") as f:
                results = pickle.load(f)

        calculated_steps = [result[0] for result in results]

        for ckpt_path in self._ckpt_paths:
            # Restore checkpoint
            ckpt = tf.compat.v2.train.Checkpoint(model=self._detection_model,
                                                 step=global_step)
            ckpt.restore(ckpt_path).expect_partial()

            step_value = int(global_step.read_value())
            if step_value in calculated_steps:
                print(
                    "Evaluation for global step {:d} already done, skipping..."
                    .format(step_value))
                continue

            print("Running evaluation for checkpoint {}...".format(ckpt_path))
            evaluation = model_lib_v2.eager_eval_loop(
                detection_model=self._detection_model,
                configs=configs,
                eval_dataset=eval_input,
                global_step=global_step)
            results.append((step_value, evaluation))
            calculated_steps.append(step_value)

        with open(self._eval_results_filename, "wb") as f:
            pickle.dump(obj=results, file=f)
def eval_continuously(
    pipeline_config_path,
    config_override=None,
    train_steps=None,
    sample_1_of_n_eval_examples=1,
    sample_1_of_n_eval_on_train_examples=1,
    use_tpu=False,
    override_eval_num_epochs=True,
    postprocess_on_cpu=False,
    model_dir=None,
    checkpoint_dir=None,
    wait_interval=180,
    timeout=3600,
    eval_index=0,
    save_final_config=False,
    **kwargs):
  """Run continuous evaluation of a detection model eagerly.

  This method builds the model, and continously restores it from the most
  recent training checkpoint in the checkpoint directory & evaluates it
  on the evaluation data.

  Args:
    pipeline_config_path: A path to a pipeline config file.
    config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
      override the config from `pipeline_config_path`.
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
    sample_1_of_n_eval_examples: Integer representing how often an eval example
      should be sampled. If 1, will sample all examples.
    sample_1_of_n_eval_on_train_examples: Similar to
      `sample_1_of_n_eval_examples`, except controls the sampling of training
      data for evaluation.
    use_tpu: Boolean, whether training and evaluation should run on TPU.
    override_eval_num_epochs: Whether to overwrite the number of epochs to 1 for
      eval_input.
    postprocess_on_cpu: When use_tpu and postprocess_on_cpu are true,
      postprocess is scheduled on the host cpu.
    model_dir: Directory to output resulting evaluation summaries to.
    checkpoint_dir: Directory that contains the training checkpoints.
    wait_interval: The mimmum number of seconds to wait before checking for a
      new checkpoint.
    timeout: The maximum number of seconds to wait for a checkpoint. Execution
      will terminate if no new checkpoints are found after these many seconds.
    eval_index: int, If given, only evaluate the dataset at the given
      index. By default, evaluates dataset at 0'th index.
    save_final_config: Whether to save the pipeline config file to the model
      directory.
    **kwargs: Additional keyword arguments for configuration override.
  """
  get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
      'get_configs_from_pipeline_file']
  create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
      'create_pipeline_proto_from_configs']
  merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
      'merge_external_params_with_configs']

  configs = get_configs_from_pipeline_file(
      pipeline_config_path, config_override=config_override)
  kwargs.update({
      'sample_1_of_n_eval_examples': sample_1_of_n_eval_examples,
      'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu
  })
  if train_steps is not None:
    kwargs['train_steps'] = train_steps
  if override_eval_num_epochs:
    kwargs.update({'eval_num_epochs': 1})
    tf.logging.warning(
        'Forced number of epochs for all eval validations to be 1.')
  configs = merge_external_params_with_configs(
      configs, None, kwargs_dict=kwargs)
  if model_dir and save_final_config:
    tf.logging.info('Saving pipeline config file to directory {}'.format(
        model_dir))
    pipeline_config_final = create_pipeline_proto_from_configs(configs)
    config_util.save_pipeline_config(pipeline_config_final, model_dir)

  model_config = configs['model']
  train_input_config = configs['train_input_config']
  eval_config = configs['eval_config']
  eval_input_configs = configs['eval_input_configs']
  eval_on_train_input_config = copy.deepcopy(train_input_config)
  eval_on_train_input_config.sample_1_of_n_examples = (
      sample_1_of_n_eval_on_train_examples)
  if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
    tf.logging.warning('Expected number of evaluation epochs is 1, but '
                       'instead encountered `eval_on_train_input_config'
                       '.num_epochs` = '
                       '{}. Overwriting `num_epochs` to 1.'.format(
                           eval_on_train_input_config.num_epochs))
    eval_on_train_input_config.num_epochs = 1

  if kwargs['use_bfloat16']:
    tf.compat.v2.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')

  eval_input_config = eval_input_configs[eval_index]
  strategy = tf.compat.v2.distribute.get_strategy()
  with strategy.scope():
    detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
        model_config=model_config, is_training=True)

  eval_input = strategy.experimental_distribute_dataset(
      inputs.eval_input(
          eval_config=eval_config,
          eval_input_config=eval_input_config,
          model_config=model_config,
          model=detection_model))

  global_step = tf.compat.v2.Variable(
      0, trainable=False, dtype=tf.compat.v2.dtypes.int64)

  optimizer, _ = optimizer_builder.build(
      configs['train_config'].optimizer, global_step=global_step)

  for latest_checkpoint in tf.train.checkpoints_iterator(
      checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval):
    ckpt = tf.compat.v2.train.Checkpoint(
        step=global_step, model=detection_model, optimizer=optimizer)

    # We run the detection_model on dummy inputs in order to ensure that the
    # model and all its variables have been properly constructed. Specifically,
    # this is currently necessary prior to (potentially) creating shadow copies
    # of the model variables for the EMA optimizer.
    if eval_config.use_moving_averages:
      unpad_groundtruth_tensors = (eval_config.batch_size == 1 and not use_tpu)
      _ensure_model_is_built(detection_model, eval_input,
                             unpad_groundtruth_tensors)
      optimizer.shadow_copy(detection_model)

    ckpt.restore(latest_checkpoint).expect_partial()

    if eval_config.use_moving_averages:
      optimizer.swap_weights()

    summary_writer = tf.compat.v2.summary.create_file_writer(
        os.path.join(model_dir, 'eval', eval_input_config.name))
    with summary_writer.as_default():
      eval_metrics = eager_eval_loop(
          detection_model,
          configs,
          eval_input,
          use_tpu=use_tpu,
          postprocess_on_cpu=postprocess_on_cpu,
          global_step=global_step,
          )
    return eval_metrics
def eval_continuously(hparams,
                      pipeline_config_path,
                      config_override=None,
                      train_steps=None,
                      sample_1_of_n_eval_examples=1,
                      sample_1_of_n_eval_on_train_examples=1,
                      use_tpu=False,
                      override_eval_num_epochs=True,
                      postprocess_on_cpu=False,
                      export_to_tpu=None,
                      model_dir=None,
                      checkpoint_dir=None,
                      wait_interval=180,
                      timeout=3600,
                      **kwargs):
    """Run continuous evaluation of a detection model eagerly.

  This method builds the model, and continously restores it from the most
  recent training checkpoint in the checkpoint directory & evaluates it
  on the evaluation data.

  Args:
    hparams: A `HParams`.
    pipeline_config_path: A path to a pipeline config file.
    config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
      override the config from `pipeline_config_path`.
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
    sample_1_of_n_eval_examples: Integer representing how often an eval example
      should be sampled. If 1, will sample all examples.
    sample_1_of_n_eval_on_train_examples: Similar to
      `sample_1_of_n_eval_examples`, except controls the sampling of training
      data for evaluation.
    use_tpu: Boolean, whether training and evaluation should run on TPU.
    override_eval_num_epochs: Whether to overwrite the number of epochs to 1 for
      eval_input.
    postprocess_on_cpu: When use_tpu and postprocess_on_cpu are true,
      postprocess is scheduled on the host cpu.
    export_to_tpu: When use_tpu and export_to_tpu are true,
      `export_savedmodel()` exports a metagraph for serving on TPU besides the
      one on CPU. If export_to_tpu is not provided, we will look for it in
      hparams too.
    model_dir: Directory to output resulting evaluation summaries to.
    checkpoint_dir: Directory that contains the training checkpoints.
    wait_interval: The mimmum number of seconds to wait before checking for a
      new checkpoint.
    timeout: The maximum number of seconds to wait for a checkpoint. Execution
      will terminate if no new checkpoints are found after these many seconds.

    **kwargs: Additional keyword arguments for configuration override.
  """
    get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
        'get_configs_from_pipeline_file']
    merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
        'merge_external_params_with_configs']

    configs = get_configs_from_pipeline_file(pipeline_config_path,
                                             config_override=config_override)
    kwargs.update({
        'sample_1_of_n_eval_examples':
        sample_1_of_n_eval_examples,
        'use_bfloat16':
        configs['train_config'].use_bfloat16 and use_tpu
    })
    if train_steps is not None:
        kwargs['train_steps'] = train_steps
    if override_eval_num_epochs:
        kwargs.update({'eval_num_epochs': 1})
        tf.logging.warning(
            'Forced number of epochs for all eval validations to be 1.')
    configs = merge_external_params_with_configs(configs,
                                                 hparams,
                                                 kwargs_dict=kwargs)
    model_config = configs['model']
    train_input_config = configs['train_input_config']
    eval_config = configs['eval_config']
    eval_input_configs = configs['eval_input_configs']
    eval_on_train_input_config = copy.deepcopy(train_input_config)
    eval_on_train_input_config.sample_1_of_n_examples = (
        sample_1_of_n_eval_on_train_examples)
    if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
        tf.logging.warning('Expected number of evaluation epochs is 1, but '
                           'instead encountered `eval_on_train_input_config'
                           '.num_epochs` = '
                           '{}. Overwriting `num_epochs` to 1.'.format(
                               eval_on_train_input_config.num_epochs))
        eval_on_train_input_config.num_epochs = 1

    if kwargs['use_bfloat16']:
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(
            'mixed_bfloat16')

    detection_model = model_builder.build(model_config=model_config,
                                          is_training=True)

    # Create the inputs.
    eval_inputs = []
    for eval_input_config in eval_input_configs:
        next_eval_input = inputs.eval_input(
            eval_config=eval_config,
            eval_input_config=eval_input_config,
            model_config=model_config,
            model=detection_model)
        eval_inputs.append((eval_input_config.name, next_eval_input))

    # Read export_to_tpu from hparams if not passed.
    if export_to_tpu is None:
        export_to_tpu = hparams.get('export_to_tpu', False)
    tf.logging.info('eval_continuously: use_tpu %s, export_to_tpu %s', use_tpu,
                    export_to_tpu)

    global_step = tf.compat.v2.Variable(0,
                                        trainable=False,
                                        dtype=tf.compat.v2.dtypes.int64)

    for latest_checkpoint in tf.train.checkpoints_iterator(
            checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval):
        ckpt = tf.compat.v2.train.Checkpoint(step=global_step,
                                             model=detection_model)

        ckpt.restore(latest_checkpoint).expect_partial()

        for eval_name, eval_input in eval_inputs:
            summary_writer = tf.compat.v2.summary.create_file_writer(
                model_dir + '/eval' + eval_name)
            with summary_writer.as_default():
                eager_eval_loop(detection_model,
                                configs,
                                eval_input,
                                use_tpu=use_tpu,
                                postprocess_on_cpu=postprocess_on_cpu,
                                global_step=global_step)
def main(unused_argv):
    MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP
    
    get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
          'get_configs_from_pipeline_file']
    configs = get_configs_from_pipeline_file(
          FLAGS.pipeline_config_path, config_override=None)
    
    model_config = configs['model']
    eval_config = configs['eval_config']
    eval_input_configs = configs['eval_input_configs']
    detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
          model_config=model_config, is_training=True)
    
    # Create the inputs.
    eval_inputs = []
    for eval_input_config in eval_input_configs:
        next_eval_input = inputs.eval_input(
            eval_config=eval_config,
            eval_input_config=eval_input_config,
            model_config=model_config,
            model=detection_model)
        eval_inputs.append((eval_input_config.name, next_eval_input))
    
    
    
    if FLAGS.mode != '3':
        # Scan for new model files
        model_list = generate_model_list(FLAGS.model_dir)
        timeout_count = 0
        latest_model = 'ckpt-0' if len(model_list) == 0 else model_list[-1]
        start_model_num = int(latest_model.split('-')[-1])
        while(timeout_count<=FLAGS.eval_timeout):
            print('Wait for {} seconds before starting a new evaluation.'.format(FLAGS.wait_interval))
            time.sleep(FLAGS.wait_interval)
            timeout_count += FLAGS.wait_interval
            model_list = generate_model_list(FLAGS.model_dir)
            
            if model_list[-1] != latest_model:
                timeout_count = 0
                print('\nEvaluating {}'.format(model_list[-1]))
                tStart = time.time()
                latest_model = model_list[-1]
                eval_metric = generate_eval_metric(detection_model, FLAGS.model_dir, latest_model, eval_inputs, configs)
                print('Takes {} to evaluate {}'.format(show_time_taken(tStart), latest_model))
                
                # Automatic labeling restart after at least 10 models produced
                model_num = int(latest_model.split('-')[-1])
                if model_num >= start_model_num+10:
                    eval_index = eval_metric[FLAGS.eval_index]
                    if eval_index >= float(FLAGS.eval_threshold) and FLAGS.mode == '2':
                        break
                

                
    else:
        model_list = generate_model_list(FLAGS.model_dir)
        for model_step in model_list:
            print('\nEvaluating {}'.format(model_step))
            tStart = time.time()
            eval_metric = generate_eval_metric(detection_model, FLAGS.model_dir, model_step, eval_inputs, configs)
            print('Takes {} to evaluate {}'.format(show_time_taken(tStart), model_step))
Beispiel #5
0
def eval_all_checkpoints(pipeline_config_path,
                         config_override=None,
                         train_steps=None,
                         sample_1_of_n_eval_examples=1,
                         sample_1_of_n_eval_on_train_examples=1,
                         use_tpu=False,
                         override_eval_num_epochs=True,
                         postprocess_on_cpu=False,
                         model_dir=None,
                         checkpoint_dir=None,
                         eval_index=None,
                         **kwargs):
    get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
        'get_configs_from_pipeline_file']
    merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
        'merge_external_params_with_configs']

    configs = get_configs_from_pipeline_file(pipeline_config_path,
                                             config_override=config_override)
    kwargs.update({
        'sample_1_of_n_eval_examples':
        sample_1_of_n_eval_examples,
        'use_bfloat16':
        configs['train_config'].use_bfloat16 and use_tpu
    })
    if train_steps is not None:
        kwargs['train_steps'] = train_steps
    if override_eval_num_epochs:
        kwargs.update({'eval_num_epochs': 1})
        tf.logging.warning(
            'Forced number of epochs for all eval validations to be 1.')
    configs = merge_external_params_with_configs(configs,
                                                 None,
                                                 kwargs_dict=kwargs)
    model_config = configs['model']
    train_input_config = configs['train_input_config']
    eval_config = configs['eval_config']
    eval_input_configs = configs['eval_input_configs']
    eval_on_train_input_config = copy.deepcopy(train_input_config)
    eval_on_train_input_config.sample_1_of_n_examples = (
        sample_1_of_n_eval_on_train_examples)
    if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
        tf.logging.warning('Expected number of evaluation epochs is 1, but '
                           'instead encountered `eval_on_train_input_config'
                           '.num_epochs` = '
                           '{}. Overwriting `num_epochs` to 1.'.format(
                               eval_on_train_input_config.num_epochs))
        eval_on_train_input_config.num_epochs = 1

    if kwargs['use_bfloat16']:
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(
            'mixed_bfloat16')

    detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
        model_config=model_config, is_training=True)

    # Create the inputs.
    eval_inputs = []
    for eval_input_config in eval_input_configs:
        next_eval_input = inputs.eval_input(
            eval_config=eval_config,
            eval_input_config=eval_input_config,
            model_config=model_config,
            model=detection_model)
        eval_inputs.append((eval_input_config.name, next_eval_input))

    if eval_index is not None:
        eval_inputs = [eval_inputs[eval_index]]

    global_step = tf.compat.v2.Variable(0,
                                        trainable=False,
                                        dtype=tf.compat.v2.dtypes.int64)
    ckpt = tf.compat.v2.train.Checkpoint(step=global_step,
                                         model=detection_model)
    ckpt_list = []

    import re

    def atoi(text):
        return int(text) if text.isdigit() else text

    def natural_keys(text):
        return [atoi(c) for c in re.split(r'(\d+)', text)]

    for c in os.listdir(checkpoint_dir):
        if c.endswith('.index'):
            ckpt_list.append(os.path.splitext(c)[0])

    ckpt_list.sort(key=natural_keys)
    for eval_name, eval_input in eval_inputs:
        summary_writer = tf.compat.v2.summary.create_file_writer(
            os.path.join(model_dir, 'eval', eval_name))
        for c in ckpt_list:
            ckpt.restore(os.path.join(checkpoint_dir, c)).expect_partial()
            print(c)
            with summary_writer.as_default():
                eager_eval_loop(detection_model,
                                configs,
                                eval_input,
                                use_tpu=use_tpu,
                                postprocess_on_cpu=postprocess_on_cpu,
                                global_step=global_step)
Beispiel #6
0
def eval_continuously(hparams,
                      pipeline_config_path,
                      config_override=None,
                      train_steps=None,
                      sample_1_of_n_eval_examples=1,
                      sample_1_of_n_eval_on_train_examples=1,
                      use_tpu=False,
                      override_eval_num_epochs=True,
                      postprocess_on_cpu=False,
                      export_to_tpu=None,
                      model_dir=None,
                      checkpoint_dir=None,
                      wait_interval=180,
                      **kwargs):

    get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
        'get_configs_from_pipeline_file']
    merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
        'merge_external_params_with_configs']

    configs = get_configs_from_pipeline_file(pipeline_config_path,
                                             config_override=config_override)
    kwargs.update({
        'sample_1_of_n_eval_examples':
        sample_1_of_n_eval_examples,
        'use_bfloat16':
        configs['train_config'].use_bfloat16 and use_tpu
    })
    if train_steps is not None:
        kwargs['train_steps'] = train_steps
    if override_eval_num_epochs:
        kwargs.update({'eval_num_epochs': 1})
        tf.logging.warning(
            'Forced number of epochs for all eval validations to be 1.')
    configs = merge_external_params_with_configs(configs,
                                                 hparams,
                                                 kwargs_dict=kwargs)
    model_config = configs['model']
    train_input_config = configs['train_input_config']
    eval_config = configs['eval_config']
    eval_input_configs = configs['eval_input_configs']
    eval_on_train_input_config = copy.deepcopy(train_input_config)
    eval_on_train_input_config.sample_1_of_n_examples = (
        sample_1_of_n_eval_on_train_examples)
    if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
        tf.logging.warning('Expected number of evaluation epochs is 1, but '
                           'instead encountered `eval_on_train_input_config'
                           '.num_epochs` = '
                           '{}. Overwriting `num_epochs` to 1.'.format(
                               eval_on_train_input_config.num_epochs))
        eval_on_train_input_config.num_epochs = 1

    detection_model = model_builder.build(model_config=model_config,
                                          is_training=True)

    # Create the inputs.
    eval_inputs = []
    for eval_input_config in eval_input_configs:
        next_eval_input = inputs.eval_input(
            eval_config=eval_config,
            eval_input_config=eval_input_config,
            model_config=model_config,
            model=detection_model)
        eval_inputs.append((eval_input_config.name, next_eval_input))

    # Read export_to_tpu from hparams if not passed.
    if export_to_tpu is None:
        export_to_tpu = hparams.get('export_to_tpu', False)
    tf.logging.info('eval_continuously: use_tpu %s, export_to_tpu %s', use_tpu,
                    export_to_tpu)

    global_step = tf.compat.v2.Variable(0,
                                        trainable=False,
                                        dtype=tf.compat.v2.dtypes.int64)

    prev_checkpoint = None
    waiting = False
    while True:
        ckpt = tf.compat.v2.train.Checkpoint(step=global_step,
                                             model=detection_model)
        manager = tf.compat.v2.train.CheckpointManager(ckpt,
                                                       checkpoint_dir,
                                                       max_to_keep=3)

        latest_checkpoint = manager.latest_checkpoint
        if prev_checkpoint == latest_checkpoint:
            if prev_checkpoint is None:
                tf.logging.info(
                    'No checkpoints found yet. Trying again in %s seconds.' %
                    wait_interval)
                time.sleep(wait_interval)
            else:
                if waiting:
                    tf.logging.info(
                        'Terminating eval after %s seconds of no new '
                        'checkpoints.' % wait_interval)
                    break
                else:
                    tf.logging.info(
                        'No new checkpoint found. Will try again '
                        'in %s seconds and terminate if no checkpoint '
                        'appears.' % wait_interval)
                    waiting = True
                    time.sleep(wait_interval)
        else:
            tf.logging.info('New checkpoint found. Starting evaluation.')
            waiting = False
            prev_checkpoint = latest_checkpoint
            ckpt.restore(latest_checkpoint)

            for eval_name, eval_input in eval_inputs:
                summary_writer = tf.compat.v2.summary.create_file_writer(
                    model_dir + '/eval' + eval_name)
                with summary_writer.as_default():
                    eager_eval_loop(detection_model,
                                    configs,
                                    eval_input,
                                    use_tpu=use_tpu,
                                    postprocess_on_cpu=postprocess_on_cpu,
                                    global_step=global_step)
def run(callbacks=None):
    configs = config_util.get_configs_from_pipeline_file(config_file)
    model_config = configs['model']
    detection_model = model_builder.build(model_config=model_config,
                                          is_training=False)

    #tf.config.experimental_run_functions_eagerly(True)

    checkpoint_dir = app.FLAGS.checkpoint_dir
    checkpoint_path = tf.train.latest_checkpoint(
        app.FLAGS.model_dir if not len(checkpoint_dir) > 0 else checkpoint_dir,
        latest_filename=None)
    # Restore checkpoint
    ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
    ckpt.restore(checkpoint_path).expect_partial()

    def get_model_detection_function(model):
        """Get a tf.function for detection."""
        @tf.function
        def detect_fn(image):
            """Detect objects in image."""

            image, shapes = model.preprocess(image)
            prediction_dict = model.predict(image, shapes)
            detections = model.postprocess(prediction_dict, shapes)
            return detections, prediction_dict, tf.reshape(shapes, [-1])

        return detect_fn

    detect_fn = get_model_detection_function(detection_model)

    eval_config = configs['eval_config']
    eval_input_configs = configs['eval_input_configs']

    eval_inputs = []
    for eval_input_config in eval_input_configs:
        next_eval_input = inputs.eval_input(
            eval_config=eval_config,
            eval_input_config=eval_input_config,
            model_config=model_config,
            model=detection_model)
        eval_inputs.append((eval_input_config.name, next_eval_input))

    index = 0
    annotation_min_score = 0.5
    annotations_center = []
    annotations_box = []
    empty_json = "{}"

    for eval_name, eval_input in eval_inputs:
        for images_dict, labels in eval_input:
            orig_image = images_dict['original_image']
            images = tf.cast(orig_image, tf.float32)

            detections, prediction_outputs, shapes = detect_fn(images)
            images_shape = shapes.numpy()[0:2]
            image_shape_multiplier = tf.cast(tf.tile(shapes[0:2], [2]),
                                             tf.float32)

            def shape_boxes(boxes):
                boxes_px = boxes * image_shape_multiplier
                return boxes_px

            def get_boxes(numpy=True):
                boxes = detections['detection_boxes'][i]
                shaped = shape_boxes(boxes)
                if numpy:
                    shaped = shaped.numpy()
                return shaped

            for i in range(images.shape[0]):
                index += 1
                boxes = get_boxes(numpy=False)
                classes = detections['detection_classes'][i] + label_offset
                scores = detections['detection_scores'][i]
                y_angles = detections['detection_y_rotation_angles'][i]
                source_id = images_dict['source_id'][i]
                cv_camera_matrix = images_dict['camera_intrinsic'][i].numpy(
                ).reshape(3, 3)
                cv_distortion_coefficients = images_dict['camera_distortion'][
                    i].numpy()
                if get_poses:
                    from DicePoseFinding import get_dice_pose_results
                    dice_pose_results = get_dice_pose_results(
                        boxes, classes, scores, y_angles, cv_camera_matrix,
                        cv_distortion_coefficients)
                else:
                    dice_pose_results = []
                boxes = boxes.numpy()
                classes = classes.numpy()
                scores = scores.numpy()
                source_id = source_id.numpy()
                if show_images:
                    import matplotlib.patches as patches
                    image = images[i]
                    image_np = image.numpy() / 255.0
                    rotation_angles = detections[
                        'detection_y_rotation_angles'][i].numpy()
                    matplotlib.use('TkAgg')
                    plt.figure(figsize=(12, 16))
                    plt.imshow(image_np)
                    ax = plt.gca()
                    for dice_pose_result in dice_pose_results:
                        approx_pose_result = dice_pose_result.additional_data[
                            'approx_pose_result']
                        from TransformUtil import transform_points_3d
                        import DiceConfig, DiceProjection
                        import cv2
                        local_dots_visible_in_eye_space = DiceProjection.get_local_dots_facing_camera_from_eye_space_pose(
                            dice_pose_result.pose_pyrender)
                        #NB We should use cv-space pose when using cv2.projectPoints
                        if local_dots_visible_in_eye_space.size > 0:
                            pose_points, pose_points_jacobian = cv2.projectPoints(
                                local_dots_visible_in_eye_space,
                                dice_pose_result.pose_cv.rotation_rodrigues,
                                dice_pose_result.pose_cv.translation,
                                cv_camera_matrix, cv_distortion_coefficients)
                            pose_points = np.squeeze(pose_points)
                            ax.scatter(pose_points[:, 0],
                                       pose_points[:, 1],
                                       s=4)
                            comp_pts = dice_pose_result.comparison_points_cv[
                                dice_pose_result.comparison_indices]
                            proj_pts = dice_pose_result.projected_points[
                                dice_pose_result.projected_indices]
                            ax.plot(
                                np.vstack([comp_pts[:, 0], proj_pts[:, 0]]),
                                np.vstack([comp_pts[:, 1], proj_pts[:, 1]]),
                                'g-')
                        local_dots_visible_in_eye_space_approx = DiceProjection.get_local_dots_facing_camera_from_eye_space_pose(
                            approx_pose_result.pose_pyrender)
                        if show_approx and local_dots_visible_in_eye_space_approx.size > 0:
                            pose_points, pose_points_jacobian = cv2.projectPoints(
                                local_dots_visible_in_eye_space_approx,
                                approx_pose_result.pose_cv.rotation_rodrigues,
                                approx_pose_result.pose_cv.translation,
                                cv_camera_matrix, cv_distortion_coefficients)
                            pose_points = np.squeeze(pose_points)
                            ax.scatter(pose_points[:, 0],
                                       pose_points[:, 1],
                                       s=4)
                            comp_pts = approx_pose_result.comparison_points_cv[
                                approx_pose_result.comparison_indices]
                            proj_pts = approx_pose_result.projected_points[
                                approx_pose_result.projected_indices]
                            ax.plot(
                                np.vstack([comp_pts[:, 0], proj_pts[:, 0]]),
                                np.vstack([comp_pts[:, 1], proj_pts[:, 1]]),
                                'y-')
                        try:
                            local_dots_projected = dice_pose_result.additional_data[
                                'local_dots_projected']
                            dot_centers_transformed = dice_pose_result.additional_data[
                                'dot_centers_transformed']
                            #ax.scatter(local_dots_projected[:, 0], local_dots_projected[:, 1], marker='+')
                            #ax.scatter(dot_centers_transformed[:, 0], dot_centers_transformed[:, 1], marker='s')
                        except KeyError:
                            pass
                    for j in range(boxes.shape[0]):
                        box = tuple(boxes[j].tolist())
                        score = scores[j]
                        class_id = classes[j]
                        rot_angle = (rotation_angles[j] %
                                     (2 * np.pi)) * 180. / np.pi
                        if score > min_score:
                            rect = patches.Rectangle((box[1], box[0]),
                                                     box[3] - box[1],
                                                     box[2] - box[0],
                                                     linewidth=1,
                                                     edgecolor='r',
                                                     facecolor='none')
                            ax.add_patch(rect)
                            if class_id > 0:
                                plt.text(
                                    box[1], box[0],
                                    '<{:.2f}>:{}:{:.2f}\n@ {:.0f} {:.0f}'.
                                    format(score, class_id, rot_angle, box[1],
                                           box[0]), {'color': 'r'})
                            else:
                                #plt.text(box[1], box[0], ' {:.2f}:D'.format(score), {'color': 'r'})
                                pass
                    label_rotation_angles = labels['y_rotation_angle'][
                        i].numpy()
                    label_boxes = labels['groundtruth_boxes'][i].numpy()
                    label_classes = labels['groundtruth_classes'][i].numpy()
                    label_classes_id = np.argmax(label_classes, axis=-1)
                    label_box_px = shape_boxes(label_boxes).numpy()
                    num_gt_boxes = labels['num_groundtruth_boxes'].numpy()[0]
                    for j in range(num_gt_boxes):
                        box = tuple(label_box_px[j].tolist())
                        rect = patches.Rectangle((box[1], box[0]),
                                                 box[3] - box[1],
                                                 box[2] - box[0],
                                                 linewidth=1,
                                                 linestyle=':',
                                                 edgecolor='g',
                                                 facecolor='none')
                        ax.add_patch(rect)
                        class_id = label_classes_id[j]
                        rot_angle = (label_rotation_angles[j] %
                                     (2 * np.pi)) * 180. / np.pi
                        if class_id > 0:
                            plt.text(box[3], box[0],
                                     '{}:{:.2f}'.format(class_id, rot_angle),
                                     {'color': 'g'})
                        else:
                            #plt.text(box[3], box[0], 'D', {'color': 'g'})
                            pass
                    plt.show()
                for j in range(boxes.shape[0]):
                    box = boxes[j].tolist()
                    score = scores[j]
                    class_id = classes[j]
                    if score > annotation_min_score:
                        x_min = box[1]
                        y_min = box[0]
                        width = box[3] - box[1]
                        height = box[2] - box[0]
                        x_mid = (box[1] + box[3]) / 2
                        y_mid = (box[0] + box[2]) / 2
                        region_attributes_string = "{\"type\":" + str(
                            class_id) + "}"
                        region_shape_string_box = "{\"name\":\"rect\",\"x\":" + str(
                            x_min) + ",\"y\":" + str(
                                y_min) + ",\"width\":" + str(
                                    width) + ",\"height\":" + str(height) + "}"
                        annotations_box.append([
                            source_id, empty_json, empty_json, empty_json, j,
                            region_shape_string_box, region_attributes_string
                        ])
                        region_shape_string_center = "{\"name\":\"point\",\"cx\":" + str(
                            x_mid) + ",\"cy\":" + str(y_mid) + "}"
                        annotations_center.append([
                            source_id, empty_json, empty_json, empty_json, j,
                            region_shape_string_center,
                            region_attributes_string
                        ])
                print(index)

    if write_annotations:
        import csv
        with open('output_annotation_box.csv', 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            for row in annotations_box:
                writer.writerow(row)
        with open('output_annotation_point.csv', 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            for row in annotations_center:
                writer.writerow(row)

    print("DONE")