def graph_fn(): window = tf.constant([0, 0, 9, 14], tf.float32) corners = tf.constant([[5.0, 5.0, 6.0, 6.0], [-1.0, -2.0, 4.0, 5.0], [2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0], [-10.0, -10.0, -9.0, -9.0], [-100.0, -100.0, 300.0, 600.0]]) boxes = box_list.BoxList(corners) boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]])) pruned, keep_indices = box_list_ops.prune_outside_window(boxes, window) return pruned.get(), pruned.get_field('extra_data'), keep_indices
def test_prune_outside_window_filters_boxes_which_fall_outside_the_window( self): window = tf.constant([0, 0, 9, 14], tf.float32) corners = tf.constant([[5.0, 5.0, 6.0, 6.0], [-1.0, -2.0, 4.0, 5.0], [2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0], [-10.0, -10.0, -9.0, -9.0], [-100.0, -100.0, 300.0, 600.0]]) boxes = box_list.BoxList(corners) boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]])) exp_output = [[5.0, 5.0, 6.0, 6.0], [2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0]] pruned, keep_indices = box_list_ops.prune_outside_window(boxes, window) with self.test_session() as sess: pruned_output = sess.run(pruned.get()) self.assertAllClose(pruned_output, exp_output) keep_indices_out = sess.run(keep_indices) self.assertAllEqual(keep_indices_out, [0, 2, 3]) extra_data_out = sess.run(pruned.get_field('extra_data')) self.assertAllEqual(extra_data_out, [[1], [3], [4]])
def test_prune_outside_window_filters_boxes_which_fall_outside_the_window( self): window = tf.constant([0, 0, 9, 14], tf.float32) corners = tf.constant([[5.0, 5.0, 6.0, 6.0], [-1.0, -2.0, 4.0, 5.0], [2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0], [-10.0, -10.0, -9.0, -9.0], [-100.0, -100.0, 300.0, 600.0]]) boxes = box_list.BoxList(corners) boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]])) exp_output = [[5.0, 5.0, 6.0, 6.0], [2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0]] pruned, keep_indices = box_list_ops.prune_outside_window(boxes, window) with self.test_session() as sess: pruned_output = sess.run(pruned.get()) self.assertAllClose(pruned_output, exp_output) keep_indices_out = sess.run(keep_indices) self.assertAllEqual(keep_indices_out, [0, 2, 3]) extra_data_out = sess.run(pruned.get_field('extra_data')) self.assertAllEqual(extra_data_out, [[1], [3], [4]])
Faster-Rcnn full process: A: prediction_dict = detection_model.predict() a: Train RPN (first stage) 0. preprocessed inputs 1. self._extract_rpn_feature_maps() I: rpn_features_to_crop = self._feature_extractor.extract_proposal_features() II: anchors = self._first_stage_anchor_generator.generate() -->GridAnchorGenerator._generate() -->GridAnchorGenerator.tile_anchors() III: rpn_box_predictor_features = slim.conv2d() 2. self._predict_rpn_proposals() I: box_predictions = self._first_stage_box_predictor.predict() II: objectness_predictions_with_background = box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND() 3. _remove_invalid_anchors_and_predictions I: pruned_anchors_boxlist, keep_indices = box_list_ops.prune_outside_window() II: _batch_gather_kept_indices(box_encodings) _batch_gather_kept_indices(objectness_predictions_with_background) 'Extremely hard to get through!!!!!!' b: Classification (second stage) 1. _predict_second_stage I: flattened_proposal_feature_maps = self._postprocess_rpn() 'Very complicate function!!!!!!' i: self._format_groundtruth_data(): ii: decoded_boxes = self._box_coder.decode(rpn_box_encodings, box_list.BoxList(anchors)) --> faster_rcnn_box_coder.FasterRcnnBoxCoder._decode() objectness_scores = tf.nn.softmax(rpn_objectness_predictions_with_background) iii:proposal_boxlist = post_processing.multiclass_non_max_suppression() iv: padded_proposals = box_list_ops.pad_or_clip_box_list() II: self._compute_second_stage_input_feature_maps()