def test_write_graph_and_checkpoint(self):
    tmp_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
    self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                          use_moving_averages=False)
    output_directory = os.path.join(tmp_dir, 'output')
    model_path = os.path.join(output_directory, 'model.ckpt')
    meta_graph_path = model_path + '.meta'
    tf.gfile.MakeDirs(output_directory)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel(
          add_detection_keypoints=True, add_detection_masks=True)
      pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
      pipeline_config.eval_config.use_moving_averages = False
      detection_model = model_builder.build(pipeline_config.model,
                                            is_training=False)
      exporter._build_detection_graph(
          input_type='tf_example',
          detection_model=detection_model,
          input_shape=None,
          output_collection_name='inference_op',
          graph_hook_fn=None)
      saver = tf.train.Saver()
      input_saver_def = saver.as_saver_def()
      exporter.write_graph_and_checkpoint(
          inference_graph_def=tf.get_default_graph().as_graph_def(),
          model_path=model_path,
          input_saver_def=input_saver_def,
          trained_checkpoint_prefix=trained_checkpoint_prefix)

    tf_example_np = np.hstack([self._create_tf_example(
        np.ones((4, 4, 3)).astype(np.uint8))] * 2)
    with tf.Graph().as_default() as od_graph:
      with self.test_session(graph=od_graph) as sess:
        new_saver = tf.train.import_meta_graph(meta_graph_path)
        new_saver.restore(sess, model_path)

        tf_example = od_graph.get_tensor_by_name('tf_example:0')
        boxes = od_graph.get_tensor_by_name('detection_boxes:0')
        scores = od_graph.get_tensor_by_name('detection_scores:0')
        classes = od_graph.get_tensor_by_name('detection_classes:0')
        keypoints = od_graph.get_tensor_by_name('detection_keypoints:0')
        masks = od_graph.get_tensor_by_name('detection_masks:0')
        num_detections = od_graph.get_tensor_by_name('num_detections:0')
        (boxes_np, scores_np, classes_np, keypoints_np, masks_np,
         num_detections_np) = sess.run(
             [boxes, scores, classes, keypoints, masks, num_detections],
             feed_dict={tf_example: tf_example_np})
        self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
                                        [0.5, 0.5, 0.8, 0.8]],
                                       [[0.5, 0.5, 1.0, 1.0],
                                        [0.0, 0.0, 0.0, 0.0]]])
        self.assertAllClose(scores_np, [[0.7, 0.6],
                                        [0.9, 0.0]])
        self.assertAllClose(classes_np, [[1, 2],
                                         [2, 1]])
        self.assertAllClose(keypoints_np, np.arange(48).reshape([2, 2, 6, 2]))
        self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
        self.assertAllClose(num_detections_np, [2, 1])
示例#2
0
def _export_inference_graph(input_type,
                            detection_model,
                            use_moving_averages,
                            trained_checkpoint_prefix,
                            output_directory,
                            additional_output_tensor_names=None,
                            input_shape=None,
                            output_collection_name='inference_op',
                            graph_hook_fn=None,
                            write_inference_graph=False,
                            temp_checkpoint_prefix=''):
    """Export helper."""
    tf.gfile.MakeDirs(output_directory)
    saved_model_path = os.path.join(output_directory, 'saved_model', '00001')
    model_path = os.path.join(output_directory, 'model.ckpt')

    outputs, placeholder_tensor = build_detection_graph(
        input_type=input_type,
        detection_model=detection_model,
        input_shape=input_shape,
        output_collection_name=output_collection_name,
        graph_hook_fn=graph_hook_fn)

    # OpenTPOD: popping unnecessary outputs for object detection inference.
    # see
    # https://github.com/tensorflow/models/blob/master/research/object_detection/core/standard_fields.py
    outputs.pop(fields.DetectionResultFields.detection_multiclass_scores, None)
    outputs.pop(fields.DetectionResultFields.detection_features, None)
    outputs.pop(fields.DetectionResultFields.detection_masks, None)
    outputs.pop(fields.DetectionResultFields.detection_boundaries, None)
    outputs.pop(fields.DetectionResultFields.detection_keypoints, None)
    outputs.pop(fields.DetectionResultFields.raw_detection_boxes, None)
    outputs.pop(fields.DetectionResultFields.raw_detection_scores, None)
    outputs.pop(fields.DetectionResultFields.detection_anchor_indices, None)

    profile_inference_graph(tf.get_default_graph())
    saver_kwargs = {}
    if use_moving_averages:
        if not temp_checkpoint_prefix:
            # This check is to be compatible with both version of SaverDef.
            if os.path.isfile(trained_checkpoint_prefix):
                saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
                temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
            else:
                temp_checkpoint_prefix = tempfile.mkdtemp()
        replace_variable_values_with_moving_averages(
            tf.get_default_graph(), trained_checkpoint_prefix,
            temp_checkpoint_prefix)
        checkpoint_to_use = temp_checkpoint_prefix
    else:
        checkpoint_to_use = trained_checkpoint_prefix

    saver = tf.train.Saver(**saver_kwargs)
    input_saver_def = saver.as_saver_def()

    write_graph_and_checkpoint(
        inference_graph_def=tf.get_default_graph().as_graph_def(),
        model_path=model_path,
        input_saver_def=input_saver_def,
        trained_checkpoint_prefix=checkpoint_to_use)
    if write_inference_graph:
        inference_graph_def = tf.get_default_graph().as_graph_def()
        inference_graph_path = os.path.join(output_directory,
                                            'inference_graph.pbtxt')
        for node in inference_graph_def.node:
            node.device = ''
        with tf.gfile.GFile(inference_graph_path, 'wb') as f:
            f.write(str(inference_graph_def))

    if additional_output_tensor_names is not None:
        output_node_names = ','.join(outputs.keys() +
                                     additional_output_tensor_names)
    else:
        output_node_names = ','.join(outputs.keys())

    write_saved_model(saved_model_path, trained_checkpoint_prefix,
                      placeholder_tensor, outputs)