コード例 #1
0
 def graph_fn():
   window = tf.constant([0, 0, 9, 14], tf.float32)
   corners = tf.constant([[5.0, 5.0, 6.0, 6.0],
                          [-1.0, -2.0, 4.0, 5.0],
                          [2.0, 3.0, 5.0, 9.0],
                          [0.0, 0.0, 9.0, 14.0],
                          [-10.0, -10.0, -9.0, -9.0],
                          [-100.0, -100.0, 300.0, 600.0]])
   boxes = box_list.BoxList(corners)
   boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]]))
   pruned, keep_indices = box_list_ops.prune_outside_window(boxes, window)
   return pruned.get(), pruned.get_field('extra_data'), keep_indices
コード例 #2
0
 def test_prune_outside_window_filters_boxes_which_fall_outside_the_window(
         self):
     window = tf.constant([0, 0, 9, 14], tf.float32)
     corners = tf.constant([[5.0, 5.0, 6.0, 6.0], [-1.0, -2.0, 4.0, 5.0],
                            [2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0],
                            [-10.0, -10.0, -9.0, -9.0],
                            [-100.0, -100.0, 300.0, 600.0]])
     boxes = box_list.BoxList(corners)
     boxes.add_field('extra_data',
                     tf.constant([[1], [2], [3], [4], [5], [6]]))
     exp_output = [[5.0, 5.0, 6.0, 6.0], [2.0, 3.0, 5.0, 9.0],
                   [0.0, 0.0, 9.0, 14.0]]
     pruned, keep_indices = box_list_ops.prune_outside_window(boxes, window)
     with self.test_session() as sess:
         pruned_output = sess.run(pruned.get())
         self.assertAllClose(pruned_output, exp_output)
         keep_indices_out = sess.run(keep_indices)
         self.assertAllEqual(keep_indices_out, [0, 2, 3])
         extra_data_out = sess.run(pruned.get_field('extra_data'))
         self.assertAllEqual(extra_data_out, [[1], [3], [4]])
コード例 #3
0
 def test_prune_outside_window_filters_boxes_which_fall_outside_the_window(
     self):
   window = tf.constant([0, 0, 9, 14], tf.float32)
   corners = tf.constant([[5.0, 5.0, 6.0, 6.0],
                          [-1.0, -2.0, 4.0, 5.0],
                          [2.0, 3.0, 5.0, 9.0],
                          [0.0, 0.0, 9.0, 14.0],
                          [-10.0, -10.0, -9.0, -9.0],
                          [-100.0, -100.0, 300.0, 600.0]])
   boxes = box_list.BoxList(corners)
   boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]]))
   exp_output = [[5.0, 5.0, 6.0, 6.0],
                 [2.0, 3.0, 5.0, 9.0],
                 [0.0, 0.0, 9.0, 14.0]]
   pruned, keep_indices = box_list_ops.prune_outside_window(boxes, window)
   with self.test_session() as sess:
     pruned_output = sess.run(pruned.get())
     self.assertAllClose(pruned_output, exp_output)
     keep_indices_out = sess.run(keep_indices)
     self.assertAllEqual(keep_indices_out, [0, 2, 3])
     extra_data_out = sess.run(pruned.get_field('extra_data'))
     self.assertAllEqual(extra_data_out, [[1], [3], [4]])
コード例 #4
0
ファイル: faster_rcnn_index.py プロジェクト: e271141/models
Faster-Rcnn full process:

A: prediction_dict = detection_model.predict()
	a: Train RPN (first stage)
	  0. preprocessed inputs
	  1. self._extract_rpn_feature_maps()
	    I: rpn_features_to_crop = self._feature_extractor.extract_proposal_features()
	    II: anchors = self._first_stage_anchor_generator.generate()
          -->GridAnchorGenerator._generate() -->GridAnchorGenerator.tile_anchors()
	    III: rpn_box_predictor_features = slim.conv2d()
	  2. self._predict_rpn_proposals()
	    I: box_predictions = self._first_stage_box_predictor.predict()
	    II: objectness_predictions_with_background = 
	                      box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND()
	  3. _remove_invalid_anchors_and_predictions
	    I: pruned_anchors_boxlist, keep_indices = box_list_ops.prune_outside_window()
	    II: _batch_gather_kept_indices(box_encodings)
	        _batch_gather_kept_indices(objectness_predictions_with_background)
	        'Extremely hard to get through!!!!!!'

	b: Classification (second stage)
	  1. _predict_second_stage
	  	I: flattened_proposal_feature_maps = self._postprocess_rpn()
      'Very complicate function!!!!!!'
        i: self._format_groundtruth_data(): 
        ii: decoded_boxes = self._box_coder.decode(rpn_box_encodings, box_list.BoxList(anchors))
                            --> faster_rcnn_box_coder.FasterRcnnBoxCoder._decode()
            objectness_scores = tf.nn.softmax(rpn_objectness_predictions_with_background)
        iii:proposal_boxlist = post_processing.multiclass_non_max_suppression()
        iv: padded_proposals = box_list_ops.pad_or_clip_box_list()
	  	II: self._compute_second_stage_input_feature_maps()