Beispiel #1
0
  def test_write_frozen_graph(self):
    tmp_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
    self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                          use_moving_averages=True)
    output_directory = os.path.join(tmp_dir, 'output')
    inference_graph_path = os.path.join(output_directory,
                                        'frozen_inference_graph.pb')
    tf.gfile.MakeDirs(output_directory)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel(add_detection_masks=True)
      pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
      pipeline_config.eval_config.use_moving_averages = False
      detection_model = model_builder.build(pipeline_config.model,
                                            is_training=False)
      outputs, _ = exporter._build_detection_graph(
          input_type='tf_example',
          detection_model=detection_model,
          input_shape=None,
          output_collection_name='inference_op',
          graph_hook_fn=None)
      output_node_names = ','.join(outputs.keys())
      saver = tf.train.Saver()
      input_saver_def = saver.as_saver_def()
      frozen_graph_def = exporter.freeze_graph_with_def_protos(
          input_graph_def=tf.get_default_graph().as_graph_def(),
          input_saver_def=input_saver_def,
          input_checkpoint=trained_checkpoint_prefix,
          output_node_names=output_node_names,
          restore_op_name='save/restore_all',
          filename_tensor_name='save/Const:0',
          clear_devices=True,
          initializer_nodes='')
      exporter.write_frozen_graph(inference_graph_path, frozen_graph_def)

    inference_graph = self._load_inference_graph(inference_graph_path)
    tf_example_np = np.expand_dims(self._create_tf_example(
        np.ones((4, 4, 3)).astype(np.uint8)), axis=0)
    with self.test_session(graph=inference_graph) as sess:
      tf_example = inference_graph.get_tensor_by_name('tf_example:0')
      boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
      scores = inference_graph.get_tensor_by_name('detection_scores:0')
      classes = inference_graph.get_tensor_by_name('detection_classes:0')
      masks = inference_graph.get_tensor_by_name('detection_masks:0')
      num_detections = inference_graph.get_tensor_by_name('num_detections:0')
      (boxes_np, scores_np, classes_np, masks_np, num_detections_np) = sess.run(
          [boxes, scores, classes, masks, num_detections],
          feed_dict={tf_example: tf_example_np})
      self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
                                      [0.5, 0.5, 0.8, 0.8]],
                                     [[0.5, 0.5, 1.0, 1.0],
                                      [0.0, 0.0, 0.0, 0.0]]])
      self.assertAllClose(scores_np, [[0.7, 0.6],
                                      [0.9, 0.0]])
      self.assertAllClose(classes_np, [[1, 2],
                                       [2, 1]])
      self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
      self.assertAllClose(num_detections_np, [2, 1])
    def test_write_frozen_graph(self):
        tmp_dir = self.get_temp_dir()
        trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
        self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                              use_moving_averages=True)
        output_directory = os.path.join(tmp_dir, 'output')
        inference_graph_path = os.path.join(output_directory,
                                            'frozen_inference_graph.pb')
        tf.gfile.MakeDirs(output_directory)
        with mock.patch.object(model_builder, 'build',
                               autospec=True) as mock_builder:
            mock_builder.return_value = FakeModel(add_detection_masks=True)
            pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
            pipeline_config.eval_config.use_moving_averages = False
            detection_model = model_builder.build(pipeline_config.model,
                                                  is_training=False)
            outputs, _ = exporter._build_detection_graph(
                input_type='tf_example',
                detection_model=detection_model,
                input_shape=None,
                output_collection_name='inference_op',
                graph_hook_fn=None)
            output_node_names = ','.join(outputs.keys())
            saver = tf.train.Saver()
            input_saver_def = saver.as_saver_def()
            frozen_graph_def = exporter.freeze_graph_with_def_protos(
                input_graph_def=tf.get_default_graph().as_graph_def(),
                input_saver_def=input_saver_def,
                input_checkpoint=trained_checkpoint_prefix,
                output_node_names=output_node_names,
                restore_op_name='save/restore_all',
                filename_tensor_name='save/Const:0',
                clear_devices=True,
                initializer_nodes='')
            exporter.write_frozen_graph(inference_graph_path, frozen_graph_def)

        inference_graph = self._load_inference_graph(inference_graph_path)
        tf_example_np = np.expand_dims(self._create_tf_example(
            np.ones((4, 4, 3)).astype(np.uint8)),
                                       axis=0)
        with self.test_session(graph=inference_graph) as sess:
            tf_example = inference_graph.get_tensor_by_name('tf_example:0')
            boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
            scores = inference_graph.get_tensor_by_name('detection_scores:0')
            classes = inference_graph.get_tensor_by_name('detection_classes:0')
            masks = inference_graph.get_tensor_by_name('detection_masks:0')
            num_detections = inference_graph.get_tensor_by_name(
                'num_detections:0')
            (boxes_np, scores_np, classes_np, masks_np,
             num_detections_np) = sess.run(
                 [boxes, scores, classes, masks, num_detections],
                 feed_dict={tf_example: tf_example_np})
            self.assertAllClose(boxes_np,
                                [[[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.8, 0.8]],
                                 [[0.5, 0.5, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]])
            self.assertAllClose(scores_np, [[0.7, 0.6], [0.9, 0.0]])
            self.assertAllClose(classes_np, [[1, 2], [2, 1]])
            self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
            self.assertAllClose(num_detections_np, [2, 1])