Beispiel #1
0
    def _save_checkpoint_from_mock_model(self,
                                         checkpoint_path,
                                         use_moving_averages,
                                         quantize=False,
                                         num_channels=3):
        g = tf.Graph()
        with g.as_default():
            mock_model = FakeModel()
            inputs = tf.placeholder(tf.float32,
                                    shape=[1, 10, 10, num_channels])
            mock_model.predict(inputs, true_image_shapes=None)
            if use_moving_averages:
                tf.train.ExponentialMovingAverage(0.0).apply()
            tf.train.get_or_create_global_step()
            if quantize:
                graph_rewriter_config = graph_rewriter_pb2.GraphRewriter()
                graph_rewriter_config.quantization.delay = 500000
                graph_rewriter_fn = graph_rewriter_builder.build(
                    graph_rewriter_config, is_training=False)
                graph_rewriter_fn()

            saver = tf.train.Saver()
            init = tf.global_variables_initializer()
            with self.test_session() as sess:
                sess.run(init)
                saver.save(sess, checkpoint_path)
Beispiel #2
0
def main(unused_argv):
    assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
    assert FLAGS.eval_dir, '`eval_dir` is missing.'
    tf.gfile.MakeDirs(FLAGS.eval_dir)
    if FLAGS.pipeline_config_path:
        configs = config_util.get_configs_from_pipeline_file(
            FLAGS.pipeline_config_path)
        tf.gfile.Copy(FLAGS.pipeline_config_path,
                      os.path.join(FLAGS.eval_dir, 'pipeline.config'),
                      overwrite=True)
    else:
        configs = config_util.get_configs_from_multiple_files(
            model_config_path=FLAGS.model_config_path,
            eval_config_path=FLAGS.eval_config_path,
            eval_input_config_path=FLAGS.input_config_path)
        for name, config in [('model.config', FLAGS.model_config_path),
                             ('eval.config', FLAGS.eval_config_path),
                             ('input.config', FLAGS.input_config_path)]:
            tf.gfile.Copy(config,
                          os.path.join(FLAGS.eval_dir, name),
                          overwrite=True)

    model_config = configs['model']
    eval_config = configs['eval_config']
    input_config = configs['eval_input_config']
    if FLAGS.eval_training_data:
        input_config = configs['train_input_config']

    model_fn = functools.partial(model_builder.build,
                                 model_config=model_config,
                                 is_training=False)

    def get_next(config):
        return dataset_builder.make_initializable_iterator(
            dataset_builder.build(config)).get_next()

    create_input_dict_fn = functools.partial(get_next, input_config)

    categories = label_map_util.create_categories_from_labelmap(
        input_config.label_map_path)

    if FLAGS.run_once:
        eval_config.max_evals = 1

    graph_rewriter_fn = None
    if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=False)

    evaluator.evaluate(create_input_dict_fn,
                       model_fn,
                       eval_config,
                       categories,
                       FLAGS.checkpoint_dir,
                       FLAGS.eval_dir,
                       graph_hook_fn=graph_rewriter_fn)
Beispiel #3
0
 def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
     with mock.patch.object(tf.contrib.quantize,
                            'create_eval_graph') as mock_quant_fn:
         with mock.patch.object(
                 tf.contrib.layers,
                 'summarize_collection') as mock_summarize_col:
             graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
             graph_rewriter_proto.quantization.delay = 10
             graph_rewrite_fn = graph_rewriter_builder.build(
                 graph_rewriter_proto, is_training=False)
             graph_rewrite_fn()
             _, kwargs = mock_quant_fn.call_args
             self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
             mock_summarize_col.assert_called_with('quant_vars')
Beispiel #4
0
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           output_collection_name='inference_op',
                           additional_output_tensor_names=None,
                           write_inference_graph=False):
    """Exports inference graph for the model specified in the pipeline config.

  Args:
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
    input_shape: Sets a fixed shape for an `image_tensor` input. If not
      specified, will default to [None, None, None, 3].
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
    additional_output_tensor_names: list of additional output
      tensors to include in the frozen graph.
    write_inference_graph: If true, writes inference graph to disk.
  """
    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)
    graph_rewriter_fn = None
    if pipeline_config.HasField('graph_rewriter'):
        graph_rewriter_config = pipeline_config.graph_rewriter
        graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                         is_training=False)
    _export_inference_graph(input_type,
                            detection_model,
                            pipeline_config.eval_config.use_moving_averages,
                            trained_checkpoint_prefix,
                            output_directory,
                            additional_output_tensor_names,
                            input_shape,
                            output_collection_name,
                            graph_hook_fn=graph_rewriter_fn,
                            write_inference_graph=write_inference_graph)
    pipeline_config.eval_config.use_moving_averages = False
    config_util.save_pipeline_config(pipeline_config, output_directory)
