def test_write_graph_and_checkpoint(self):
    tmp_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
    self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                          use_moving_averages=False)
    output_directory = os.path.join(tmp_dir, 'output')
    model_path = os.path.join(output_directory, 'model.ckpt')
    meta_graph_path = model_path + '.meta'
    tf.gfile.MakeDirs(output_directory)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel(
          add_detection_keypoints=True, add_detection_masks=True)
      pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
      pipeline_config.eval_config.use_moving_averages = False
      detection_model = model_builder.build(pipeline_config.model,
                                            is_training=False)
      exporter._build_detection_graph(
          input_type='tf_example',
          detection_model=detection_model,
          input_shape=None,
          output_collection_name='inference_op',
          graph_hook_fn=None)
      saver = tf.train.Saver()
      input_saver_def = saver.as_saver_def()
      exporter.write_graph_and_checkpoint(
          inference_graph_def=tf.get_default_graph().as_graph_def(),
          model_path=model_path,
          input_saver_def=input_saver_def,
          trained_checkpoint_prefix=trained_checkpoint_prefix)

    tf_example_np = np.hstack([self._create_tf_example(
        np.ones((4, 4, 3)).astype(np.uint8))] * 2)
    with tf.Graph().as_default() as od_graph:
      with self.test_session(graph=od_graph) as sess:
        new_saver = tf.train.import_meta_graph(meta_graph_path)
        new_saver.restore(sess, model_path)

        tf_example = od_graph.get_tensor_by_name('tf_example:0')
        boxes = od_graph.get_tensor_by_name('detection_boxes:0')
        scores = od_graph.get_tensor_by_name('detection_scores:0')
        classes = od_graph.get_tensor_by_name('detection_classes:0')
        keypoints = od_graph.get_tensor_by_name('detection_keypoints:0')
        masks = od_graph.get_tensor_by_name('detection_masks:0')
        num_detections = od_graph.get_tensor_by_name('num_detections:0')
        (boxes_np, scores_np, classes_np, keypoints_np, masks_np,
         num_detections_np) = sess.run(
             [boxes, scores, classes, keypoints, masks, num_detections],
             feed_dict={tf_example: tf_example_np})
        self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
                                        [0.5, 0.5, 0.8, 0.8]],
                                       [[0.5, 0.5, 1.0, 1.0],
                                        [0.0, 0.0, 0.0, 0.0]]])
        self.assertAllClose(scores_np, [[0.7, 0.6],
                                        [0.9, 0.0]])
        self.assertAllClose(classes_np, [[1, 2],
                                         [2, 1]])
        self.assertAllClose(keypoints_np, np.arange(48).reshape([2, 2, 6, 2]))
        self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
        self.assertAllClose(num_detections_np, [2, 1])
