def get_loss(self, conv_fpn_feat, gt_bbox, im_info): p = self.p bs = p.batch_image # batch_size on a single gpu centerness_logit_dict, cls_logit_dict, offset_logit_dict = self.get_output(conv_fpn_feat) centerness_loss_list = [] cls_loss_list = [] offset_loss_list = [] # prepare gt ignore_label = X.block_grad(X.var('ignore_label', init=X.constant(p.loss_setting.ignore_label), shape=(1,1))) ignore_offset = X.block_grad(X.var('ignore_offset', init=X.constant(p.loss_setting.ignore_offset), shape=(1,1,1))) gt_bbox = X.var('gt_bbox') im_info = X.var('im_info') centerness_labels, cls_labels, offset_labels = make_fcos_gt(gt_bbox, im_info, p.loss_setting.ignore_offset, p.loss_setting.ignore_label, p.FCOSParam.num_classifier) centerness_labels = X.block_grad(centerness_labels) cls_labels = X.block_grad(cls_labels) offset_labels = X.block_grad(offset_labels) # gather output logits cls_logit_dict_list = [] centerness_logit_dict_list = [] offset_logit_dict_list = [] for idx, stride in enumerate(p.FCOSParam.stride): # (c,H1,W1), (c,H2,W2), ..., (c,H5,W5) -> (H1W1+H2W2+...+H5W5), ...c..., (H1W1+H2W2+...+H5W5) cls_logit_dict_list.append(mx.sym.reshape(cls_logit_dict[stride], shape=(0,0,-1))) centerness_logit_dict_list.append(mx.sym.reshape(centerness_logit_dict[stride], shape=(0,0,-1))) offset_logit_dict_list.append(mx.sym.reshape(offset_logit_dict[stride], shape=(0,0,-1))) cls_logits = mx.sym.reshape(mx.sym.concat(*cls_logit_dict_list, dim=2), shape=(0,-1)) centerness_logits = mx.sym.reshape(mx.sym.concat(*centerness_logit_dict_list, dim=2), shape=(0,-1)) offset_logits = mx.sym.reshape(mx.sym.concat(*offset_logit_dict_list, dim=2), shape=(0,4,-1)) # make losses nonignore_mask = mx.sym.broadcast_not_equal(lhs=cls_labels, rhs=ignore_label) nonignore_mask = X.block_grad(nonignore_mask) cls_loss = make_sigmoid_focal_loss(gamma=p.loss_setting.focal_loss_gamma, alpha=p.loss_setting.focal_loss_alpha, logits=cls_logits, labels=cls_labels, nonignore_mask=nonignore_mask) cls_loss = X.loss(cls_loss, grad_scale=1) nonignore_mask = mx.sym.broadcast_logical_and(lhs=mx.sym.broadcast_not_equal( lhs=X.block_grad(centerness_labels), rhs=ignore_label ), rhs=mx.sym.broadcast_greater( lhs=centerness_labels, rhs=mx.sym.full((1,1), 0) ) ) nonignore_mask = X.block_grad(nonignore_mask) centerness_loss = make_binary_cross_entropy_loss(centerness_logits, centerness_labels, nonignore_mask) centerness_loss = X.loss(centerness_loss, grad_scale=1) offset_loss = IoULoss(offset_logits, offset_labels, ignore_offset, centerness_labels, name='offset_loss') return centerness_loss, cls_loss, offset_loss
def cls_pc_loss(self, logits, tsd_logits, gt_label, scale_loss_shift): ''' TSD classification progressive constraint Args: logits: [batch_image * image_roi, num_class] tsd_logits: [batch_image * image_roi, num_class] gt_label: [batch_image * image_roi] scale_loss_shift: float Returns: loss: [batch_image * image_roi] ''' p = self.p batch_image = p.batch_image image_roi = p.image_roi batch_roi = batch_image * image_roi margin = self.p.TSD.pc_cls_margin cls_prob = mx.sym.SoftmaxActivation(logits, mode='instance') tsd_prob = mx.sym.SoftmaxActivation(tsd_logits, mode='instance') cls_score = mx.sym.pick(cls_prob, gt_label, axis=1) tsd_score = mx.sym.pick(tsd_prob, gt_label, axis=1) cls_score = X.block_grad(cls_score) cls_pc_margin = mx.sym.minimum(1. - cls_score, margin) loss = mx.sym.relu(-(tsd_score - cls_score - cls_pc_margin)) grad_scale = 1. / batch_roi grad_scale *= scale_loss_shift loss = X.loss(loss, grad_scale=grad_scale, name='cls_pc_loss') return loss
def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): p = self.p batch_image = p.batch_image image_anchor = p.anchor_generate.image_anchor cls_logit, bbox_delta = self.get_output(conv_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 # classification loss cls_logit_reshape = X.reshape( cls_logit, shape=(0, -4, 2, -1, 0, 0), # (N,C,H,W) -> (N,2,C/2,H,W) name="rpn_cls_logit_reshape") cls_loss = X.softmax_output(data=cls_logit_reshape, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") # regression loss reg_loss = X.smooth_l1((bbox_delta - bbox_target), scalar=3.0, name='rpn_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss(reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss') return cls_loss, reg_loss
def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): p = self.p batch_roi = p.image_roi * p.batch_image batch_image = p.batch_image cls_logit, bbox_delta = self.get_output(conv_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 # classification loss cls_loss = X.softmax_output(data=cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='bbox_cls_loss') # bounding box regression reg_loss = X.smooth_l1(bbox_delta - bbox_target, scalar=1.0, name='bbox_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='bbox_reg_loss', ) # append label cls_label = X.reshape(cls_label, shape=(batch_image, -1), name='bbox_label_reshape') cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad') # output return cls_loss, reg_loss, cls_label
def get_loss(self, conv_fpn_feat, cls_label, bbox_target, bbox_weight): p = self.p batch_image = p.batch_image image_anchor = p.anchor_generate.image_anchor rpn_stride = p.anchor_generate.stride cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 rpn_cls_logit_list = [] rpn_bbox_delta_list = [] for stride in rpn_stride: rpn_cls_logit = cls_logit_dict[stride] rpn_bbox_delta = bbox_delta_dict[stride] rpn_cls_logit_reshape = X.reshape( data=rpn_cls_logit, shape=(0, 2, -1), name="rpn_cls_score_reshape_stride%s" % stride ) rpn_bbox_delta_reshape = X.reshape( data=rpn_bbox_delta, shape=(0, 0, -1), name="rpn_bbox_pred_reshape_stride%s" % stride ) rpn_bbox_delta_list.append(rpn_bbox_delta_reshape) rpn_cls_logit_list.append(rpn_cls_logit_reshape) # concat output of each level rpn_bbox_delta_concat = X.concat(rpn_bbox_delta_list, axis=2, name="rpn_bbox_pred_concat") rpn_cls_logit_concat = X.concat(rpn_cls_logit_list, axis=2, name="rpn_cls_score_concat") cls_loss = X.softmax_output( data=rpn_cls_logit_concat, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss" ) # regression loss reg_loss = X.smooth_l1( (rpn_bbox_delta_concat - bbox_target), scalar=3.0, name='rpn_reg_l1' ) reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss' ) return cls_loss, reg_loss
def get_loss(self, conv_feat, gt_bboxes, im_infos, rpn_groups): p = self.p num_class = p.num_class batch_image = p.batch_image image_anchor = p.anchor_generate.image_anchor cls_logit, bbox_delta = self.get_output(conv_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 cls_label = X.var("rpn_cls_label") bbox_target = X.var("rpn_reg_target") bbox_weight = X.var("rpn_reg_weight") # classification loss cls_logit_reshape = X.reshape( cls_logit, shape=(0, -4, num_class, -1, 0, 0), # (N,C,H,W) -> (N,num_class,C/num_class,H,W) name="rpn_cls_logit_reshape") cls_loss = None if p.use_groupsoftmax: cls_loss = mx.sym.contrib.GroupSoftmaxOutput( data=cls_logit_reshape, label=cls_label, group=rpn_groups, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") else: cls_loss = X.softmax_output(data=cls_logit_reshape, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") # regression loss reg_loss = X.smooth_l1((bbox_delta - bbox_target), scalar=3.0, name='rpn_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss(reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss') return cls_loss, reg_loss
def IoULoss(x_box, y_box, ignore_offset, centerness_label, name='iouloss'): centerness_label = mx.sym.reshape(centerness_label, shape=(0, 1, -1)) y_box = X.block_grad(y_box) target_left = mx.sym.slice_axis(y_box, axis=1, begin=0, end=1) target_top = mx.sym.slice_axis(y_box, axis=1, begin=1, end=2) target_right = mx.sym.slice_axis(y_box, axis=1, begin=2, end=3) target_bottom = mx.sym.slice_axis(y_box, axis=1, begin=3, end=4) # filter out out-of-bbox area, loss is only computed inside bboxes nonignore_mask = mx.sym.broadcast_logical_and( lhs=mx.sym.broadcast_not_equal(lhs=target_left, rhs=ignore_offset), rhs=mx.sym.broadcast_greater(lhs=centerness_label, rhs=mx.sym.full((1, 1, 1), 0))) nonignore_mask = X.block_grad(nonignore_mask) x_box = mx.sym.clip(x_box, a_min=0, a_max=1e4) x_box = mx.sym.broadcast_mul(lhs=x_box, rhs=nonignore_mask) centerness_label = centerness_label * nonignore_mask pred_left = mx.sym.slice_axis(x_box, axis=1, begin=0, end=1) pred_top = mx.sym.slice_axis(x_box, axis=1, begin=1, end=2) pred_right = mx.sym.slice_axis(x_box, axis=1, begin=2, end=3) pred_bottom = mx.sym.slice_axis(x_box, axis=1, begin=3, end=4) target_area = (target_left + target_right) * (target_top + target_bottom) pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) w_intersect = mx.sym.min( mx.sym.stack(pred_left, target_left, axis=0), axis=0) + mx.sym.min( mx.sym.stack(pred_right, target_right, axis=0), axis=0) h_intersect = mx.sym.min( mx.sym.stack(pred_bottom, target_bottom, axis=0), axis=0) + mx.sym.min( mx.sym.stack(pred_top, target_top, axis=0), axis=0) area_intersect = w_intersect * h_intersect area_union = (target_area + pred_area - area_intersect) loss = -mx.sym.log((area_intersect + 1.0) / (area_union + 1.0)) loss = mx.sym.broadcast_mul(lhs=loss, rhs=centerness_label) loss = mx.sym.sum(loss) / (mx.sym.sum(centerness_label) + 1e-30) return X.loss(loss, grad_scale=1, name=name)
def get_loss(self, conv_fpn_feat, gt_bbox, im_info): p = self.p batch_image = p.batch_image image_anchor = p.anchor_assign.image_anchor rpn_stride = p.anchor_generate.stride anchor_scale = p.anchor_generate.scale anchor_ratio = p.anchor_generate.ratio num_anchor = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 rpn_cls_logit_list = [] rpn_bbox_delta_list = [] feat_list = [] for stride in rpn_stride: rpn_cls_logit = cls_logit_dict[stride] rpn_bbox_delta = bbox_delta_dict[stride] rpn_cls_logit_reshape = X.reshape( data=rpn_cls_logit, shape=(0, 2, num_anchor, -1), name="rpn_cls_score_reshape_stride%s" % stride) rpn_bbox_delta_reshape = X.reshape( data=rpn_bbox_delta, shape=(0, 0, -1), name="rpn_bbox_pred_reshape_stride%s" % stride) rpn_bbox_delta_list.append(rpn_bbox_delta_reshape) rpn_cls_logit_list.append(rpn_cls_logit_reshape) feat_list.append(rpn_cls_logit) if p.nnvm_rpn_target: from mxnext.tvm.rpn_target import _fpn_rpn_target_batch anchor_list = [ self.anchor_dict["stride%s" % s] for s in rpn_stride ] gt_bbox = mx.sym.slice_axis(gt_bbox, axis=-1, begin=0, end=4) max_side = p.anchor_generate.max_side allowed_border = p.anchor_assign.allowed_border fg_fraction = p.anchor_assign.pos_fraction fg_thr = p.anchor_assign.pos_thr bg_thr = p.anchor_assign.neg_thr cls_label, bbox_target, bbox_weight = _fpn_rpn_target_batch( mx.sym, feat_list, anchor_list, gt_bbox, im_info, batch_image, num_anchor, max_side, rpn_stride, allowed_border, image_anchor, fg_fraction, fg_thr, bg_thr) else: cls_label = X.var("rpn_cls_label") bbox_target = X.var("rpn_reg_target") bbox_weight = X.var("rpn_reg_weight") # concat output of each level rpn_bbox_delta_concat = X.concat(rpn_bbox_delta_list, axis=2, name="rpn_bbox_pred_concat") rpn_cls_logit_concat = X.concat(rpn_cls_logit_list, axis=-1, name="rpn_cls_score_concat") cls_loss = X.softmax_output(data=rpn_cls_logit_concat, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") # regression loss reg_loss = X.smooth_l1((rpn_bbox_delta_concat - bbox_target), scalar=3.0, name='rpn_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss(reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss') return cls_loss, reg_loss, X.stop_grad(cls_label, "rpn_cls_label_blockgrad")
def reg_pc_loss(self, bbox_delta, tsd_bbox_delta, rois, tsd_rois, gt_bbox, gt_label, scale_loss_shift): ''' TSD regression progressive constraint Args: bbox_delta: [batch_image * image_roi, num_class*4] tsd_bbox_delta: [batch_image * image_roi, num_class*4] rois: [batch_image, image_roi, 4] rois_r: [batch_image, image_roi, 4] gt_bbox: [batch_image, max_gt_num, 4] gt_label: [batch_image * image_roi] scale_loss_shift: float Returns: loss: [batch_image * image_roi] ''' def _box_decode(rois, deltas, means, stds): rois = X.block_grad(rois) rois = mx.sym.reshape(rois, [-1, 4]) deltas = mx.sym.reshape(deltas, [-1, 4]) x1, y1, x2, y2 = mx.sym.split(rois, axis=-1, num_outputs=4, squeeze_axis=True) dx, dy, dw, dh = mx.sym.split(deltas, axis=-1, num_outputs=4, squeeze_axis=True) dx = dx * stds[0] + means[0] dy = dy * stds[1] + means[1] dw = dw * stds[2] + means[2] dh = dh * stds[3] + means[3] x = (x1 + x2) * 0.5 y = (y1 + y2) * 0.5 w = x2 - x1 + 1 h = y2 - y1 + 1 nx = x + dx * w ny = y + dy * h nw = w * mx.sym.exp(dw) nh = h * mx.sym.exp(dh) nx1 = nx - 0.5 * nw ny1 = ny - 0.5 * nh nx2 = nx + 0.5 * nw ny2 = ny + 0.5 * nh return mx.sym.stack(nx1, ny1, nx2, ny2, axis=1, name='pc_reg_loss_decoded_roi') def _gather_3d(data, indices, n): datas = mx.sym.split(data, axis=-1, num_outputs=n, squeeze_axis=True) outputs = [] for d in datas: outputs.append(mx.sym.pick(d, indices, axis=1)) return mx.sym.stack(*outputs, axis=1) batch_image = self.p.batch_image image_roi = self.p.image_roi batch_roi = batch_image * image_roi num_class = self.p.num_class bbox_mean = self.p.regress_target.mean bbox_std = self.p.regress_target.std margin = self.p.TSD.pc_reg_margin gt_label = mx.sym.reshape(gt_label, (-1, )) bbox_delta = mx.sym.reshape(bbox_delta, (batch_image * image_roi, num_class, 4)) tsd_bbox_delta = mx.sym.reshape( tsd_bbox_delta, (batch_image * image_roi, num_class, 4)) bbox_delta = _gather_3d(bbox_delta, gt_label, n=4) tsd_bbox_delta = _gather_3d(tsd_bbox_delta, gt_label, n=4) boxes = _box_decode(rois, bbox_delta, bbox_mean, bbox_std) tsd_bboxes = _box_decode(tsd_rois, tsd_bbox_delta, bbox_mean, bbox_std) rois = mx.sym.reshape(rois, [batch_image, -1, 4]) tsd_rois = mx.sym.reshape(tsd_rois, [batch_image, -1, 4]) boxes = mx.sym.reshape(boxes, [batch_image, -1, 4]) tsd_bboxes = mx.sym.reshape(tsd_bboxes, [batch_image, -1, 4]) rois_group = mx.sym.split(rois, axis=0, num_outputs=batch_image, squeeze_axis=True) tsd_rois_group = mx.sym.split(tsd_rois, axis=0, num_outputs=batch_image, squeeze_axis=True) boxes_group = mx.sym.split(boxes, axis=0, num_outputs=batch_image, squeeze_axis=True) tsd_bboxes_group = mx.sym.split(tsd_bboxes, axis=0, num_outputs=batch_image, squeeze_axis=True) gt_group = mx.sym.split(gt_bbox, axis=0, num_outputs=batch_image, squeeze_axis=True) ious = [] tsd_ious = [] for i, (rois_i, tsd_rois_i, boxes_i, tsd_boxes_i, gt_i) in \ enumerate(zip(rois_group, tsd_rois_group, boxes_group, tsd_bboxes_group, gt_group)): iou_mat = get_iou_mat(rois_i, gt_i, image_roi) tsd_iou_mat = get_iou_mat(tsd_rois_i, gt_i, image_roi) matched_gt = mx.sym.gather_nd( gt_i, X.reshape(mx.sym.argmax(iou_mat, axis=1), [1, -1])) tsd_matched_gt = mx.sym.gather_nd( gt_i, X.reshape(mx.sym.argmax(tsd_iou_mat, axis=1), [1, -1])) matched_gt = mx.sym.slice_axis(matched_gt, axis=-1, begin=0, end=4) tsd_matched_gt = mx.sym.slice_axis(tsd_matched_gt, axis=-1, begin=0, end=4) ious.append(get_iou(boxes_i, matched_gt)) tsd_ious.append(get_iou(tsd_boxes_i, tsd_matched_gt)) iou = mx.sym.concat(*ious, dim=0) tsd_iou = mx.sym.concat(*tsd_ious, dim=0) weight = X.block_grad(gt_label != 0) iou = X.block_grad(iou) reg_pc_margin = mx.sym.minimum(1. - iou, margin) loss = mx.sym.relu(-(tsd_iou - iou - reg_pc_margin)) grad_scale = 1. / batch_roi grad_scale *= scale_loss_shift loss = X.loss(weight * loss, grad_scale=grad_scale, name='reg_pc_loss') return loss
def get_loss(self, rois, roi_feat, fpn_conv_feats, cls_label, bbox_target, bbox_weight, gt_bbox): ''' Args: rois: [batch_image, image_roi, 4] roi_feat: [batch_image * image_roi, 256, roi_size, roi_size] fpn_conv_feats: dict of FPN features, each [batch_image, in_channels, fh, fw] cls_label: [batch_image * image_roi] bbox_target: [batch_image * image_roi, num_class * 4] bbox_weight: [batch_image * image_roi, num_class * 4] gt_bbox: [batch_image, max_gt_num, 4] Returns: cls_loss: [batch_image * image_roi, num_class] reg_loss: [batch_image * image_roi, num_class * 4] tsd_cls_loss: [batch_image * image_roi, num_class] tsd_reg_loss: [batch_image * image_roi, num_class * 4] tsd_cls_pc_loss: [batch_image * image_roi] tsd_reg_pc_loss: [batch_image * image_roi] cls_label: [batch_image, image_roi] ''' p = self.p assert not p.regress_target.class_agnostic batch_image = p.batch_image image_roi = p.image_roi batch_roi = batch_image * image_roi smooth_l1_scalar = p.regress_target.smooth_l1_scalar or 1.0 cls_logit, bbox_delta, tsd_cls_logit, tsd_bbox_delta, delta_c, delta_r = self.get_output( fpn_conv_feats, roi_feat, rois, is_train=True) rois_r = self._get_delta_r_box(delta_r, rois) tsd_reg_target = self.get_reg_target( rois_r, gt_bbox) # [batch_roi, num_class*4] scale_loss_shift = 128 if self.p.fp16 else 1.0 # origin loss cls_loss = X.softmax_output(data=cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='bbox_cls_loss') reg_loss = X.smooth_l1(bbox_delta - bbox_target, scalar=smooth_l1_scalar, name='bbox_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='bbox_reg_loss', ) # tsd loss tsd_cls_loss = X.softmax_output(data=tsd_cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='tsd_bbox_cls_loss') tsd_reg_loss = X.smooth_l1(tsd_bbox_delta - tsd_reg_target, scalar=smooth_l1_scalar, name='tsd_bbox_reg_l1') tsd_reg_loss = bbox_weight * tsd_reg_loss tsd_reg_loss = X.loss( tsd_reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='tsd_bbox_reg_loss', ) losses = [ cls_loss, reg_loss, tsd_cls_loss, tsd_reg_loss, tsd_cls_pc_loss ] if p.TSD.pc_cls: losses.append( self.cls_pc_loss(cls_logit, tsd_cls_logit, cls_label, scale_loss_shift)) if p.TSD.pc_reg: losses.append( self.reg_pc_loss(bbox_delta, tsd_bbox_delta, rois, rois_r, gt_bbox, cls_label, scale_loss_shift)) # append label cls_label = X.reshape(cls_label, shape=(batch_image, -1), name='bbox_label_reshape') cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad') losses.append(cls_label) return tuple(losses)
def emd_loss(self, cls_logit, cls_label, cls_sec_logit, cls_sec_label, bbox_delta, bbox_target, bbox_sec_delta, bbox_sec_target, bbox_weight, bbox_sec_weight, prefix=""): p = self.p smooth_l1_scalar = p.regress_target.smooth_l1_scalar or 1.0 scale_loss_shift = 128.0 if p.fp16 else 1.0 cls_loss11 = self.softmax_entropy(cls_logit, cls_label, prefix=prefix + 'cls_loss11') cls_loss12 = self.softmax_entropy(cls_sec_logit, cls_sec_label, prefix=prefix + 'cls_loss12') cls_loss1 = cls_loss11 + cls_loss12 cls_loss21 = self.softmax_entropy(cls_logit, cls_sec_label, prefix=prefix + 'cls_loss21') cls_loss22 = self.softmax_entropy(cls_sec_logit, cls_label, prefix=prefix + 'cls_loss22') cls_loss2 = cls_loss21 + cls_loss22 # bounding box regression reg_loss11 = X.smooth_l1(bbox_delta - bbox_target, scalar=smooth_l1_scalar, name=prefix + 'bbox_reg_l1_11') reg_loss11 = bbox_weight * reg_loss11 reg_loss12 = X.smooth_l1(bbox_sec_delta - bbox_sec_target, scalar=smooth_l1_scalar, name=prefix + 'bbox_reg_l1_12') reg_loss12 = bbox_sec_weight * reg_loss12 reg_loss1 = reg_loss11 + reg_loss12 reg_loss21 = X.smooth_l1(bbox_delta - bbox_sec_target, scalar=smooth_l1_scalar, name=prefix + 'bbox_reg_l1_21') reg_loss21 = bbox_sec_weight * reg_loss21 reg_loss22 = X.smooth_l1(bbox_sec_delta - bbox_target, scalar=smooth_l1_scalar, name=prefix + 'bbox_reg_l1_22') reg_loss22 = bbox_weight * reg_loss22 reg_loss2 = reg_loss21 + reg_loss22 cls_reg_loss1 = mx.sym.sum(cls_loss1, axis=-1) + mx.sym.sum(reg_loss1, axis=-1) cls_reg_loss2 = mx.sym.sum(cls_loss2, axis=-1) + mx.sym.sum(reg_loss2, axis=-1) cls_reg_loss = mx.sym.minimum(cls_reg_loss1, cls_reg_loss2) cls_reg_loss = mx.sym.mean(cls_reg_loss) cls_reg_loss = X.loss(cls_reg_loss, grad_scale=1.0 * scale_loss_shift, name=prefix + 'cls_reg_loss') return cls_reg_loss