Beispiel #5
0
def main(_):
    assert FLAGS.train_dir, '`train_dir` is missing.'
    if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
    if FLAGS.pipeline_config_path:
        configs = config_util.get_configs_from_pipeline_file(
            FLAGS.pipeline_config_path)
        if FLAGS.task == 0:
            tf.gfile.Copy(FLAGS.pipeline_config_path,
                          os.path.join(FLAGS.train_dir, 'pipeline.config'),
                          overwrite=True)
    else:
        configs = config_util.get_configs_from_multiple_files(
            model_config_path=FLAGS.model_config_path,
            train_config_path=FLAGS.train_config_path,
            train_input_config_path=FLAGS.input_config_path)
        if FLAGS.task == 0:
            for name, config in [('model.config', FLAGS.model_config_path),
                                 ('train.config', FLAGS.train_config_path),
                                 ('input.config', FLAGS.input_config_path)]:
                tf.gfile.Copy(config,
                              os.path.join(FLAGS.train_dir, name),
                              overwrite=True)

    model_config = configs['model']
    train_config = configs['train_config']
    input_config = configs['train_input_config']

    model_fn = functools.partial(model_builder.build,
                                 model_config=model_config,
                                 is_training=True)

    def get_next(config):
        return dataset_builder.make_initializable_iterator(
            dataset_builder.build(config)).get_next()

    create_input_dict_fn = functools.partial(get_next, input_config)

    env = json.loads(os.environ.get('TF_CONFIG', '{}'))
    cluster_data = env.get('cluster', None)
    cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
    task_data = env.get('task', None) or {'type': 'master', 'index': 0}
    task_info = type('TaskSpec', (object, ), task_data)

    # Parameters for a single worker.
    ps_tasks = 0
    worker_replicas = 1
    worker_job_name = 'lonely_worker'
    task = 0
    is_chief = True
    master = ''

    if cluster_data and 'worker' in cluster_data:
        # Number of total worker replicas include "worker"s and the "master".
        worker_replicas = len(cluster_data['worker']) + 1
    if cluster_data and 'ps' in cluster_data:
        ps_tasks = len(cluster_data['ps'])

    if worker_replicas > 1 and ps_tasks < 1:
        raise ValueError(
            'At least 1 ps task is needed for distributed training.')

    if worker_replicas >= 1 and ps_tasks > 0:
        # Set up distributed training.
        server = tf.train.Server(tf.train.ClusterSpec(cluster),
                                 protocol='grpc',
                                 job_name=task_info.type,
                                 task_index=task_info.index)
        if task_info.type == 'ps':
            server.join()
            return

        worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
        task = task_info.index
        is_chief = (task_info.type == 'master')
        master = server.target

    graph_rewriter_fn = None
    if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=True)

    trainer.train(create_input_dict_fn,
                  model_fn,
                  train_config,
                  master,
                  task,
                  FLAGS.num_clones,
                  worker_replicas,
                  FLAGS.clone_on_cpu,
                  ps_tasks,
                  worker_job_name,
                  is_chief,
                  FLAGS.train_dir,
                  graph_hook_fn=graph_rewriter_fn)
