Exemple #1
0
    def get_train_symbol(cls, backbone, neck, rpn_head, roi_extractor, bbox_head, num_branch, scaleaware):
        gt_bbox = X.var("gt_bbox")
        im_info = X.var("im_info")
        if scaleaware:
            valid_ranges = X.var("valid_ranges")
        rpn_cls_label = X.var("rpn_cls_label")
        rpn_reg_target = X.var("rpn_reg_target")
        rpn_reg_weight = X.var("rpn_reg_weight")

        im_info = TridentResNetV2Builder.stack_branch_symbols([im_info] * num_branch)
        gt_bbox = TridentResNetV2Builder.stack_branch_symbols([gt_bbox] * num_branch)
        if scaleaware:
            valid_ranges = X.reshape(valid_ranges, (-3, -2))
        rpn_cls_label = X.reshape(rpn_cls_label, (-3, -2))
        rpn_reg_target = X.reshape(rpn_reg_target, (-3, -2))
        rpn_reg_weight = X.reshape(rpn_reg_weight, (-3, -2))

        rpn_feat = backbone.get_rpn_feature()
        rcnn_feat = backbone.get_rcnn_feature()
        rpn_feat = neck.get_rpn_feature(rpn_feat)
        rcnn_feat = neck.get_rcnn_feature(rcnn_feat)

        rpn_loss = rpn_head.get_loss(rpn_feat, rpn_cls_label, rpn_reg_target, rpn_reg_weight)
        if scaleaware:
            proposal, bbox_cls, bbox_target, bbox_weight = rpn_head.get_sampled_proposal_with_filter(rpn_feat, gt_bbox, im_info, valid_ranges)
        else:
            proposal, bbox_cls, bbox_target, bbox_weight = rpn_head.get_sampled_proposal(rpn_feat, gt_bbox, im_info)
        roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal)
        bbox_loss = bbox_head.get_loss(roi_feat, bbox_cls, bbox_target, bbox_weight)

        return X.group(rpn_loss + bbox_loss)
Exemple #2
0
    def get_test_symbol(backbone, neck, rpn_head, roi_extractor,
                        mask_roi_extractor, bbox_head, mask_head,
                        bbox_post_processor, num_branch):
        rec_id, im_id, im_info, proposal, proposal_score = \
            TridentFasterRcnn.get_rpn_test_symbol(backbone, neck, rpn_head, num_branch)

        im_info_branches = TridentResNetV2Builder.stack_branch_symbols(
            [im_info] * num_branch)

        rcnn_feat = backbone.get_rcnn_feature()
        rcnn_feat = neck.get_rcnn_feature(rcnn_feat)

        roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal)
        cls_score, bbox_xyxy = bbox_head.get_prediction(
            roi_feat, im_info_branches, proposal)
        post_cls_score, post_bbox_xyxy, post_cls = bbox_post_processor.get_post_processing(
            cls_score, bbox_xyxy)

        mask_roi_feat = mask_roi_extractor.get_roi_feature(
            rcnn_feat, post_bbox_xyxy)
        mask = mask_head.get_prediction(mask_roi_feat)

        # fold batch size into roi size for trident only
        post_cls_score = X.reshape(post_cls_score, (-3, -2),
                                   name="post_cls_score_fold")
        post_bbox_xyxy = X.reshape(post_bbox_xyxy, (-3, -2),
                                   name="post_bbox_xyxy_fold")
        post_cls = X.reshape(post_cls, (-3, -2), name="post_cls_fold")

        return X.group([
            rec_id, im_id, im_info, post_cls_score, post_bbox_xyxy, post_cls,
            mask
        ])
Exemple #3
0
    def get_test_symbol(cls, backbone, neck, rpn_head, roi_extractor, bbox_head, num_branch):
        rec_id, im_id, im_info, proposal, proposal_score = \
            TridentFasterRcnn.get_rpn_test_symbol(backbone, neck, rpn_head, num_branch)

        im_info_branches = TridentResNetV2Builder.stack_branch_symbols([im_info] * num_branch)

        rcnn_feat = backbone.get_rcnn_feature()
        rcnn_feat = neck.get_rcnn_feature(rcnn_feat)

        roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal)
        cls_score, bbox_xyxy = bbox_head.get_prediction(roi_feat, im_info_branches, proposal)

        cls_score = X.reshape(cls_score, (-3, -2))
        bbox_xyxy = X.reshape(bbox_xyxy, (-3, -2))

        return X.group([rec_id, im_id, im_info, cls_score, bbox_xyxy])
Exemple #4
0
    def get_rpn_test_symbol(cls, backbone, neck, rpn_head, num_branch):
        if cls._rpn_output is not None:
            return cls._rpn_output

        im_info = X.var("im_info")
        im_id = X.var("im_id")
        rec_id = X.var("rec_id")

        rpn_feat = backbone.get_rpn_feature()
        rpn_feat = neck.get_rpn_feature(rpn_feat)

        im_info_branches = TridentResNetV2Builder.stack_branch_symbols([im_info] * num_branch)
        (proposal, proposal_score) = rpn_head.get_all_proposal(rpn_feat, im_info_branches)

        cls._rpn_output = X.group([rec_id, im_id, im_info, proposal, proposal_score])
        return cls._rpn_output