示例#2
0
  def test_write_frozen_graph(self):
    tmp_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
    self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                          use_moving_averages=True)
    output_directory = os.path.join(tmp_dir, 'output')
    inference_graph_path = os.path.join(output_directory,
                                        'frozen_inference_graph.pb')
    tf.gfile.MakeDirs(output_directory)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel(add_detection_masks=True)
      pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
      pipeline_config.eval_config.use_moving_averages = False
      detection_model = model_builder.build(pipeline_config.model,
                                            is_training=False)
      outputs, _ = exporter._build_detection_graph(
          input_type='tf_example',
          detection_model=detection_model,
          input_shape=None,
          output_collection_name='inference_op',
          graph_hook_fn=None)
      output_node_names = ','.join(outputs.keys())
      saver = tf.train.Saver()
      input_saver_def = saver.as_saver_def()
      frozen_graph_def = exporter.freeze_graph_with_def_protos(
          input_graph_def=tf.get_default_graph().as_graph_def(),
          input_saver_def=input_saver_def,
          input_checkpoint=trained_checkpoint_prefix,
          output_node_names=output_node_names,
          restore_op_name='save/restore_all',
          filename_tensor_name='save/Const:0',
          clear_devices=True,
          initializer_nodes='')
      exporter.write_frozen_graph(inference_graph_path, frozen_graph_def)

    inference_graph = self._load_inference_graph(inference_graph_path)
    tf_example_np = np.expand_dims(self._create_tf_example(
        np.ones((4, 4, 3)).astype(np.uint8)), axis=0)
    with self.test_session(graph=inference_graph) as sess:
      tf_example = inference_graph.get_tensor_by_name('tf_example:0')
      boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
      scores = inference_graph.get_tensor_by_name('detection_scores:0')
      classes = inference_graph.get_tensor_by_name('detection_classes:0')
      masks = inference_graph.get_tensor_by_name('detection_masks:0')
      num_detections = inference_graph.get_tensor_by_name('num_detections:0')
      (boxes_np, scores_np, classes_np, masks_np, num_detections_np) = sess.run(
          [boxes, scores, classes, masks, num_detections],
          feed_dict={tf_example: tf_example_np})
      self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
                                      [0.5, 0.5, 0.8, 0.8]],
                                     [[0.5, 0.5, 1.0, 1.0],
                                      [0.0, 0.0, 0.0, 0.0]]])
      self.assertAllClose(scores_np, [[0.7, 0.6],
                                      [0.9, 0.0]])
      self.assertAllClose(classes_np, [[1, 2],
                                       [2, 1]])
      self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
      self.assertAllClose(num_detections_np, [2, 1])
    def test_write_frozen_graph(self):
        tmp_dir = self.get_temp_dir()
        trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
        self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                              use_moving_averages=True)
        output_directory = os.path.join(tmp_dir, 'output')
        inference_graph_path = os.path.join(output_directory,
                                            'frozen_inference_graph.pb')
        tf.gfile.MakeDirs(output_directory)
        with mock.patch.object(model_builder, 'build',
                               autospec=True) as mock_builder:
            mock_builder.return_value = FakeModel(add_detection_masks=True)
            pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
            pipeline_config.eval_config.use_moving_averages = False
            detection_model = model_builder.build(pipeline_config.model,
                                                  is_training=False)
            outputs, _ = exporter._build_detection_graph(
                input_type='tf_example',
                detection_model=detection_model,
                input_shape=None,
                output_collection_name='inference_op',
                graph_hook_fn=None)
            output_node_names = ','.join(outputs.keys())
            saver = tf.train.Saver()
            input_saver_def = saver.as_saver_def()
            frozen_graph_def = exporter.freeze_graph_with_def_protos(
                input_graph_def=tf.get_default_graph().as_graph_def(),
                input_saver_def=input_saver_def,
                input_checkpoint=trained_checkpoint_prefix,
                output_node_names=output_node_names,
                restore_op_name='save/restore_all',
                filename_tensor_name='save/Const:0',
                clear_devices=True,
                initializer_nodes='')
            exporter.write_frozen_graph(inference_graph_path, frozen_graph_def)

        inference_graph = self._load_inference_graph(inference_graph_path)
        tf_example_np = np.expand_dims(self._create_tf_example(
            np.ones((4, 4, 3)).astype(np.uint8)),
                                       axis=0)
        with self.test_session(graph=inference_graph) as sess:
            tf_example = inference_graph.get_tensor_by_name('tf_example:0')
            boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
            scores = inference_graph.get_tensor_by_name('detection_scores:0')
            classes = inference_graph.get_tensor_by_name('detection_classes:0')
            masks = inference_graph.get_tensor_by_name('detection_masks:0')
            num_detections = inference_graph.get_tensor_by_name(
                'num_detections:0')
            (boxes_np, scores_np, classes_np, masks_np,
             num_detections_np) = sess.run(
                 [boxes, scores, classes, masks, num_detections],
                 feed_dict={tf_example: tf_example_np})
            self.assertAllClose(boxes_np,
                                [[[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.8, 0.8]],
                                 [[0.5, 0.5, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]])
            self.assertAllClose(scores_np, [[0.7, 0.6], [0.9, 0.0]])
            self.assertAllClose(classes_np, [[1, 2], [2, 1]])
            self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
            self.assertAllClose(num_detections_np, [2, 1])
示例#4
0
  def test_write_saved_model(self):
    tmp_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
    self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                          use_moving_averages=False)
    output_directory = os.path.join(tmp_dir, 'output')
    saved_model_path = os.path.join(output_directory, 'saved_model')
    tf.gfile.MakeDirs(output_directory)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel(add_detection_masks=True)
      pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
      pipeline_config.eval_config.use_moving_averages = False
      detection_model = model_builder.build(pipeline_config.model,
                                            is_training=False)
      outputs, placeholder_tensor = exporter._build_detection_graph(
          input_type='tf_example',
          detection_model=detection_model,
          input_shape=None,
          output_collection_name='inference_op',
          graph_hook_fn=None)
      output_node_names = ','.join(outputs.keys())
      saver = tf.train.Saver()
      input_saver_def = saver.as_saver_def()
      frozen_graph_def = exporter.freeze_graph_with_def_protos(
          input_graph_def=tf.get_default_graph().as_graph_def(),
          input_saver_def=input_saver_def,
          input_checkpoint=trained_checkpoint_prefix,
          output_node_names=output_node_names,
          restore_op_name='save/restore_all',
          filename_tensor_name='save/Const:0',
          clear_devices=True,
          initializer_nodes='')
      exporter.write_saved_model(
          saved_model_path=saved_model_path,
          frozen_graph_def=frozen_graph_def,
          inputs=placeholder_tensor,
          outputs=outputs)

    tf_example_np = np.hstack([self._create_tf_example(
        np.ones((4, 4, 3)).astype(np.uint8))] * 2)
    with tf.Graph().as_default() as od_graph:
      with self.test_session(graph=od_graph) as sess:
        meta_graph = tf.saved_model.loader.load(
            sess, [tf.saved_model.tag_constants.SERVING], saved_model_path)

        signature = meta_graph.signature_def['serving_default']
        input_tensor_name = signature.inputs['inputs'].name
        tf_example = od_graph.get_tensor_by_name(input_tensor_name)

        boxes = od_graph.get_tensor_by_name(
            signature.outputs['detection_boxes'].name)
        scores = od_graph.get_tensor_by_name(
            signature.outputs['detection_scores'].name)
        classes = od_graph.get_tensor_by_name(
            signature.outputs['detection_classes'].name)
        masks = od_graph.get_tensor_by_name(
            signature.outputs['detection_masks'].name)
        num_detections = od_graph.get_tensor_by_name(
            signature.outputs['num_detections'].name)

        (boxes_np, scores_np, classes_np, masks_np,
         num_detections_np) = sess.run(
             [boxes, scores, classes, masks, num_detections],
             feed_dict={tf_example: tf_example_np})
        self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
                                        [0.5, 0.5, 0.8, 0.8]],
                                       [[0.5, 0.5, 1.0, 1.0],
                                        [0.0, 0.0, 0.0, 0.0]]])
        self.assertAllClose(scores_np, [[0.7, 0.6],
                                        [0.9, 0.0]])
        self.assertAllClose(classes_np, [[1, 2],
                                         [2, 1]])
        self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
        self.assertAllClose(num_detections_np, [2, 1])
    def test_write_saved_model(self):
        tmp_dir = self.get_temp_dir()
        trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
        self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                              use_moving_averages=False)
        output_directory = os.path.join(tmp_dir, 'output')
        saved_model_path = os.path.join(output_directory, 'saved_model')
        tf.gfile.MakeDirs(output_directory)
        with mock.patch.object(model_builder, 'build',
                               autospec=True) as mock_builder:
            mock_builder.return_value = FakeModel(add_detection_keypoints=True,
                                                  add_detection_masks=True)
            pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
            pipeline_config.eval_config.use_moving_averages = False
            detection_model = model_builder.build(pipeline_config.model,
                                                  is_training=False)
            outputs, placeholder_tensor = exporter._build_detection_graph(
                input_type='tf_example',
                detection_model=detection_model,
                input_shape=None,
                output_collection_name='inference_op',
                graph_hook_fn=None)
            output_node_names = ','.join(outputs.keys())
            saver = tf.train.Saver()
            input_saver_def = saver.as_saver_def()
            frozen_graph_def = exporter.freeze_graph_with_def_protos(
                input_graph_def=tf.get_default_graph().as_graph_def(),
                input_saver_def=input_saver_def,
                input_checkpoint=trained_checkpoint_prefix,
                output_node_names=output_node_names,
                restore_op_name='save/restore_all',
                filename_tensor_name='save/Const:0',
                clear_devices=True,
                initializer_nodes='')
            exporter.write_saved_model(saved_model_path=saved_model_path,
                                       frozen_graph_def=frozen_graph_def,
                                       inputs=placeholder_tensor,
                                       outputs=outputs)

        tf_example_np = np.hstack(
            [self._create_tf_example(np.ones((4, 4, 3)).astype(np.uint8))] * 2)
        with tf.Graph().as_default() as od_graph:
            with self.test_session(graph=od_graph) as sess:
                meta_graph = tf.saved_model.loader.load(
                    sess, [tf.saved_model.tag_constants.SERVING],
                    saved_model_path)

                signature = meta_graph.signature_def['serving_default']
                input_tensor_name = signature.inputs['inputs'].name
                tf_example = od_graph.get_tensor_by_name(input_tensor_name)

                boxes = od_graph.get_tensor_by_name(
                    signature.outputs['detection_boxes'].name)
                scores = od_graph.get_tensor_by_name(
                    signature.outputs['detection_scores'].name)
                classes = od_graph.get_tensor_by_name(
                    signature.outputs['detection_classes'].name)
                keypoints = od_graph.get_tensor_by_name(
                    signature.outputs['detection_keypoints'].name)
                masks = od_graph.get_tensor_by_name(
                    signature.outputs['detection_masks'].name)
                num_detections = od_graph.get_tensor_by_name(
                    signature.outputs['num_detections'].name)

                (boxes_np, scores_np, classes_np, keypoints_np, masks_np,
                 num_detections_np) = sess.run([
                     boxes, scores, classes, keypoints, masks, num_detections
                 ],
                                               feed_dict={
                                                   tf_example: tf_example_np
                                               })
                self.assertAllClose(
                    boxes_np, [[[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.8, 0.8]],
                               [[0.5, 0.5, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]])
                self.assertAllClose(scores_np, [[0.7, 0.6], [0.9, 0.0]])
                self.assertAllClose(classes_np, [[1, 2], [2, 1]])
                self.assertAllClose(keypoints_np,
                                    np.arange(48).reshape([2, 2, 6, 2]))
                self.assertAllClose(masks_np,
                                    np.arange(64).reshape([2, 2, 4, 4]))
                self.assertAllClose(num_detections_np, [2, 1])
def main(_):
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(
            os.path.join(FLAGS.result_base, FLAGS.pipeline_config_path),
            'r') as f:
        text_format.Merge(f.read(), pipeline_config)
    text_format.Merge(FLAGS.config_override, pipeline_config)
    if FLAGS.input_shape:
        input_shape = [
            int(dim) if dim != '-1' else None
            for dim in FLAGS.input_shape.split(',')
        ]
    else:
        input_shape = None

    if os.path.exists(FLAGS.model_dir) and os.path.isdir(FLAGS.model_dir):
        shutil.rmtree(FLAGS.model_dir)

    if not FLAGS.trained_checkpoint_prefix:
        path = os.path.join(FLAGS.result_base, FLAGS.trained_checkpoint_path)
        regex = re.compile(r"model\.ckpt-([0-9]+)\.index")
        numbers = [
            int(regex.search(f).group(1)) for f in os.listdir(path)
            if regex.search(f)
        ]
        if not numbers:
            print('No checkpoint found!')
            exit()
        trained_checkpoint_prefix = os.path.join(
            path, 'model.ckpt-{}'.format(max(numbers)))
    else:
        trained_checkpoint_prefix = FLAGS.trained_checkpoint_prefix

    exporter.export_inference_graph(
        FLAGS.input_type,
        pipeline_config,
        trained_checkpoint_prefix,
        FLAGS.model_dir,
        input_shape=input_shape,
        write_inference_graph=FLAGS.write_inference_graph)

    tf.reset_default_graph()
    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)
    exporter._build_detection_graph(input_type=FLAGS.input_type,
                                    detection_model=detection_model,
                                    input_shape=input_shape,
                                    output_collection_name='inference_op',
                                    graph_hook_fn=None)

    with tf.Session() as sess:
        boxes = detection_model.anchors.get()
        anchors = boxes.eval(session=sess)
        output_anchors_as_swift(anchors)

    label_map = get_label_map_dict(
        os.path.join(FLAGS.result_base, FLAGS.label_map_path))
    label_array = [k for k in sorted(label_map, key=label_map.get)]
    with open(os.path.join(FLAGS.model_dir, FLAGS.output_label_path),
              'w') as f:
        json.dump(label_array, f)