Beispiel #6
0
def export_tflite_graph(pipeline_config, trained_checkpoint_prefix, output_dir,
                        add_postprocessing_op, max_detections,
                        max_classes_per_detection):
    """Exports a tflite compatible graph and anchors for ssd detection model.

  Anchors are written to a tensor and tflite compatible graph
  is written to output_dir/tflite_graph.pb.

  Args:
    pipeline_config: a pipeline.proto object containing the configuration for
      SSD model to export.
    trained_checkpoint_prefix: a file prefix for the checkpoint containing the
      trained parameters of the SSD model.
    output_dir: A directory to write the tflite graph and anchor file to.
    add_postprocessing_op: If add_postprocessing_op is true: frozen graph adds a
      TFLite_Detection_PostProcess custom op
    max_detections: Maximum number of detections (boxes) to show
    max_classes_per_detection: Number of classes to display per detection


  Raises:
    ValueError: if the pipeline config contains models other than ssd or uses an
      fixed_shape_resizer and provides a shape as well.
  """
    tf.gfile.MakeDirs(output_dir)
    if pipeline_config.model.WhichOneof('model') != 'ssd':
        raise ValueError('Only ssd models are supported in tflite. '
                         'Found {} in config'.format(
                             pipeline_config.model.WhichOneof('model')))

    num_classes = pipeline_config.model.ssd.num_classes
    nms_score_threshold = {
        pipeline_config.model.ssd.post_processing.batch_non_max_suppression.
        score_threshold
    }
    nms_iou_threshold = {
        pipeline_config.model.ssd.post_processing.batch_non_max_suppression.
        iou_threshold
    }
    scale_values = {
        'y_scale':
        {pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.y_scale},
        'x_scale':
        {pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.x_scale},
        'h_scale': {
            pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.
            height_scale
        },
        'w_scale': {
            pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.
            width_scale
        }
    }

    image_resizer_config = pipeline_config.model.ssd.image_resizer
    image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
    num_channels = _DEFAULT_NUM_CHANNELS
    if image_resizer == 'fixed_shape_resizer':
        height = image_resizer_config.fixed_shape_resizer.height
        width = image_resizer_config.fixed_shape_resizer.width
        if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
            num_channels = 1
        shape = [1, height, width, num_channels]
    else:
        raise ValueError(
            'Only fixed_shape_resizer'
            'is supported with tflite. Found {}'.format(
                image_resizer_config.WhichOneof('image_resizer_oneof')))

    image = tf.placeholder(tf.float32,
                           shape=shape,
                           name='normalized_input_image_tensor')

    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)
    predicted_tensors = detection_model.predict(image, true_image_shapes=None)
    # The score conversion occurs before the post-processing custom op
    _, score_conversion_fn = post_processing_builder.build(
        pipeline_config.model.ssd.post_processing)
    class_predictions = score_conversion_fn(
        predicted_tensors['class_predictions_with_background'])

    with tf.name_scope('raw_outputs'):
        # 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
        #  containing the encoded box predictions. Note that these are raw
        #  predictions and no Non-Max suppression is applied on them and
        #  no decode center size boxes is applied to them.
        tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
        # 'raw_outputs/class_predictions': a float32 tensor of shape
        #  [1, num_anchors, num_classes] containing the class scores for each anchor
        #  after applying score conversion.
        tf.identity(class_predictions, name='class_predictions')
    # 'anchors': a float32 tensor of shape
    #   [4, num_anchors] containing the anchors as a constant node.
    tf.identity(get_const_center_size_encoded_anchors(
        predicted_tensors['anchors']),
                name='anchors')

    # Add global step to the graph, so we know the training step number when we
    # evaluate the model.
    tf.train.get_or_create_global_step()

    # graph rewriter
    if pipeline_config.HasField('graph_rewriter'):
        graph_rewriter_config = pipeline_config.graph_rewriter
        graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                         is_training=False)
        graph_rewriter_fn()

    # freeze the graph
    saver_kwargs = {}
    if pipeline_config.eval_config.use_moving_averages:
        saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
        moving_average_checkpoint = tempfile.NamedTemporaryFile()
        exporter.replace_variable_values_with_moving_averages(
            tf.get_default_graph(), trained_checkpoint_prefix,
            moving_average_checkpoint.name)
        checkpoint_to_use = moving_average_checkpoint.name
    else:
        checkpoint_to_use = trained_checkpoint_prefix

    saver = tf.train.Saver(**saver_kwargs)
    input_saver_def = saver.as_saver_def()
    frozen_graph_def = exporter.freeze_graph_with_def_protos(
        input_graph_def=tf.get_default_graph().as_graph_def(),
        input_saver_def=input_saver_def,
        input_checkpoint=checkpoint_to_use,
        output_node_names=','.join([
            'raw_outputs/box_encodings', 'raw_outputs/class_predictions',
            'anchors'
        ]),
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        clear_devices=True,
        output_graph='',
        initializer_nodes='')

    # Add new operation to do post processing in a custom op (TF Lite only)
    if add_postprocessing_op:
        transformed_graph_def = append_postprocessing_op(
            frozen_graph_def, max_detections, max_classes_per_detection,
            nms_score_threshold, nms_iou_threshold, num_classes, scale_values)
    else:
        # Return frozen without adding post-processing custom op
        transformed_graph_def = frozen_graph_def

    binary_graph = os.path.join(output_dir, 'tflite_graph.pb')
    with tf.gfile.GFile(binary_graph, 'wb') as f:
        f.write(transformed_graph_def.SerializeToString())
    txt_graph = os.path.join(output_dir, 'tflite_graph.pbtxt')
    with tf.gfile.GFile(txt_graph, 'w') as f:
        f.write(str(transformed_graph_def))
