def build_whole_detection_network(self, input_img_batch, gtboxes_batch_h, gtboxes_batch_r, gpu_id=0): if self.is_training: gtboxes_batch_h = tf.reshape(gtboxes_batch_h, [-1, 5]) gtboxes_batch_h = tf.cast(gtboxes_batch_h, tf.float32) gtboxes_batch_r = tf.reshape(gtboxes_batch_r, [-1, 6]) gtboxes_batch_r = tf.cast(gtboxes_batch_r, tf.float32) # 1. build base network feature_pyramid = self.build_base_network(input_img_batch) # 2. build rpn rpn_box_pred, rpn_cls_score, rpn_cls_prob = self.rpn_net( feature_pyramid) # 3. generate_anchors anchors = self.make_anchors(feature_pyramid) # 4. postprocess rpn proposals. such as: decode, clip, filter if not self.is_training: with tf.variable_scope('postprocess_detctions'): boxes, scores, category = postprocess_detctions( rpn_bbox_pred=rpn_box_pred, rpn_cls_prob=rpn_cls_prob, anchors=anchors, is_training=self.is_training) return boxes, scores, category # 5. build loss else: with tf.variable_scope('build_loss'): labels, target_delta, anchor_states, target_boxes = tf.py_func( func=anchor_target_layer, inp=[gtboxes_batch_h, gtboxes_batch_r, anchors, gpu_id], Tout=[tf.float32, tf.float32, tf.float32, tf.float32]) if self.method == 'H': self.add_anchor_img_smry(input_img_batch, anchors, anchor_states, 0) else: self.add_anchor_img_smry(input_img_batch, anchors, anchor_states, 1) cls_loss = losses.focal_loss(labels, rpn_cls_score, anchor_states) if cfgs.REG_LOSS_MODE == 0: reg_loss = losses.iou_smooth_l1_loss( target_delta, rpn_box_pred, anchor_states, target_boxes, anchors) elif cfgs.REG_LOSS_MODE == 1: reg_loss = losses.smooth_l1_loss_atan( target_delta, rpn_box_pred, anchor_states) else: reg_loss = losses.smooth_l1_loss(target_delta, rpn_box_pred, anchor_states) losses_dict = { 'cls_loss': cls_loss * cfgs.CLS_WEIGHT, 'reg_loss': reg_loss * cfgs.REG_WEIGHT } with tf.variable_scope('postprocess_detctions'): boxes, scores, category = postprocess_detctions( rpn_bbox_pred=rpn_box_pred, rpn_cls_prob=rpn_cls_prob, anchors=anchors, is_training=self.is_training) boxes = tf.stop_gradient(boxes) scores = tf.stop_gradient(scores) category = tf.stop_gradient(category) return boxes, scores, category, losses_dict
def build_whole_detection_network(self, input_img_batch, gtboxes_batch_h, gtboxes_batch_r, gthead_quadrant, gt_smooth_label, gpu_id=0): if self.is_training: gtboxes_batch_h = tf.reshape(gtboxes_batch_h, [-1, 5]) gtboxes_batch_h = tf.cast(gtboxes_batch_h, tf.float32) gtboxes_batch_r = tf.reshape(gtboxes_batch_r, [-1, 6]) gtboxes_batch_r = tf.cast(gtboxes_batch_r, tf.float32) gthead_quadrant = tf.reshape(gthead_quadrant, [-1, 1]) gthead_quadrant = tf.cast(gthead_quadrant, tf.int32) gt_smooth_label = tf.reshape(gt_smooth_label, [-1, self.angle_range]) gt_smooth_label = tf.cast(gt_smooth_label, tf.float32) img_shape = tf.shape(input_img_batch) # 1. build base network feature_pyramid = self.build_base_network(input_img_batch) # 2. build rpn rpn_box_pred, rpn_cls_score, rpn_cls_prob, rpn_head_cls, rpn_angle_cls = self.rpn_net( feature_pyramid) # 3. generate_anchors anchors = self.make_anchors(feature_pyramid) # 4. postprocess rpn proposals. such as: decode, clip, filter if self.is_training: with tf.variable_scope('build_loss'): labels, target_delta, anchor_states, target_boxes, target_head_quadrant, target_smooth_label = tf.py_func( func=anchor_target_layer, inp=[ gtboxes_batch_h, gtboxes_batch_r, gthead_quadrant, gt_smooth_label, anchors, gpu_id ], Tout=[ tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32 ]) if self.method == 'H': self.add_anchor_img_smry(input_img_batch, anchors, anchor_states, 0) else: self.add_anchor_img_smry(input_img_batch, anchors, anchor_states, 1) cls_loss = losses.focal_loss(labels, rpn_cls_score, anchor_states) if cfgs.REG_LOSS_MODE == 0: reg_loss = losses.iou_smooth_l1_loss( target_delta, rpn_box_pred, anchor_states, target_boxes, anchors) elif cfgs.REG_LOSS_MODE == 1: reg_loss = losses.smooth_l1_loss_atan( target_delta, rpn_box_pred, anchor_states) else: reg_loss = losses.smooth_l1_loss(target_delta, rpn_box_pred, anchor_states) if cfgs.DATASET_NAME.startswith('DOTA'): head_cls_loss = losses.head_specific_cls_focal_loss( target_head_quadrant, rpn_head_cls, anchor_states, labels, specific_cls=[6, 7, 8, 9, 10, 11]) else: head_cls_loss = losses.head_focal_loss( target_head_quadrant, rpn_head_cls, anchor_states) angle_cls_loss = losses.angle_focal_loss( target_smooth_label, rpn_angle_cls, anchor_states) self.losses_dict = { 'cls_loss': cls_loss * cfgs.CLS_WEIGHT, 'reg_loss': reg_loss * cfgs.REG_WEIGHT, 'head_cls_loss': head_cls_loss * cfgs.HEAD_WEIGHT, 'angle_cls_loss': angle_cls_loss * cfgs.ANGLE_WEIGHT } with tf.variable_scope('postprocess_detctions'): boxes, scores, category, boxes_head, boxes_angle = postprocess_detctions( rpn_bbox_pred=rpn_box_pred, rpn_cls_prob=rpn_cls_prob, rpn_angle_prob=tf.sigmoid(rpn_angle_cls), rpn_head_prob=tf.sigmoid(rpn_head_cls), anchors=anchors, is_training=self.is_training) boxes = tf.stop_gradient(boxes) scores = tf.stop_gradient(scores) category = tf.stop_gradient(category) boxes_head = tf.stop_gradient(boxes_head) boxes_angle = tf.stop_gradient(boxes_angle) if self.is_training: return boxes, scores, category, boxes_head, boxes_angle, self.losses_dict else: return boxes, scores, category, boxes_head, boxes_angle