def __init__(self, num_classes, **kwags):
        super(CascadeRCNN, self).__init__(**kwags)

        self.NUM_CLASSES = num_classes

        # RPN configuration
        # Anchor attributes
        self.ANCHOR_SCALES = (32, 64, 128, 256, 512)
        self.ANCHOR_RATIOS = (0.5, 1, 2)
        self.ANCHOR_FEATURE_STRIDES = (4, 8, 16, 32, 64)

        # Bounding box refinement mean and standard deviation
        self.RPN_TARGET_MEANS = (0., 0., 0., 0.)
        self.RPN_TARGET_STDS = (0.1, 0.1, 0.2, 0.2)

        # RPN training configuration
        self.PRN_BATCH_SIZE = 256
        self.RPN_POS_FRAC = 0.5
        self.RPN_POS_IOU_THR = 0.7
        self.RPN_NEG_IOU_THR = 0.3

        # ROIs kept configuration
        self.PRN_PROPOSAL_COUNT = 2000
        self.PRN_NMS_THRESHOLD = 0.7

        # RCNN configuration
        # Bounding box refinement mean and standard deviation
        self.RCNN_TARGET_MEANS = (0., 0., 0., 0.)
        self.RCNN_TARGET_STDS = (0.1, 0.1, 0.2, 0.2)

        # ROI Feat Size
        self.POOL_SIZE = (7, 7)

        # RCNN training configuration
        self.RCNN_BATCH_SIZE = 256
        self.RCNN_POS_FRAC = 0.25
        self.RCNN_POS_IOU_THR = [
            0.5, 0.6, 0.7
        ]  # This can be smaller like [0.3, 0.4, 0.5] / [0.4, 0.5, 0.6]
        self.RCNN_NEG_IOU_THR = [0.5, 0.4, 0.3]

        # Boxes kept configuration
        self.RCNN_MIN_CONFIDENCE = 0.05
        self.RCNN_NMS_THRESHOLD = 0.5
        self.RCNN_MAX_INSTANCES = 100

        # Target Generator for the second stage.
        self.bbox_target1 = bbox_target.ProposalTarget(
            target_means=self.RCNN_TARGET_MEANS,
            target_stds=self.RPN_TARGET_STDS,
            num_rcnn_deltas=self.RCNN_BATCH_SIZE,
            positive_fraction=self.RCNN_POS_FRAC,
            pos_iou_thr=self.RCNN_POS_IOU_THR[0],
            neg_iou_thr=self.RCNN_NEG_IOU_THR[0],
            num_classes=self.NUM_CLASSES)

        self.bbox_target2 = bbox_target.ProposalTarget(
            target_means=self.RCNN_TARGET_MEANS,
            target_stds=self.RPN_TARGET_STDS,
            num_rcnn_deltas=self.RCNN_BATCH_SIZE,
            positive_fraction=self.RCNN_POS_FRAC,
            pos_iou_thr=self.RCNN_POS_IOU_THR[1],
            neg_iou_thr=self.RCNN_NEG_IOU_THR[1],
            num_classes=self.NUM_CLASSES)

        self.bbox_target3 = bbox_target.ProposalTarget(
            target_means=self.RCNN_TARGET_MEANS,
            target_stds=self.RPN_TARGET_STDS,
            num_rcnn_deltas=self.RCNN_BATCH_SIZE,
            positive_fraction=self.RCNN_POS_FRAC,
            pos_iou_thr=self.RCNN_POS_IOU_THR[2],
            neg_iou_thr=self.RCNN_NEG_IOU_THR[2],
            num_classes=self.NUM_CLASSES)

        # Modules
        self.backbone = resnet.ResNet(depth=101, name='res_net')

        self.neck = fpn.FPN(name='fpn')

        self.rpn_head = rpn_head.RPNHead(
            anchor_scales=self.ANCHOR_SCALES,
            anchor_ratios=self.ANCHOR_RATIOS,
            anchor_feature_strides=self.ANCHOR_FEATURE_STRIDES,
            proposal_count=self.PRN_PROPOSAL_COUNT,
            nms_threshold=self.PRN_NMS_THRESHOLD,
            target_means=self.RPN_TARGET_MEANS,
            target_stds=self.RPN_TARGET_STDS,
            num_rpn_deltas=self.PRN_BATCH_SIZE,
            positive_fraction=self.RPN_POS_FRAC,
            pos_iou_thr=self.RPN_POS_IOU_THR,
            neg_iou_thr=self.RPN_NEG_IOU_THR,
            name='rpn_head')

        self.roi_align = roi_align.PyramidROIAlign(pool_shape=self.POOL_SIZE,
                                                   name='pyramid_roi_align')

        # first detection stage
        self.bbox_head1 = bbox_head.BBoxHead(
            num_classes=self.NUM_CLASSES,
            pool_size=self.POOL_SIZE,
            target_means=self.RCNN_TARGET_MEANS,
            target_stds=self.RCNN_TARGET_STDS,
            min_confidence=self.RCNN_MIN_CONFIDENCE,
            nms_threshold=self.RCNN_NMS_THRESHOLD,
            max_instances=self.RCNN_MAX_INSTANCES,
            name='b_box_head1')

        # second detection stage
        self.bbox_head2 = bbox_head.BBoxHead(
            num_classes=self.NUM_CLASSES,
            pool_size=self.POOL_SIZE,
            target_means=self.RCNN_TARGET_MEANS,
            target_stds=self.RCNN_TARGET_STDS,
            min_confidence=self.RCNN_MIN_CONFIDENCE,
            nms_threshold=self.RCNN_NMS_THRESHOLD,
            max_instances=self.RCNN_MAX_INSTANCES,
            name='b_box_head2')

        # third detection stage
        self.bbox_head3 = bbox_head.BBoxHead(
            num_classes=self.NUM_CLASSES,
            pool_size=self.POOL_SIZE,
            target_means=self.RCNN_TARGET_MEANS,
            target_stds=self.RCNN_TARGET_STDS,
            min_confidence=self.RCNN_MIN_CONFIDENCE,
            nms_threshold=self.RCNN_NMS_THRESHOLD,
            max_instances=self.RCNN_MAX_INSTANCES,
            name='b_box_head3')
    def __init__(self, num_classes, **kwags):
        super(FasterRCNN, self).__init__(**kwags)

        self.NUM_CLASSES = num_classes

        # Anchor attributes
        self.ANCHOR_SCALES = (32, 64, 128, 256, 512)
        self.ANCHOR_RATIOS = (0.5, 1, 2)

        # The strides of each layer of the FPN Pyramid.
        self.FEATURE_STRIDES = (4, 8, 16, 32, 64)

        # Bounding box refinement mean and standard deviation
        self.RPN_TARGET_MEANS = (0., 0., 0., 0.)
        self.RPN_TARGET_STDS = (0.1, 0.1, 0.2, 0.2)

        self.PRN_PROPOSAL_COUNT = 2000
        self.PRN_NMS_THRESHOLD = 0.7

        self.ROI_BATCH_SIZE = 512

        # Bounding box refinement mean and standard deviation
        self.RCNN_TARGET_MEANS = (0., 0., 0., 0.)
        self.RCNN_TARGET_STDS = (0.1, 0.1, 0.2, 0.2)

        self.POOL_SIZE = (7, 7)

        self.backbone = resnet.ResNet(depth=101, name='res_net')
        self.neck = fpn.FPN(name='fpn')
        self.rpn_head = rpn_head.RPNHead(
            anchors_per_location=len(self.ANCHOR_RATIOS),
            proposal_count=self.PRN_PROPOSAL_COUNT,
            nms_threshold=self.PRN_NMS_THRESHOLD,
            target_means=self.RPN_TARGET_MEANS,
            target_stds=self.RPN_TARGET_STDS,
            name='rpn_head')

        self.roi_align = roi_align.PyramidROIAlign(pool_shape=self.POOL_SIZE,
                                                   name='pyramid_roi_align')
        self.bbox_head = bbox_head.BBoxHead(num_classes=self.NUM_CLASSES,
                                            pool_size=self.POOL_SIZE,
                                            name='b_box_head')

        self.generator = anchor_generator.AnchorGenerator(
            scales=self.ANCHOR_SCALES,
            ratios=self.ANCHOR_RATIOS,
            feature_strides=self.FEATURE_STRIDES)

        self.anchor_target = anchor_target.AnchorTarget(
            target_means=self.RPN_TARGET_MEANS,
            target_stds=self.RPN_TARGET_STDS)

        self.bbox_target = bbox_target.ProposalTarget(
            target_means=self.RCNN_TARGET_MEANS,
            target_stds=self.RPN_TARGET_STDS,
            num_rcnn_deltas=self.ROI_BATCH_SIZE)

        self.rpn_class_loss = losses.rpn_class_loss
        self.rpn_bbox_loss = losses.rpn_bbox_loss

        self.rcnn_class_loss = losses.rcnn_class_loss
        self.rcnn_bbox_loss = losses.rcnn_bbox_loss