Beispiel #7
0
    def model_fn(features, labels, mode, params=None):
        """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
        params = params or {}
        total_loss, train_op, detections, export_outputs = None, None, None, None
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Make sure to set the Keras learning phase. True during training,
        # False for inference.
        tf.keras.backend.set_learning_phase(is_training)
        detection_model = detection_model_fn(is_training=is_training,
                                             add_summaries=(not use_tpu))
        scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:
            labels = unstack_batch(labels,
                                   unpad_groundtruth_tensors=train_config.
                                   unpad_groundtruth_tensors)
        elif mode == tf.estimator.ModeKeys.EVAL:
            # For evaling on train data, it is necessary to check whether groundtruth
            # must be unpadded.
            boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes].
                           get_shape().as_list())
            unpad_groundtruth_tensors = True if boxes_shape[
                1] is not None else False
            labels = unstack_batch(
                labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
            gt_classes_list = labels[
                fields.InputDataFields.groundtruth_classes]
            gt_masks_list = None
            if fields.InputDataFields.groundtruth_instance_masks in labels:
                gt_masks_list = labels[
                    fields.InputDataFields.groundtruth_instance_masks]
            gt_keypoints_list = None
            if fields.InputDataFields.groundtruth_keypoints in labels:
                gt_keypoints_list = labels[
                    fields.InputDataFields.groundtruth_keypoints]
            gt_weights_list = None
            if fields.InputDataFields.groundtruth_weights in labels:
                gt_weights_list = labels[
                    fields.InputDataFields.groundtruth_weights]
            gt_is_crowd_list = None
            if fields.InputDataFields.groundtruth_is_crowd in labels:
                gt_is_crowd_list = labels[
                    fields.InputDataFields.groundtruth_is_crowd]
            detection_model.provide_groundtruth(
                groundtruth_boxes_list=gt_boxes_list,
                groundtruth_classes_list=gt_classes_list,
                groundtruth_masks_list=gt_masks_list,
                groundtruth_keypoints_list=gt_keypoints_list,
                groundtruth_weights_list=gt_weights_list,
                groundtruth_is_crowd_list=gt_is_crowd_list)

        preprocessed_images = features[fields.InputDataFields.image]
        if use_tpu and train_config.use_bfloat16:
            with tf.contrib.tpu.bfloat16_scope():
                prediction_dict = detection_model.predict(
                    preprocessed_images,
                    features[fields.InputDataFields.true_image_shape])
                for k, v in prediction_dict.items():
                    if v.dtype == tf.bfloat16:
                        prediction_dict[k] = tf.cast(v, tf.float32)
        else:
            prediction_dict = detection_model.predict(
                preprocessed_images,
                features[fields.InputDataFields.true_image_shape])
        if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
            detections = detection_model.postprocess(
                prediction_dict,
                features[fields.InputDataFields.true_image_shape],
            )

        if mode == tf.estimator.ModeKeys.TRAIN:
            if train_config.fine_tune_checkpoint and hparams.load_pretrained:
                if not train_config.fine_tune_checkpoint_type:
                    # train_config.from_detection_checkpoint field is deprecated. For
                    # backward compatibility, set train_config.fine_tune_checkpoint_type
                    # based on train_config.from_detection_checkpoint.
                    if train_config.from_detection_checkpoint:
                        train_config.fine_tune_checkpoint_type = 'detection'
                    else:
                        train_config.fine_tune_checkpoint_type = 'classification'
                asg_map = detection_model.restore_map(
                    fine_tune_checkpoint_type=train_config.
                    fine_tune_checkpoint_type,
                    load_all_detection_checkpoint_vars=(
                        train_config.load_all_detection_checkpoint_vars))
                available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        asg_map,
                        train_config.fine_tune_checkpoint,
                        include_global_step=False))
                if use_tpu:

                    def tpu_scaffold():
                        tf.train.init_from_checkpoint(
                            train_config.fine_tune_checkpoint,
                            available_var_map)
                        return tf.train.Scaffold()

                    scaffold_fn = tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(
                        train_config.fine_tune_checkpoint, available_var_map)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            losses_dict = detection_model.loss(
                prediction_dict,
                features[fields.InputDataFields.true_image_shape])
            losses = [loss_tensor for loss_tensor in losses_dict.values()]
            if train_config.add_regularization_loss:
                regularization_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if regularization_losses:
                    regularization_loss = tf.add_n(regularization_losses,
                                                   name='regularization_loss')
                    losses.append(regularization_loss)
                    losses_dict[
                        'Loss/regularization_loss'] = regularization_loss
            total_loss = tf.add_n(losses, name='total_loss')
            losses_dict['Loss/total_loss'] = total_loss

            if 'graph_rewriter_config' in configs:
                graph_rewriter_fn = graph_rewriter_builder.build(
                    configs['graph_rewriter_config'], is_training=is_training)
                graph_rewriter_fn()

            # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
            # can write learning rate summaries on TPU without host calls.
            global_step = tf.train.get_or_create_global_step()
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)

        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
                    training_optimizer)

            # Optionally freeze some layers by setting their gradients to be zero.
            trainable_variables = None
            include_variables = (train_config.update_trainable_variables
                                 if train_config.update_trainable_variables
                                 else None)
            exclude_variables = (train_config.freeze_variables
                                 if train_config.freeze_variables else None)
            trainable_variables = tf.contrib.framework.filter_variables(
                tf.trainable_variables(),
                include_patterns=include_variables,
                exclude_patterns=exclude_variables)

            clip_gradients_value = None
            if train_config.gradient_clipping_by_norm > 0:
                clip_gradients_value = train_config.gradient_clipping_by_norm

            if not use_tpu:
                for var in optimizer_summary_vars:
                    tf.summary.scalar(var.op.name, var)
            summaries = [] if use_tpu else None
            train_op = tf.contrib.layers.optimize_loss(
                loss=total_loss,
                global_step=global_step,
                learning_rate=None,
                clip_gradients=clip_gradients_value,
                optimizer=training_optimizer,
                variables=trainable_variables,
                summaries=summaries,
                name='')  # Preventing scope prefix on all variables.

        if mode == tf.estimator.ModeKeys.PREDICT:
            export_outputs = {
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
                tf.estimator.export.PredictOutput(detections)
            }

        eval_metric_ops = None
        scaffold = None
        if mode == tf.estimator.ModeKeys.EVAL:
            class_agnostic = (fields.DetectionResultFields.detection_classes
                              not in detections)
            groundtruth = _prepare_groundtruth_for_eval(
                detection_model, class_agnostic)
            use_original_images = fields.InputDataFields.original_image in features
            if use_original_images:
                eval_images = tf.cast(
                    tf.image.resize_bilinear(
                        features[fields.InputDataFields.original_image][0:1],
                        features[fields.InputDataFields.
                                 original_image_spatial_shape][0]), tf.uint8)
            else:
                eval_images = features[fields.InputDataFields.image]

            eval_dict = eval_util.result_dict_for_single_example(
                eval_images[0:1],
                features[inputs.HASH_KEY][0],
                detections,
                groundtruth,
                class_agnostic=class_agnostic,
                scale_to_absolute=True)

            if class_agnostic:
                category_index = label_map_util.create_class_agnostic_category_index(
                )
            else:
                category_index = label_map_util.create_category_index_from_labelmap(
                    eval_input_config.label_map_path)
            vis_metric_ops = None
            if not use_tpu and use_original_images:
                eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
                    category_index,
                    max_examples_to_draw=eval_config.num_visualizations,
                    max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
                    min_score_thresh=eval_config.min_score_threshold,
                    use_normalized_coordinates=False)
                vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
                    eval_dict)

            # Eval metrics on a single example.
            eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
                eval_config, category_index.values(), eval_dict)
            for loss_key, loss_tensor in iter(losses_dict.items()):
                eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
            for var in optimizer_summary_vars:
                eval_metric_ops[var.op.name] = (var, tf.no_op())
            if vis_metric_ops is not None:
                eval_metric_ops.update(vis_metric_ops)
            eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

            if eval_config.use_moving_averages:
                variable_averages = tf.train.ExponentialMovingAverage(0.0)
                variables_to_restore = variable_averages.variables_to_restore()
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    variables_to_restore,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours
                )
                scaffold = tf.train.Scaffold(saver=saver)

        # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
        if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                scaffold_fn=scaffold_fn,
                predictions=detections,
                loss=total_loss,
                train_op=train_op,
                eval_metrics=eval_metric_ops,
                export_outputs=export_outputs)
        else:
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=detections,
                                              loss=total_loss,
                                              train_op=train_op,
                                              eval_metric_ops=eval_metric_ops,
                                              export_outputs=export_outputs,
                                              scaffold=scaffold)