def get_train_symbol(cls, backbone, neck, rpn_head, roi_extractor, bbox_head, num_branch, scaleaware): gt_bbox = X.var("gt_bbox") im_info = X.var("im_info") if scaleaware: valid_ranges = X.var("valid_ranges") rpn_cls_label = X.var("rpn_cls_label") rpn_reg_target = X.var("rpn_reg_target") rpn_reg_weight = X.var("rpn_reg_weight") im_info = TridentResNetV2Builder.stack_branch_symbols([im_info] * num_branch) gt_bbox = TridentResNetV2Builder.stack_branch_symbols([gt_bbox] * num_branch) if scaleaware: valid_ranges = X.reshape(valid_ranges, (-3, -2)) rpn_cls_label = X.reshape(rpn_cls_label, (-3, -2)) rpn_reg_target = X.reshape(rpn_reg_target, (-3, -2)) rpn_reg_weight = X.reshape(rpn_reg_weight, (-3, -2)) rpn_feat = backbone.get_rpn_feature() rcnn_feat = backbone.get_rcnn_feature() rpn_feat = neck.get_rpn_feature(rpn_feat) rcnn_feat = neck.get_rcnn_feature(rcnn_feat) rpn_loss = rpn_head.get_loss(rpn_feat, rpn_cls_label, rpn_reg_target, rpn_reg_weight) if scaleaware: proposal, bbox_cls, bbox_target, bbox_weight = rpn_head.get_sampled_proposal_with_filter(rpn_feat, gt_bbox, im_info, valid_ranges) else: proposal, bbox_cls, bbox_target, bbox_weight = rpn_head.get_sampled_proposal(rpn_feat, gt_bbox, im_info) roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal) bbox_loss = bbox_head.get_loss(roi_feat, bbox_cls, bbox_target, bbox_weight) return X.group(rpn_loss + bbox_loss)
def get_test_symbol(backbone, neck, rpn_head, roi_extractor, mask_roi_extractor, bbox_head, mask_head, bbox_post_processor, num_branch): rec_id, im_id, im_info, proposal, proposal_score = \ TridentFasterRcnn.get_rpn_test_symbol(backbone, neck, rpn_head, num_branch) im_info_branches = TridentResNetV2Builder.stack_branch_symbols( [im_info] * num_branch) rcnn_feat = backbone.get_rcnn_feature() rcnn_feat = neck.get_rcnn_feature(rcnn_feat) roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal) cls_score, bbox_xyxy = bbox_head.get_prediction( roi_feat, im_info_branches, proposal) post_cls_score, post_bbox_xyxy, post_cls = bbox_post_processor.get_post_processing( cls_score, bbox_xyxy) mask_roi_feat = mask_roi_extractor.get_roi_feature( rcnn_feat, post_bbox_xyxy) mask = mask_head.get_prediction(mask_roi_feat) # fold batch size into roi size for trident only post_cls_score = X.reshape(post_cls_score, (-3, -2), name="post_cls_score_fold") post_bbox_xyxy = X.reshape(post_bbox_xyxy, (-3, -2), name="post_bbox_xyxy_fold") post_cls = X.reshape(post_cls, (-3, -2), name="post_cls_fold") return X.group([ rec_id, im_id, im_info, post_cls_score, post_bbox_xyxy, post_cls, mask ])
def __init__(self, pBackbone): super().__init__(pBackbone) p = self.p b = TridentResNetV2Builder() self.c4, self.c5 = b.get_backbone("mxnet", p.depth, "c4c5", p.normalizer, p.fp16, p.num_branch, p.branch_dilates, p.branch_ids, p.branch_bn_shared, p.branch_conv_shared, p.branch_deform)
def __init__(self, pBackbone): super(TridentMXNetResNetV2, self).__init__(pBackbone) p = pBackbone b = TridentResNetV2Builder() self.symbol = b.get_backbone("mxnet", p.depth, "c4", p.normalizer, p.fp16, p.num_branch, p.branch_dilates, p.branch_ids, p.branch_bn_shared, p.branch_conv_shared, p.branch_deform)
def get_test_symbol(cls, backbone, neck, rpn_head, roi_extractor, bbox_head, num_branch): rec_id, im_id, im_info, proposal, proposal_score = \ TridentFasterRcnn.get_rpn_test_symbol(backbone, neck, rpn_head, num_branch) im_info_branches = TridentResNetV2Builder.stack_branch_symbols([im_info] * num_branch) rcnn_feat = backbone.get_rcnn_feature() rcnn_feat = neck.get_rcnn_feature(rcnn_feat) roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal) cls_score, bbox_xyxy = bbox_head.get_prediction(roi_feat, im_info_branches, proposal) cls_score = X.reshape(cls_score, (-3, -2)) bbox_xyxy = X.reshape(bbox_xyxy, (-3, -2)) return X.group([rec_id, im_id, im_info, cls_score, bbox_xyxy])
def get_rpn_test_symbol(cls, backbone, neck, rpn_head, num_branch): if cls._rpn_output is not None: return cls._rpn_output im_info = X.var("im_info") im_id = X.var("im_id") rec_id = X.var("rec_id") rpn_feat = backbone.get_rpn_feature() rpn_feat = neck.get_rpn_feature(rpn_feat) im_info_branches = TridentResNetV2Builder.stack_branch_symbols([im_info] * num_branch) (proposal, proposal_score) = rpn_head.get_all_proposal(rpn_feat, im_info_branches) cls._rpn_output = X.group([rec_id, im_id, im_info, proposal, proposal_score]) return cls._rpn_output