def refine_stage(self,
                     input_img_batch,
                     gtboxes_batch_r,
                     gt_smooth_label,
                     box_pred_list,
                     cls_prob_list,
                     proposal_list,
                     angle_cls_list,
                     feature_pyramid,
                     gpu_id,
                     pos_threshold,
                     neg_threshold,
                     stage,
                     proposal_filter=False):
        with tf.variable_scope('refine_feature_pyramid{}'.format(stage)):
            refine_feature_pyramid = {}
            refine_boxes_list = []
            refine_boxes_angle_list = []

            for box_pred, cls_prob, proposal, angle_prob, stride, level in \
                    zip(box_pred_list, cls_prob_list, proposal_list, angle_cls_list,
                        cfgs.ANCHOR_STRIDE, cfgs.LEVEL):

                if proposal_filter:
                    box_pred = tf.reshape(
                        box_pred, [-1, self.num_anchors_per_location, 5])
                    proposal = tf.reshape(proposal, [
                        -1, self.num_anchors_per_location,
                        5 if self.method == 'R' else 4
                    ])
                    cls_prob = tf.reshape(
                        cls_prob,
                        [-1, self.num_anchors_per_location, cfgs.CLASS_NUM])

                    cls_max_prob = tf.reduce_max(cls_prob, axis=-1)
                    box_pred_argmax = tf.cast(
                        tf.reshape(tf.argmax(cls_max_prob, axis=-1), [-1, 1]),
                        tf.int32)
                    indices = tf.cast(
                        tf.cumsum(tf.ones_like(box_pred_argmax), axis=0),
                        tf.int32) - tf.constant(1, tf.int32)
                    indices = tf.concat([indices, box_pred_argmax], axis=-1)

                    box_pred = tf.reshape(tf.gather_nd(box_pred, indices),
                                          [-1, 5])
                    proposal = tf.reshape(tf.gather_nd(proposal, indices),
                                          [-1, 5 if self.method == 'R' else 4])

                    if cfgs.METHOD == 'H':
                        x_c = (proposal[:, 2] + proposal[:, 0]) / 2
                        y_c = (proposal[:, 3] + proposal[:, 1]) / 2
                        h = proposal[:, 2] - proposal[:, 0] + 1
                        w = proposal[:, 3] - proposal[:, 1] + 1
                        theta = -90 * tf.ones_like(x_c)
                        proposal = tf.transpose(
                            tf.stack([x_c, y_c, w, h, theta]))
                else:
                    box_pred = tf.reshape(box_pred, [-1, 5])
                    proposal = tf.reshape(proposal, [-1, 5])

                bboxes = bbox_transform.rbbox_transform_inv(boxes=proposal,
                                                            deltas=box_pred)

                if angle_prob is not None:
                    angle_cls = tf.cast(
                        tf.argmax(tf.sigmoid(angle_prob), axis=1), tf.float32)
                    angle_cls = tf.reshape(angle_cls, [
                        -1,
                    ]) * -1 - 0.5
                    x, y, w, h, theta = tf.unstack(bboxes, axis=1)
                    bboxes_angle = tf.transpose(
                        tf.stack([x, y, w, h, angle_cls]))
                    refine_boxes_angle_list.append(bboxes_angle)
                    center_point = bboxes_angle[:, :2] / stride
                else:
                    center_point = bboxes[:, :2] / stride
                refine_boxes_list.append(bboxes)

                refine_feature_pyramid[level] = self.refine_feature_op(
                    points=center_point,
                    feature_map=feature_pyramid[level],
                    name=level)

            refine_box_pred_list, refine_cls_score_list, refine_cls_prob_list, refine_angle_cls_list = self.refine_net(
                refine_feature_pyramid, 'refine_net{}'.format(stage))

            refine_box_pred = tf.concat(refine_box_pred_list, axis=0)
            refine_cls_score = tf.concat(refine_cls_score_list, axis=0)
            # refine_cls_prob = tf.concat(refine_cls_prob_list, axis=0)
            refine_boxes = tf.concat(refine_boxes_list, axis=0)
            refine_angle_cls = tf.concat(refine_angle_cls_list, axis=0)

        if self.is_training:
            with tf.variable_scope('build_refine_loss{}'.format(stage)):
                refine_labels, refine_target_delta, refine_box_states, refine_target_boxes, refine_target_smooth_label = tf.py_func(
                    func=refinebox_target_layer,
                    inp=[
                        gtboxes_batch_r, gt_smooth_label, refine_boxes,
                        pos_threshold, neg_threshold, gpu_id
                    ],
                    Tout=[
                        tf.float32, tf.float32, tf.float32, tf.float32,
                        tf.float32
                    ])

                self.add_anchor_img_smry(input_img_batch, refine_boxes,
                                         refine_box_states, 1)

                refine_cls_loss = losses.focal_loss(refine_labels,
                                                    refine_cls_score,
                                                    refine_box_states)
                if False:  # cfgs.USE_IOU_FACTOR:
                    refine_reg_loss = losses.iou_smooth_l1_loss(
                        refine_target_delta,
                        refine_box_pred,
                        refine_box_states,
                        refine_target_boxes,
                        refine_boxes,
                        is_refine=True)
                else:
                    refine_reg_loss = losses.smooth_l1_loss(
                        refine_target_delta, refine_box_pred,
                        refine_box_states)

                angle_cls_loss = losses.angle_focal_loss(
                    refine_target_smooth_label, refine_angle_cls,
                    refine_box_states)

                self.losses_dict['refine_cls_loss{}'.format(
                    stage)] = refine_cls_loss * cfgs.CLS_WEIGHT
                self.losses_dict['refine_reg_loss{}'.format(
                    stage)] = refine_reg_loss * cfgs.REG_WEIGHT
                self.losses_dict['angle_cls_loss{}'.format(
                    stage)] = angle_cls_loss * cfgs.ANGLE_CLS_WEIGHT

        return refine_box_pred_list, refine_cls_prob_list, refine_boxes_list, refine_angle_cls_list
    def build_whole_detection_network(self, input_img_batch, gtboxes_batch_h, gtboxes_batch_r,
                                      gt_smooth_label, gpu_id=0):

        if self.is_training:
            gtboxes_batch_h = tf.reshape(gtboxes_batch_h, [-1, 5])
            gtboxes_batch_h = tf.cast(gtboxes_batch_h, tf.float32)

            gtboxes_batch_r = tf.reshape(gtboxes_batch_r, [-1, 6])
            gtboxes_batch_r = tf.cast(gtboxes_batch_r, tf.float32)

            gt_smooth_label = tf.reshape(gt_smooth_label, [-1, cfgs.ANGLE_RANGE])
            gt_smooth_label = tf.cast(gt_smooth_label, tf.float32)

        img_shape = tf.shape(input_img_batch)

        # 1. build base network
        feature_pyramid = self.build_base_network(input_img_batch)

        # 2. build rpn
        rpn_box_pred, rpn_cls_score, rpn_cls_prob, rpn_angle_cls = self.rpn_net(feature_pyramid)

        # 3. generate_anchors
        anchors = self.make_anchors(feature_pyramid)

        # 4. postprocess rpn proposals. such as: decode, clip, filter
        if self.is_training:
            with tf.variable_scope('build_loss'):
                labels, target_delta, anchor_states, target_boxes, target_smooth_label = tf.py_func(
                    func=anchor_target_layer,
                    inp=[gtboxes_batch_h, gtboxes_batch_r,
                         gt_smooth_label, anchors, gpu_id],
                    Tout=[tf.float32, tf.float32, tf.float32,
                          tf.float32, tf.float32])

                if self.method == 'H':
                    self.add_anchor_img_smry(input_img_batch, anchors, anchor_states, 0)
                else:
                    self.add_anchor_img_smry(input_img_batch, anchors, anchor_states, 1)

                cls_loss = losses.focal_loss(labels, rpn_cls_score, anchor_states)

                if cfgs.REG_LOSS_MODE == 0:
                    reg_loss = losses.iou_smooth_l1_loss(target_delta, rpn_box_pred, anchor_states, target_boxes,
                                                         anchors)
                elif cfgs.REG_LOSS_MODE == 1:
                    reg_loss = losses.smooth_l1_loss_atan(target_delta, rpn_box_pred, anchor_states)
                else:
                    reg_loss = losses.smooth_l1_loss(target_delta, rpn_box_pred, anchor_states)

                angle_cls_loss = losses.angle_focal_loss(target_smooth_label, rpn_angle_cls, anchor_states)

                self.losses_dict = {'cls_loss': cls_loss * cfgs.CLS_WEIGHT,
                                    'reg_loss': reg_loss * cfgs.REG_WEIGHT,
                                    'angle_cls_loss': angle_cls_loss * cfgs.ANGLE_WEIGHT}

        with tf.variable_scope('postprocess_detctions'):
            boxes, scores, category, boxes_angle = postprocess_detctions(rpn_bbox_pred=rpn_box_pred,
                                                                         rpn_cls_prob=rpn_cls_prob,
                                                                         rpn_angle_prob=tf.sigmoid(rpn_angle_cls),
                                                                         anchors=anchors,
                                                                         is_training=self.is_training)
            boxes = tf.stop_gradient(boxes)
            scores = tf.stop_gradient(scores)
            category = tf.stop_gradient(category)

        if self.is_training:
            return boxes, scores, category, boxes_angle, self.losses_dict
        else:
            return boxes, scores, category, boxes_angle