def get_loss(self, conv_fpn_feat, cls_label, bbox_target, bbox_weight): p = self.p batch_image = p.batch_image image_anchor = p.anchor_generate.image_anchor rpn_stride = p.anchor_generate.stride cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 rpn_cls_logit_list = [] rpn_bbox_delta_list = [] for stride in rpn_stride: rpn_cls_logit = cls_logit_dict[stride] rpn_bbox_delta = bbox_delta_dict[stride] rpn_cls_logit_reshape = X.reshape( data=rpn_cls_logit, shape=(0, 2, -1), name="rpn_cls_score_reshape_stride%s" % stride ) rpn_bbox_delta_reshape = X.reshape( data=rpn_bbox_delta, shape=(0, 0, -1), name="rpn_bbox_pred_reshape_stride%s" % stride ) rpn_bbox_delta_list.append(rpn_bbox_delta_reshape) rpn_cls_logit_list.append(rpn_cls_logit_reshape) # concat output of each level rpn_bbox_delta_concat = X.concat(rpn_bbox_delta_list, axis=2, name="rpn_bbox_pred_concat") rpn_cls_logit_concat = X.concat(rpn_cls_logit_list, axis=2, name="rpn_cls_score_concat") cls_loss = X.softmax_output( data=rpn_cls_logit_concat, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss" ) # regression loss reg_loss = X.smooth_l1( (rpn_bbox_delta_concat - bbox_target), scalar=3.0, name='rpn_reg_l1' ) reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss' ) return cls_loss, reg_loss
def get_prediction(self, conv_feat, im_info): import mxnet as mx p = self.p num_class = p.num_class stride = p.anchor_generate.stride if not isinstance(stride, tuple): stride = (stride) pre_nms_top_n = p.proposal.pre_nms_top_n anchor_target_mean = p.head.mean or (0, 0, 0, 0) anchor_target_std = p.head.std or (1, 1, 1, 1) cls_logit_dict, bbox_delta_dict = self.get_output(conv_feat) from models.FreeAnchor.ops import _proposal_retina cls_score_list = [] bbox_xyxy_list = [] for s in stride: cls_prob = X.sigmoid(data=cls_logit_dict["stride%s" % s]) bbox_delta = bbox_delta_dict["stride%s" % s] anchors = self.anchor_dict["stride%s" % s] pre_nms_top_n_level = -1 if s == max(stride) else pre_nms_top_n bbox_xyxy, cls_score = _proposal_retina( F=mx.sym, cls_prob=cls_prob, bbox_pred=bbox_delta, anchors=anchors, im_info=im_info, batch_size=1, rpn_pre_nms_top_n=pre_nms_top_n_level, num_class=num_class, anchor_mean=anchor_target_mean, anchor_std=anchor_target_std ) cls_score_list.append(cls_score) bbox_xyxy_list.append(bbox_xyxy) cls_score = X.concat(cls_score_list, axis=1, name="cls_score_concat") bbox_xyxy = X.concat(bbox_xyxy_list, axis=1, name="bbox_xyxy_concat") return cls_score, bbox_xyxy
def _get_output(self, mask_pred_logits, conv_feat): num_class = self.pBbox.num_class msra_init = mx.init.Xavier(rnd_type="gaussian", factor_type="out", magnitude=2) normal_init = mx.init.Normal(0.01) kaiming_uniform = mx.init.Xavier(rnd_type='uniform', factor_type='in', magnitude=3) mask_pred_logits = mx.sym.expand_dims(mask_pred_logits, axis=1) iou_head_maxpool_1 = X.pool( mask_pred_logits, name='iou_head_maxpool_1', kernel=2, stride=2, pad=0, ) iou_head_input = X.concat([conv_feat, iou_head_maxpool_1], axis=1, name='iou_head_input') hi = iou_head_input for ii in range(3): hi = X.conv( hi, filter=256, kernel=3, stride=1, name='iou_head_conv_%d'%ii, no_bias=False, init=msra_init, ) hi = X.relu(hi) hi = X.conv( hi, filter=256, kernel=3, stride=2, name='iou_head_conv_3', no_bias=False, init=msra_init ) hi = X.relu(hi) hi = X.flatten(data=hi) fc1 = X.relu(X.fc(hi, filter=1024, name='iou_head_FC1', init=kaiming_uniform)) fc2 = X.relu(X.fc(fc1, filter=1024, name='iou_head_FC2', init=kaiming_uniform)) iou_pred_logits = X.fc(fc2, filter=num_class, name='iou_head_pred', init=normal_init) return iou_pred_logits
def get_all_proposal(self, conv_fpn_feat, im_info): if self._proposal is not None: return self._proposal p = self.p rpn_stride = p.anchor_generate.stride anchor_scale = p.anchor_generate.scale anchor_ratio = p.anchor_generate.ratio pre_nms_top_n = p.proposal.pre_nms_top_n post_nms_top_n = p.proposal.post_nms_top_n nms_thr = p.proposal.nms_thr min_bbox_side = p.proposal.min_bbox_side num_anchors = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) batch_size = p.batch_image cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat) # rpn rois for multi level feature proposal_list = [] proposal_scores_list = [] for stride in rpn_stride: rpn_cls_logit = cls_logit_dict[stride] rpn_bbox_delta = bbox_delta_dict[stride] # ROI Proposal rpn_cls_logit_reshape = X.reshape( data=rpn_cls_logit, shape=(0, 2, -1, 0), name="rpn_cls_logit_reshape_stride%s" % stride) rpn_cls_score = mx.symbol.SoftmaxActivation( data=rpn_cls_logit_reshape, mode="channel", name="rpn_cls_score_stride%s" % stride) rpn_cls_score_reshape = X.reshape( data=rpn_cls_score, shape=(0, 2 * num_anchors, -1, 0), name="rpn_cls_score_reshape_stride%s" % stride) rpn_proposal, rpn_proposal_scores = mx.sym.contrib.Proposal_v3( cls_prob=rpn_cls_score_reshape, bbox_pred=rpn_bbox_delta, im_info=im_info, rpn_pre_nms_top_n=pre_nms_top_n, rpn_post_nms_top_n=post_nms_top_n, feature_stride=stride, output_score=True, scales=tuple(anchor_scale), ratios=tuple(anchor_ratio), rpn_min_size=min_bbox_side, threshold=nms_thr, iou_loss=False) if p.nnvm_proposal and stride < rpn_stride[-2]: max_side = p.anchor_generate.max_side assert max_side is not None, "nnvm proposal requires max_side of image" from mxnext.tvm.proposal import proposal as Proposal anchors = self.anchor_dict["stride%s" % stride] rpn_proposal, rpn_proposal_scores = Proposal( cls_prob=rpn_cls_score_reshape, bbox_pred=rpn_bbox_delta, im_info=im_info, anchors=anchors, name='proposal', feature_stride=stride, scales=tuple(anchor_scale), ratios=tuple(anchor_ratio), rpn_pre_nms_top_n=pre_nms_top_n, rpn_post_nms_top_n=post_nms_top_n, threshold=nms_thr, batch_size=batch_size, max_side=max_side, output_score=True, variant="simpledet") proposal_list.append(rpn_proposal) proposal_scores_list.append(rpn_proposal_scores) # concat output rois of each level proposal_concat = X.concat(proposal_list, axis=1, name="proposal_concat") proposal_scores_concat = X.concat(proposal_scores_list, axis=1, name="proposal_scores_concat") from mxnext.tvm.get_top_proposal import get_top_proposal proposal = get_top_proposal(mx.symbol, bbox=proposal_concat, score=proposal_scores_concat, top_n=post_nms_top_n, batch_size=batch_size) self._proposal = proposal return proposal
def get_loss(self, conv_fpn_feat, gt_bbox, im_info): p = self.p batch_image = p.batch_image image_anchor = p.anchor_assign.image_anchor rpn_stride = p.anchor_generate.stride anchor_scale = p.anchor_generate.scale anchor_ratio = p.anchor_generate.ratio num_anchor = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 rpn_cls_logit_list = [] rpn_bbox_delta_list = [] feat_list = [] for stride in rpn_stride: rpn_cls_logit = cls_logit_dict[stride] rpn_bbox_delta = bbox_delta_dict[stride] rpn_cls_logit_reshape = X.reshape( data=rpn_cls_logit, shape=(0, 2, num_anchor, -1), name="rpn_cls_score_reshape_stride%s" % stride) rpn_bbox_delta_reshape = X.reshape( data=rpn_bbox_delta, shape=(0, 0, -1), name="rpn_bbox_pred_reshape_stride%s" % stride) rpn_bbox_delta_list.append(rpn_bbox_delta_reshape) rpn_cls_logit_list.append(rpn_cls_logit_reshape) feat_list.append(rpn_cls_logit) if p.nnvm_rpn_target: from mxnext.tvm.rpn_target import _fpn_rpn_target_batch anchor_list = [ self.anchor_dict["stride%s" % s] for s in rpn_stride ] gt_bbox = mx.sym.slice_axis(gt_bbox, axis=-1, begin=0, end=4) max_side = p.anchor_generate.max_side allowed_border = p.anchor_assign.allowed_border fg_fraction = p.anchor_assign.pos_fraction fg_thr = p.anchor_assign.pos_thr bg_thr = p.anchor_assign.neg_thr cls_label, bbox_target, bbox_weight = _fpn_rpn_target_batch( mx.sym, feat_list, anchor_list, gt_bbox, im_info, batch_image, num_anchor, max_side, rpn_stride, allowed_border, image_anchor, fg_fraction, fg_thr, bg_thr) else: cls_label = X.var("rpn_cls_label") bbox_target = X.var("rpn_reg_target") bbox_weight = X.var("rpn_reg_weight") # concat output of each level rpn_bbox_delta_concat = X.concat(rpn_bbox_delta_list, axis=2, name="rpn_bbox_pred_concat") rpn_cls_logit_concat = X.concat(rpn_cls_logit_list, axis=-1, name="rpn_cls_score_concat") cls_loss = X.softmax_output(data=rpn_cls_logit_concat, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") # regression loss reg_loss = X.smooth_l1((rpn_bbox_delta_concat - bbox_target), scalar=3.0, name='rpn_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss(reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss') return cls_loss, reg_loss, X.stop_grad(cls_label, "rpn_cls_label_blockgrad")
def get_all_proposal(self, conv_fpn_feat, im_info): if self._proposal is not None: return self._proposal p = self.p rpn_stride = p.anchor_generate.stride anchor_scale = p.anchor_generate.scale anchor_ratio = p.anchor_generate.ratio pre_nms_top_n = p.proposal.pre_nms_top_n post_nms_top_n = p.proposal.post_nms_top_n nms_thr = p.proposal.nms_thr min_bbox_side = p.proposal.min_bbox_side num_anchors = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat) # rpn rois for multi level feature proposal_list = [] proposal_scores_list = [] for stride in rpn_stride: rpn_cls_logit = cls_logit_dict[stride] rpn_bbox_delta = bbox_delta_dict[stride] # ROI Proposal rpn_cls_logit_reshape = X.reshape( data=rpn_cls_logit, shape=(0, 2, -1, 0), name="rpn_cls_logit_reshape_stride%s" % stride) rpn_cls_score = mx.symbol.SoftmaxActivation( data=rpn_cls_logit_reshape, mode="channel", name="rpn_cls_score_stride%s" % stride) rpn_cls_score_reshape = X.reshape( data=rpn_cls_score, shape=(0, 2 * num_anchors, -1, 0), name="rpn_cls_score_reshape_stride%s" % stride) rpn_proposal, rpn_proposal_scores = mx.sym.contrib.Proposal_v3( cls_prob=rpn_cls_score_reshape, bbox_pred=rpn_bbox_delta, im_info=im_info, rpn_pre_nms_top_n=pre_nms_top_n, rpn_post_nms_top_n=post_nms_top_n, feature_stride=stride, output_score=True, scales=tuple(anchor_scale), ratios=tuple(anchor_ratio), rpn_min_size=min_bbox_side, threshold=nms_thr, iou_loss=False) proposal_list.append(rpn_proposal) proposal_scores_list.append(rpn_proposal_scores) # concat output rois of each level proposal_concat = X.concat(proposal_list, axis=1, name="proposal_concat") proposal_scores_concat = X.concat(proposal_scores_list, axis=1, name="proposal_scores_concat") (proposal, proposal_score) = mx.symbol.Custom(op_type="get_top_proposal", bbox=proposal_concat, score=proposal_scores_concat, top_n=post_nms_top_n, name="get_top_proposal") self._proposal = proposal return proposal
def get_prediction(self, conv_feat, im_info): p = self.p stride = p.anchor_generate.stride if not isinstance(stride, tuple): stride = (stride) ratios = p.anchor_generate.ratio scales = p.anchor_generate.scale pre_nms_top_n = p.proposal.pre_nms_top_n min_bbox_side = p.proposal.min_bbox_side or 0 min_det_score = p.proposal.min_det_score anchor_target_mean = p.head.mean or (0, 0, 0, 0) anchor_target_std = p.head.std or (1, 1, 1, 1) num_anchors = len(ratios) * len(scales) cls_logit_dict, bbox_delta_dict = self.get_output(conv_feat) import mxnet as mx if "GenProposalRetina" in mx.sym.contrib.__all__: cls_score_list = [] bbox_xyxy_list = [] for s in stride: cls_prob = X.sigmoid(data=cls_logit_dict["stride%s" % s]) bbox_delta = bbox_delta_dict["stride%s" % s] anchors = mx.sym.contrib.GenAnchor(cls_prob=cls_prob, feature_stride=s, scales=tuple(scales), ratios=tuple(ratios), name='anchors_stride%s' % s) thresh_level = 0 if s == max(stride) else min_det_score bbox_xyxy, cls_score = mx.sym.contrib.GenProposalRetina( cls_prob=cls_prob, bbox_pred=bbox_delta, im_info=im_info, anchors=anchors, feature_stride=s, anchor_mean=anchor_target_mean, anchor_std=anchor_target_std, num_anchors=num_anchors, rpn_pre_nms_top_n=pre_nms_top_n, rpn_min_size=min_bbox_side, thresh=thresh_level, workspace=512, name="proposal_pre_nms_stride%s" % s) cls_score_list.append(cls_score) bbox_xyxy_list.append(bbox_xyxy) cls_score = X.concat(cls_score_list, axis=1, name="cls_score_concat") bbox_xyxy = X.concat(bbox_xyxy_list, axis=1, name="bbox_xyxy_concat") else: cls_score_dict = dict() for s in stride: cls_score = X.sigmoid(data=cls_logit_dict["stride%s" % s]) bbox_delta = bbox_delta_dict.pop("stride%s" % s) cls_score_dict["cls_score_stride%s" % s] = cls_score bbox_delta_dict["bbox_delta_stride%s" % s] = bbox_delta import models.retinanet.decode_retina # noqa: F401 bbox_xyxy, cls_score = mx.sym.Custom(op_type="decode_retina", im_info=im_info, stride=stride, scales=scales, ratios=ratios, per_level_top_n=pre_nms_top_n, thresh=min_det_score, **cls_score_dict, **bbox_delta_dict, name="rois") return cls_score, bbox_xyxy
def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): import mxnet as mx p = self.p stride = p.anchor_generate.stride if not isinstance(stride, tuple): stride = (stride) num_class = p.num_class num_base_anchor = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) image_per_device = p.batch_image sync_loss = p.sync_loss or False cls_logit_dict, bbox_delta_dict = self.get_output(conv_feat) cls_logit_reshape_list = [] bbox_delta_reshape_list = [] scale_loss_shift = 128.0 if p.fp16 else 1.0 if sync_loss: fg_count = X.var("rpn_fg_count") * image_per_device fg_count = mx.sym.slice_axis(fg_count, axis=0, begin=0, end=1) # reshape logit and delta for s in stride: # (N, A * C, H, W) -> (N, A, C, H * W) cls_logit = X.reshape(data=cls_logit_dict["stride%s" % s], shape=(0, num_base_anchor, num_class - 1, -1), name="cls_stride%s_reshape" % s) # (N, A, C, H * W) -> (N, A, H * W, C) cls_logit = X.transpose(data=cls_logit, axes=(0, 1, 3, 2), name="cls_stride%s_transpose" % s) # (N, A, H * W, C) -> (N, A * H * W, C) cls_logit = X.reshape(data=cls_logit, shape=(0, -3, 0), name="cls_stride%s_transpose_reshape" % s) # (N, A * 4, H, W) -> (N, A * 4, H * W) bbox_delta = X.reshape(data=bbox_delta_dict["stride%s" % s], shape=(0, 0, -1), name="bbox_stride%s_reshape" % s) cls_logit_reshape_list.append(cls_logit) bbox_delta_reshape_list.append(bbox_delta) cls_logit_concat = X.concat(cls_logit_reshape_list, axis=1, name="bbox_logit_concat") bbox_delta_concat = X.concat(bbox_delta_reshape_list, axis=2, name="bbox_delta_concat") # classification loss if sync_loss: cls_loss = X.focal_loss(data=cls_logit_concat, label=cls_label, alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, workspace=1500, out_grad=True) cls_loss = mx.sym.broadcast_div(cls_loss, fg_count) cls_loss = X.make_loss(cls_loss, grad_scale=scale_loss_shift, name="cls_loss") else: cls_loss = X.focal_loss(data=cls_logit_concat, label=cls_label, normalization='valid', alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, grad_scale=1.0 * scale_loss_shift, workspace=1024, name="cls_loss") scalar = 0.11 # regression loss bbox_loss = bbox_weight * X.smooth_l1( data=bbox_delta_concat - bbox_target, scalar=math.sqrt(1 / scalar), name="bbox_loss") if sync_loss: bbox_loss = mx.sym.broadcast_div(bbox_loss, fg_count) else: bbox_loss = X.bbox_norm(data=bbox_loss, label=cls_label, name="bbox_norm") reg_loss = X.make_loss(data=bbox_loss, grad_scale=1.0 * scale_loss_shift, name="reg_loss") return cls_loss, reg_loss
def get_loss(self, conv_feat, gt_bbox, im_info): import mxnet as mx p = self.p stride = p.anchor_generate.stride if not isinstance(stride, tuple): stride = (stride) num_class = p.num_class num_base_anchor = len(p.anchor_generate.ratio) * len(p.anchor_generate.scale) image_per_device = p.batch_image cls_logit_dict, bbox_delta_dict = self.get_output(conv_feat) cls_logit_reshape_list = [] bbox_delta_reshape_list = [] feat_list = [] scale_loss_shift = 128.0 if p.fp16 else 1.0 # reshape logit and delta for s in stride: # (N, A * C, H, W) -> (N, A * C, H * W) cls_logit = X.reshape( data=cls_logit_dict["stride%s" % s], shape=(0, 0, -1), name="cls_stride%s_reshape" % s ) # (N, A * 4, H, W) -> (N, A * 4, H * W) bbox_delta = X.reshape( data=bbox_delta_dict["stride%s" % s], shape=(0, 0, -1), name="bbox_stride%s_reshape" % s ) cls_logit_reshape_list.append(cls_logit) bbox_delta_reshape_list.append(bbox_delta) feat_list.append(cls_logit_dict["stride%s" % s]) # cls_logits -> (N, H' * W' * A, C) cls_logits = X.concat(cls_logit_reshape_list, axis=2, name="cls_logit_concat") cls_logits = X.transpose(cls_logits, axes=(0, 2, 1), name="cls_logit_transpose") cls_logits = X.reshape(cls_logits, shape=(0, -1, num_class - 1), name="cls_logit_reshape") cls_prob = X.sigmoid(cls_logits) # bbox_deltas -> (N, H' * W' * A, 4) bbox_deltas = X.concat(bbox_delta_reshape_list, axis=2, name="bbox_delta_concat") bbox_deltas = X.transpose(bbox_deltas, axes=(0, 2, 1), name="bbox_delta_transpose") bbox_deltas = X.reshape(bbox_deltas, shape=(0, -1, 4), name="bbox_delta_reshape") anchor_list = [self.anchor_dict["stride%s" % s] for s in stride] bbox_thr = p.anchor_assign.bbox_thr pre_anchor_top_n = p.anchor_assign.pre_anchor_top_n alpha = p.focal_loss.alpha gamma = p.focal_loss.gamma anchor_target_mean = p.head.mean or (0, 0, 0, 0) anchor_target_std = p.head.std or (1, 1, 1, 1) from models.FreeAnchor.ops import _prepare_anchors, _positive_loss, _negative_loss anchors = _prepare_anchors( mx.sym, feat_list, anchor_list, image_per_device, num_base_anchor) positive_loss = _positive_loss( mx.sym, anchors, gt_bbox, cls_prob, bbox_deltas, image_per_device, alpha, pre_anchor_top_n, anchor_target_mean, anchor_target_std ) positive_loss = X.make_loss( data=positive_loss, grad_scale=1.0 * scale_loss_shift, name="positive_loss" ) negative_loss = _negative_loss( mx.sym, anchors, gt_bbox, cls_prob, bbox_deltas, im_info, image_per_device, num_class, alpha, gamma, pre_anchor_top_n, bbox_thr, anchor_target_mean, anchor_target_std ) negative_loss = X.make_loss( data=negative_loss, grad_scale=1.0 * scale_loss_shift, name="negative_loss" ) return positive_loss, negative_loss
def get_prediction(self, conv_feat, im_info): from models.RepPoints.point_ops import _gen_points, _points2bbox p = self.p batch_image = p.batch_image stride = p.point_generate.stride transform = p.point_generate.transform pre_nms_top_n = p.proposal.pre_nms_top_n pts_out_inits, pts_out_refines, cls_outs = self.get_output(conv_feat) cls_score_dict = dict() bbox_xyxy_dict = dict() for s in stride: # NOTE: pre_nms_top_n_ is hard-coded as -1 because the number of proposals is less # than pre_nms_top_n in these low-resolution feature maps. Also note that one should # select the appropriate params here if using low-resolution images as input. pre_nms_top_n_ = pre_nms_top_n if s <= 32 else -1 points_ = _gen_points(mx.symbol, pts_out_inits["stride%s" % s], s) preds_refines_ = _points2bbox(mx.symbol, pts_out_refines["stride%s" % s], transform, moment_transfer=self.moment_transfer) preds_refines_ = X.reshape( X.transpose(data=preds_refines_, axes=(0, 2, 3, 1)), (0, -3, -2)) cls_ = X.reshape( X.transpose(data=cls_outs["stride%s" % s], axes=(0, 2, 3, 1)), (0, -3, -2)) scores_ = X.sigmoid(cls_) max_scores_ = mx.symbol.max(scores_, axis=-1) max_index_ = mx.symbol.topk(max_scores_, axis=1, k=pre_nms_top_n_) scores_dict = dict() bboxes_dict = dict() for i in range(batch_image): max_index_i = X.reshape( mx.symbol.slice_axis(max_index_, axis=0, begin=i, end=i + 1), (-1, )) scores_i = X.reshape( mx.symbol.slice_axis(scores_, axis=0, begin=i, end=i + 1), (-3, -2)) points_i = X.reshape(points_, (-3, -2)) preds_refines_i = X.reshape( mx.symbol.slice_axis(preds_refines_, axis=0, begin=i, end=i + 1), (-3, -2)) scores_i = mx.symbol.take(scores_i, max_index_i) points_i = mx.symbol.take(points_i, max_index_i) preds_refines_i = mx.symbol.take(preds_refines_i, max_index_i) points_i = mx.symbol.slice_axis(points_i, axis=-1, begin=0, end=2) points_xyxy_i = X.concat([points_i, points_i], axis=-1, name="points_xyxy_b{}_s{}".format( i, s)) bboxes_i = preds_refines_i * s + points_xyxy_i im_info_i = mx.symbol.slice_axis(im_info, axis=0, begin=i, end=i + 1) h_i, w_i, _ = mx.symbol.split(im_info_i, num_outputs=3, axis=1) l_i, t_i, r_i, b_i = mx.symbol.split(bboxes_i, num_outputs=4, axis=1) clip_l_i = mx.symbol.maximum( mx.symbol.broadcast_minimum(l_i, w_i - 1.0), 0.0) clip_t_i = mx.symbol.maximum( mx.symbol.broadcast_minimum(t_i, h_i - 1.0), 0.0) clip_r_i = mx.symbol.maximum( mx.symbol.broadcast_minimum(r_i, w_i - 1.0), 0.0) clip_b_i = mx.symbol.maximum( mx.symbol.broadcast_minimum(b_i, h_i - 1.0), 0.0) clip_bboxes_i = X.concat( [clip_l_i, clip_t_i, clip_r_i, clip_b_i], axis=1, name="clip_bboxes_b{}_s{}".format(i, s)) scores_dict["img%s" % i] = scores_i bboxes_dict["img%s" % i] = clip_bboxes_i cls_score_ = mx.symbol.stack( *[scores_dict["img%s" % i] for i in range(batch_image)], axis=0) pad_zeros_ = mx.symbol.zeros_like( mx.symbol.slice_axis(cls_score_, axis=-1, begin=0, end=1)) cls_score_ = X.concat([pad_zeros_, cls_score_], axis=-1, name="cls_score_s{}".format(s)) bboxes_ = mx.symbol.stack( *[bboxes_dict["img%s" % i] for i in range(batch_image)], axis=0) cls_score_dict["stride%s" % s] = cls_score_ bbox_xyxy_dict["stride%s" % s] = bboxes_ cls_score = X.concat([cls_score_dict["stride%s" % s] for s in stride], axis=1, name="cls_score_concat") bbox_xyxy = X.concat([bbox_xyxy_dict["stride%s" % s] for s in stride], axis=1, name="bbox_xyxy_concat") return cls_score, bbox_xyxy
def get_loss(self, conv_feat, gt_bbox): from models.RepPoints.point_ops import (_gen_points, _offset_to_pts, _point_target, _offset_to_boxes, _points2bbox) p = self.p batch_image = p.batch_image num_points = p.point_generate.num_points scale = p.point_generate.scale stride = p.point_generate.stride transform = p.point_generate.transform target_scale = p.point_target.target_scale num_pos = p.point_target.num_pos pos_iou_thr = p.bbox_target.pos_iou_thr neg_iou_thr = p.bbox_target.neg_iou_thr min_pos_iou = p.bbox_target.min_pos_iou pts_out_inits, pts_out_refines, cls_outs = self.get_output(conv_feat) points = dict() bboxes = dict() pts_coordinate_preds_inits = dict() pts_coordinate_preds_refines = dict() for s in stride: # generate points on base coordinate according to stride and size of feature map points["stride%s" % s] = _gen_points(mx.symbol, pts_out_inits["stride%s" % s], s) # generate bbox after init stage bboxes["stride%s" % s] = _offset_to_boxes( mx.symbol, points["stride%s" % s], X.block_grad(pts_out_inits["stride%s" % s]), s, transform, moment_transfer=self.moment_transfer) # generate final offsets in init stage pts_coordinate_preds_inits["stride%s" % s] = _offset_to_pts( mx.symbol, points["stride%s" % s], pts_out_inits["stride%s" % s], s, num_points) # generate final offsets in refine stage pts_coordinate_preds_refines["stride%s" % s] = _offset_to_pts( mx.symbol, points["stride%s" % s], pts_out_refines["stride%s" % s], s, num_points) # for init stage, use points assignment point_proposals = mx.symbol.tile(X.concat( [points["stride%s" % s] for s in stride], axis=1, name="point_concat"), reps=(batch_image, 1, 1)) points_labels_init, points_gts_init, points_weight_init = _point_target( mx.symbol, point_proposals, gt_bbox, batch_image, "point", scale=target_scale, num_pos=num_pos) # for refine stage, use max iou assignment box_proposals = X.concat([bboxes["stride%s" % s] for s in stride], axis=1, name="box_concat") points_labels_refine, points_gts_refine, points_weight_refine = _point_target( mx.symbol, box_proposals, gt_bbox, batch_image, "box", pos_iou_thr=pos_iou_thr, neg_iou_thr=neg_iou_thr, min_pos_iou=min_pos_iou) bboxes_out_strides = dict() for s in stride: cls_outs["stride%s" % s] = X.reshape( X.transpose(data=cls_outs["stride%s" % s], axes=(0, 2, 3, 1)), (0, -3, -2)) bboxes_out_strides["stride%s" % s] = mx.symbol.repeat( mx.symbol.ones_like( mx.symbol.slice_axis( cls_outs["stride%s" % s], begin=0, end=1, axis=-1)), repeats=4, axis=-1) * s # cls branch cls_outs_concat = X.concat([cls_outs["stride%s" % s] for s in stride], axis=1, name="cls_concat") cls_loss = X.focal_loss(data=cls_outs_concat, label=points_labels_refine, normalization='valid', alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, grad_scale=1.0, workspace=1500, name="cls_loss") # init box branch pts_inits_concat_ = X.concat( [pts_coordinate_preds_inits["stride%s" % s] for s in stride], axis=1, name="pts_init_concat_") pts_inits_concat = X.reshape(pts_inits_concat_, (-3, -2), name="pts_inits_concat") bboxes_inits_concat_ = _points2bbox( mx.symbol, pts_inits_concat, transform, y_first=False, moment_transfer=self.moment_transfer) bboxes_inits_concat = X.reshape(bboxes_inits_concat_, (-4, batch_image, -1, -2)) normalize_term = X.concat( [bboxes_out_strides["stride%s" % s] for s in stride], axis=1, name="normalize_term") * scale pts_init_loss = X.smooth_l1( data=(bboxes_inits_concat - points_gts_init) / normalize_term, scalar=3.0, name="pts_init_l1_loss") pts_init_loss = pts_init_loss * points_weight_init pts_init_loss = X.bbox_norm(data=pts_init_loss, label=points_labels_init, name="pts_init_norm_loss") pts_init_loss = X.make_loss(data=pts_init_loss, grad_scale=0.5, name="pts_init_loss") points_init_labels = X.block_grad(points_labels_refine, name="points_init_labels") # refine box branch pts_refines_concat_ = X.concat( [pts_coordinate_preds_refines["stride%s" % s] for s in stride], axis=1, name="pts_refines_concat_") pts_refines_concat = X.reshape(pts_refines_concat_, (-3, -2), name="pts_refines_concat") bboxes_refines_concat_ = _points2bbox( mx.symbol, pts_refines_concat, transform, y_first=False, moment_transfer=self.moment_transfer) bboxes_refines_concat = X.reshape(bboxes_refines_concat_, (-4, batch_image, -1, -2)) pts_refine_loss = X.smooth_l1( data=(bboxes_refines_concat - points_gts_refine) / normalize_term, scalar=3.0, name="pts_refine_l1_loss") pts_refine_loss = pts_refine_loss * points_weight_refine pts_refine_loss = X.bbox_norm(data=pts_refine_loss, label=points_labels_refine, name="pts_refine_norm_loss") pts_refine_loss = X.make_loss(data=pts_refine_loss, grad_scale=1.0, name="pts_refine_loss") points_refine_labels = X.block_grad(points_labels_refine, name="point_refine_labels") return cls_loss, pts_init_loss, pts_refine_loss, points_init_labels, points_refine_labels
def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): p = self.p stride = p.anchor_generate.stride if not isinstance(stride, tuple): stride = (stride) num_class = p.num_class num_base_anchor = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) cls_logit_list, bbox_delta_list = self.get_output(conv_feat) # reshape logit and delta for i, s in enumerate(stride): # (N, A * C, H, W) -> (N, A, C, H * W) cls_logit = X.reshape(data=cls_logit_list[i], shape=(0, num_base_anchor, num_class - 1, -1), name="cls_stride%s_reshape" % s) # (N, A, C, H * W) -> (N, A, H * W, C) cls_logit = X.transpose(data=cls_logit, axes=(0, 1, 3, 2), name="cls_stride%s_transpose" % s) # (N, A, H * W, C) -> (N, A * H * W, C) cls_logit = X.reshape(data=cls_logit, shape=(0, -3, 0), name="cls_stride%s_transpose_reshape" % s) # (N, A * 4, H, W) -> (N, A * 4, H * W) bbox_delta = X.reshape(data=bbox_delta_list[i], shape=(0, 0, -1), name="bbox_stride%s_reshape" % s) cls_logit_list[i] = cls_logit bbox_delta_list[i] = bbox_delta cls_logit_concat = X.concat(cls_logit_list, axis=1, name="bbox_logit_concat") bbox_delta_concat = X.concat(bbox_delta_list, axis=2, name="bbox_delta_concat") # classification loss cls_loss = X.focal_loss(data=cls_logit_concat, label=cls_label, normalization='valid', alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, grad_scale=1.0, workspace=1024, name="cls_loss") scalar = 0.11 # regression loss bbox_norm = X.bbox_norm(data=bbox_delta_concat - bbox_target, label=cls_label, name="bbox_norm") bbox_loss = bbox_weight * X.smooth_l1( data=bbox_norm, scalar=math.sqrt(1 / scalar), name="bbox_loss") reg_loss = X.make_loss(data=bbox_loss, grad_scale=1.0, name="reg_loss") return cls_loss, reg_loss