def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): p = self.p batch_roi = p.image_roi * p.batch_image batch_image = p.batch_image cls_logit, bbox_delta = self.get_output(conv_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 # classification loss cls_loss = X.softmax_output(data=cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='bbox_cls_loss') # bounding box regression reg_loss = X.smooth_l1(bbox_delta - bbox_target, scalar=1.0, name='bbox_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='bbox_reg_loss', ) # append label cls_label = X.reshape(cls_label, shape=(batch_image, -1), name='bbox_label_reshape') cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad') # output return cls_loss, reg_loss, cls_label
def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): p = self.p batch_image = p.batch_image image_anchor = p.anchor_generate.image_anchor cls_logit, bbox_delta = self.get_output(conv_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 # classification loss cls_logit_reshape = X.reshape( cls_logit, shape=(0, -4, 2, -1, 0, 0), # (N,C,H,W) -> (N,2,C/2,H,W) name="rpn_cls_logit_reshape") cls_loss = X.softmax_output(data=cls_logit_reshape, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") # regression loss reg_loss = X.smooth_l1((bbox_delta - bbox_target), scalar=3.0, name='rpn_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss(reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss') return cls_loss, reg_loss
def get_loss(self, conv_fpn_feat, cls_label, bbox_target, bbox_weight): p = self.p batch_image = p.batch_image image_anchor = p.anchor_generate.image_anchor rpn_stride = p.anchor_generate.stride cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 rpn_cls_logit_list = [] rpn_bbox_delta_list = [] for stride in rpn_stride: rpn_cls_logit = cls_logit_dict[stride] rpn_bbox_delta = bbox_delta_dict[stride] rpn_cls_logit_reshape = X.reshape( data=rpn_cls_logit, shape=(0, 2, -1), name="rpn_cls_score_reshape_stride%s" % stride ) rpn_bbox_delta_reshape = X.reshape( data=rpn_bbox_delta, shape=(0, 0, -1), name="rpn_bbox_pred_reshape_stride%s" % stride ) rpn_bbox_delta_list.append(rpn_bbox_delta_reshape) rpn_cls_logit_list.append(rpn_cls_logit_reshape) # concat output of each level rpn_bbox_delta_concat = X.concat(rpn_bbox_delta_list, axis=2, name="rpn_bbox_pred_concat") rpn_cls_logit_concat = X.concat(rpn_cls_logit_list, axis=2, name="rpn_cls_score_concat") cls_loss = X.softmax_output( data=rpn_cls_logit_concat, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss" ) # regression loss reg_loss = X.smooth_l1( (rpn_bbox_delta_concat - bbox_target), scalar=3.0, name='rpn_reg_l1' ) reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss' ) return cls_loss, reg_loss
def get_loss(self, conv_feat, gt_bboxes, im_infos, rpn_groups): p = self.p num_class = p.num_class batch_image = p.batch_image image_anchor = p.anchor_generate.image_anchor cls_logit, bbox_delta = self.get_output(conv_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 cls_label = X.var("rpn_cls_label") bbox_target = X.var("rpn_reg_target") bbox_weight = X.var("rpn_reg_weight") # classification loss cls_logit_reshape = X.reshape( cls_logit, shape=(0, -4, num_class, -1, 0, 0), # (N,C,H,W) -> (N,num_class,C/num_class,H,W) name="rpn_cls_logit_reshape") cls_loss = None if p.use_groupsoftmax: cls_loss = mx.sym.contrib.GroupSoftmaxOutput( data=cls_logit_reshape, label=cls_label, group=rpn_groups, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") else: cls_loss = X.softmax_output(data=cls_logit_reshape, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") # regression loss reg_loss = X.smooth_l1((bbox_delta - bbox_target), scalar=3.0, name='rpn_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss(reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss') return cls_loss, reg_loss
def get_loss(self, conv_fpn_feat, gt_bbox, im_info): p = self.p batch_image = p.batch_image image_anchor = p.anchor_assign.image_anchor rpn_stride = p.anchor_generate.stride anchor_scale = p.anchor_generate.scale anchor_ratio = p.anchor_generate.ratio num_anchor = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 rpn_cls_logit_list = [] rpn_bbox_delta_list = [] feat_list = [] for stride in rpn_stride: rpn_cls_logit = cls_logit_dict[stride] rpn_bbox_delta = bbox_delta_dict[stride] rpn_cls_logit_reshape = X.reshape( data=rpn_cls_logit, shape=(0, 2, num_anchor, -1), name="rpn_cls_score_reshape_stride%s" % stride) rpn_bbox_delta_reshape = X.reshape( data=rpn_bbox_delta, shape=(0, 0, -1), name="rpn_bbox_pred_reshape_stride%s" % stride) rpn_bbox_delta_list.append(rpn_bbox_delta_reshape) rpn_cls_logit_list.append(rpn_cls_logit_reshape) feat_list.append(rpn_cls_logit) if p.nnvm_rpn_target: from mxnext.tvm.rpn_target import _fpn_rpn_target_batch anchor_list = [ self.anchor_dict["stride%s" % s] for s in rpn_stride ] gt_bbox = mx.sym.slice_axis(gt_bbox, axis=-1, begin=0, end=4) max_side = p.anchor_generate.max_side allowed_border = p.anchor_assign.allowed_border fg_fraction = p.anchor_assign.pos_fraction fg_thr = p.anchor_assign.pos_thr bg_thr = p.anchor_assign.neg_thr cls_label, bbox_target, bbox_weight = _fpn_rpn_target_batch( mx.sym, feat_list, anchor_list, gt_bbox, im_info, batch_image, num_anchor, max_side, rpn_stride, allowed_border, image_anchor, fg_fraction, fg_thr, bg_thr) else: cls_label = X.var("rpn_cls_label") bbox_target = X.var("rpn_reg_target") bbox_weight = X.var("rpn_reg_weight") # concat output of each level rpn_bbox_delta_concat = X.concat(rpn_bbox_delta_list, axis=2, name="rpn_bbox_pred_concat") rpn_cls_logit_concat = X.concat(rpn_cls_logit_list, axis=-1, name="rpn_cls_score_concat") cls_loss = X.softmax_output(data=rpn_cls_logit_concat, label=cls_label, multi_output=True, normalization='valid', use_ignore=True, ignore_label=-1, grad_scale=1.0 * scale_loss_shift, name="rpn_cls_loss") # regression loss reg_loss = X.smooth_l1((rpn_bbox_delta_concat - bbox_target), scalar=3.0, name='rpn_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss(reg_loss, grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift, name='rpn_reg_loss') return cls_loss, reg_loss, X.stop_grad(cls_label, "rpn_cls_label_blockgrad")
def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): import mxnet as mx p = self.p stride = p.anchor_generate.stride if not isinstance(stride, tuple): stride = (stride) num_class = p.num_class num_base_anchor = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) image_per_device = p.batch_image sync_loss = p.sync_loss or False cls_logit_dict, bbox_delta_dict = self.get_output(conv_feat) cls_logit_reshape_list = [] bbox_delta_reshape_list = [] scale_loss_shift = 128.0 if p.fp16 else 1.0 if sync_loss: fg_count = X.var("rpn_fg_count") * image_per_device fg_count = mx.sym.slice_axis(fg_count, axis=0, begin=0, end=1) # reshape logit and delta for s in stride: # (N, A * C, H, W) -> (N, A, C, H * W) cls_logit = X.reshape(data=cls_logit_dict["stride%s" % s], shape=(0, num_base_anchor, num_class - 1, -1), name="cls_stride%s_reshape" % s) # (N, A, C, H * W) -> (N, A, H * W, C) cls_logit = X.transpose(data=cls_logit, axes=(0, 1, 3, 2), name="cls_stride%s_transpose" % s) # (N, A, H * W, C) -> (N, A * H * W, C) cls_logit = X.reshape(data=cls_logit, shape=(0, -3, 0), name="cls_stride%s_transpose_reshape" % s) # (N, A * 4, H, W) -> (N, A * 4, H * W) bbox_delta = X.reshape(data=bbox_delta_dict["stride%s" % s], shape=(0, 0, -1), name="bbox_stride%s_reshape" % s) cls_logit_reshape_list.append(cls_logit) bbox_delta_reshape_list.append(bbox_delta) cls_logit_concat = X.concat(cls_logit_reshape_list, axis=1, name="bbox_logit_concat") bbox_delta_concat = X.concat(bbox_delta_reshape_list, axis=2, name="bbox_delta_concat") # classification loss if sync_loss: cls_loss = X.focal_loss(data=cls_logit_concat, label=cls_label, alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, workspace=1500, out_grad=True) cls_loss = mx.sym.broadcast_div(cls_loss, fg_count) cls_loss = X.make_loss(cls_loss, grad_scale=scale_loss_shift, name="cls_loss") else: cls_loss = X.focal_loss(data=cls_logit_concat, label=cls_label, normalization='valid', alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, grad_scale=1.0 * scale_loss_shift, workspace=1024, name="cls_loss") scalar = 0.11 # regression loss bbox_loss = bbox_weight * X.smooth_l1( data=bbox_delta_concat - bbox_target, scalar=math.sqrt(1 / scalar), name="bbox_loss") if sync_loss: bbox_loss = mx.sym.broadcast_div(bbox_loss, fg_count) else: bbox_loss = X.bbox_norm(data=bbox_loss, label=cls_label, name="bbox_norm") reg_loss = X.make_loss(data=bbox_loss, grad_scale=1.0 * scale_loss_shift, name="reg_loss") return cls_loss, reg_loss
def get_loss(self, rois, roi_feat, fpn_conv_feats, cls_label, bbox_target, bbox_weight, gt_bbox): ''' Args: rois: [batch_image, image_roi, 4] roi_feat: [batch_image * image_roi, 256, roi_size, roi_size] fpn_conv_feats: dict of FPN features, each [batch_image, in_channels, fh, fw] cls_label: [batch_image * image_roi] bbox_target: [batch_image * image_roi, num_class * 4] bbox_weight: [batch_image * image_roi, num_class * 4] gt_bbox: [batch_image, max_gt_num, 4] Returns: cls_loss: [batch_image * image_roi, num_class] reg_loss: [batch_image * image_roi, num_class * 4] tsd_cls_loss: [batch_image * image_roi, num_class] tsd_reg_loss: [batch_image * image_roi, num_class * 4] tsd_cls_pc_loss: [batch_image * image_roi] tsd_reg_pc_loss: [batch_image * image_roi] cls_label: [batch_image, image_roi] ''' p = self.p assert not p.regress_target.class_agnostic batch_image = p.batch_image image_roi = p.image_roi batch_roi = batch_image * image_roi smooth_l1_scalar = p.regress_target.smooth_l1_scalar or 1.0 cls_logit, bbox_delta, tsd_cls_logit, tsd_bbox_delta, delta_c, delta_r = self.get_output( fpn_conv_feats, roi_feat, rois, is_train=True) rois_r = self._get_delta_r_box(delta_r, rois) tsd_reg_target = self.get_reg_target( rois_r, gt_bbox) # [batch_roi, num_class*4] scale_loss_shift = 128 if self.p.fp16 else 1.0 # origin loss cls_loss = X.softmax_output(data=cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='bbox_cls_loss') reg_loss = X.smooth_l1(bbox_delta - bbox_target, scalar=smooth_l1_scalar, name='bbox_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='bbox_reg_loss', ) # tsd loss tsd_cls_loss = X.softmax_output(data=tsd_cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='tsd_bbox_cls_loss') tsd_reg_loss = X.smooth_l1(tsd_bbox_delta - tsd_reg_target, scalar=smooth_l1_scalar, name='tsd_bbox_reg_l1') tsd_reg_loss = bbox_weight * tsd_reg_loss tsd_reg_loss = X.loss( tsd_reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='tsd_bbox_reg_loss', ) losses = [ cls_loss, reg_loss, tsd_cls_loss, tsd_reg_loss, tsd_cls_pc_loss ] if p.TSD.pc_cls: losses.append( self.cls_pc_loss(cls_logit, tsd_cls_logit, cls_label, scale_loss_shift)) if p.TSD.pc_reg: losses.append( self.reg_pc_loss(bbox_delta, tsd_bbox_delta, rois, rois_r, gt_bbox, cls_label, scale_loss_shift)) # append label cls_label = X.reshape(cls_label, shape=(batch_image, -1), name='bbox_label_reshape') cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad') losses.append(cls_label) return tuple(losses)
def get_loss(self, conv_feat, gt_bbox): from models.RepPoints.point_ops import (_gen_points, _offset_to_pts, _point_target, _offset_to_boxes, _points2bbox) p = self.p batch_image = p.batch_image num_points = p.point_generate.num_points scale = p.point_generate.scale stride = p.point_generate.stride transform = p.point_generate.transform target_scale = p.point_target.target_scale num_pos = p.point_target.num_pos pos_iou_thr = p.bbox_target.pos_iou_thr neg_iou_thr = p.bbox_target.neg_iou_thr min_pos_iou = p.bbox_target.min_pos_iou pts_out_inits, pts_out_refines, cls_outs = self.get_output(conv_feat) points = dict() bboxes = dict() pts_coordinate_preds_inits = dict() pts_coordinate_preds_refines = dict() for s in stride: # generate points on base coordinate according to stride and size of feature map points["stride%s" % s] = _gen_points(mx.symbol, pts_out_inits["stride%s" % s], s) # generate bbox after init stage bboxes["stride%s" % s] = _offset_to_boxes( mx.symbol, points["stride%s" % s], X.block_grad(pts_out_inits["stride%s" % s]), s, transform, moment_transfer=self.moment_transfer) # generate final offsets in init stage pts_coordinate_preds_inits["stride%s" % s] = _offset_to_pts( mx.symbol, points["stride%s" % s], pts_out_inits["stride%s" % s], s, num_points) # generate final offsets in refine stage pts_coordinate_preds_refines["stride%s" % s] = _offset_to_pts( mx.symbol, points["stride%s" % s], pts_out_refines["stride%s" % s], s, num_points) # for init stage, use points assignment point_proposals = mx.symbol.tile(X.concat( [points["stride%s" % s] for s in stride], axis=1, name="point_concat"), reps=(batch_image, 1, 1)) points_labels_init, points_gts_init, points_weight_init = _point_target( mx.symbol, point_proposals, gt_bbox, batch_image, "point", scale=target_scale, num_pos=num_pos) # for refine stage, use max iou assignment box_proposals = X.concat([bboxes["stride%s" % s] for s in stride], axis=1, name="box_concat") points_labels_refine, points_gts_refine, points_weight_refine = _point_target( mx.symbol, box_proposals, gt_bbox, batch_image, "box", pos_iou_thr=pos_iou_thr, neg_iou_thr=neg_iou_thr, min_pos_iou=min_pos_iou) bboxes_out_strides = dict() for s in stride: cls_outs["stride%s" % s] = X.reshape( X.transpose(data=cls_outs["stride%s" % s], axes=(0, 2, 3, 1)), (0, -3, -2)) bboxes_out_strides["stride%s" % s] = mx.symbol.repeat( mx.symbol.ones_like( mx.symbol.slice_axis( cls_outs["stride%s" % s], begin=0, end=1, axis=-1)), repeats=4, axis=-1) * s # cls branch cls_outs_concat = X.concat([cls_outs["stride%s" % s] for s in stride], axis=1, name="cls_concat") cls_loss = X.focal_loss(data=cls_outs_concat, label=points_labels_refine, normalization='valid', alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, grad_scale=1.0, workspace=1500, name="cls_loss") # init box branch pts_inits_concat_ = X.concat( [pts_coordinate_preds_inits["stride%s" % s] for s in stride], axis=1, name="pts_init_concat_") pts_inits_concat = X.reshape(pts_inits_concat_, (-3, -2), name="pts_inits_concat") bboxes_inits_concat_ = _points2bbox( mx.symbol, pts_inits_concat, transform, y_first=False, moment_transfer=self.moment_transfer) bboxes_inits_concat = X.reshape(bboxes_inits_concat_, (-4, batch_image, -1, -2)) normalize_term = X.concat( [bboxes_out_strides["stride%s" % s] for s in stride], axis=1, name="normalize_term") * scale pts_init_loss = X.smooth_l1( data=(bboxes_inits_concat - points_gts_init) / normalize_term, scalar=3.0, name="pts_init_l1_loss") pts_init_loss = pts_init_loss * points_weight_init pts_init_loss = X.bbox_norm(data=pts_init_loss, label=points_labels_init, name="pts_init_norm_loss") pts_init_loss = X.make_loss(data=pts_init_loss, grad_scale=0.5, name="pts_init_loss") points_init_labels = X.block_grad(points_labels_refine, name="points_init_labels") # refine box branch pts_refines_concat_ = X.concat( [pts_coordinate_preds_refines["stride%s" % s] for s in stride], axis=1, name="pts_refines_concat_") pts_refines_concat = X.reshape(pts_refines_concat_, (-3, -2), name="pts_refines_concat") bboxes_refines_concat_ = _points2bbox( mx.symbol, pts_refines_concat, transform, y_first=False, moment_transfer=self.moment_transfer) bboxes_refines_concat = X.reshape(bboxes_refines_concat_, (-4, batch_image, -1, -2)) pts_refine_loss = X.smooth_l1( data=(bboxes_refines_concat - points_gts_refine) / normalize_term, scalar=3.0, name="pts_refine_l1_loss") pts_refine_loss = pts_refine_loss * points_weight_refine pts_refine_loss = X.bbox_norm(data=pts_refine_loss, label=points_labels_refine, name="pts_refine_norm_loss") pts_refine_loss = X.make_loss(data=pts_refine_loss, grad_scale=1.0, name="pts_refine_loss") points_refine_labels = X.block_grad(points_labels_refine, name="point_refine_labels") return cls_loss, pts_init_loss, pts_refine_loss, points_init_labels, points_refine_labels
def emd_loss(self, cls_logit, cls_label, cls_sec_logit, cls_sec_label, bbox_delta, bbox_target, bbox_sec_delta, bbox_sec_target, bbox_weight, bbox_sec_weight, prefix=""): p = self.p smooth_l1_scalar = p.regress_target.smooth_l1_scalar or 1.0 scale_loss_shift = 128.0 if p.fp16 else 1.0 cls_loss11 = self.softmax_entropy(cls_logit, cls_label, prefix=prefix + 'cls_loss11') cls_loss12 = self.softmax_entropy(cls_sec_logit, cls_sec_label, prefix=prefix + 'cls_loss12') cls_loss1 = cls_loss11 + cls_loss12 cls_loss21 = self.softmax_entropy(cls_logit, cls_sec_label, prefix=prefix + 'cls_loss21') cls_loss22 = self.softmax_entropy(cls_sec_logit, cls_label, prefix=prefix + 'cls_loss22') cls_loss2 = cls_loss21 + cls_loss22 # bounding box regression reg_loss11 = X.smooth_l1(bbox_delta - bbox_target, scalar=smooth_l1_scalar, name=prefix + 'bbox_reg_l1_11') reg_loss11 = bbox_weight * reg_loss11 reg_loss12 = X.smooth_l1(bbox_sec_delta - bbox_sec_target, scalar=smooth_l1_scalar, name=prefix + 'bbox_reg_l1_12') reg_loss12 = bbox_sec_weight * reg_loss12 reg_loss1 = reg_loss11 + reg_loss12 reg_loss21 = X.smooth_l1(bbox_delta - bbox_sec_target, scalar=smooth_l1_scalar, name=prefix + 'bbox_reg_l1_21') reg_loss21 = bbox_sec_weight * reg_loss21 reg_loss22 = X.smooth_l1(bbox_sec_delta - bbox_target, scalar=smooth_l1_scalar, name=prefix + 'bbox_reg_l1_22') reg_loss22 = bbox_weight * reg_loss22 reg_loss2 = reg_loss21 + reg_loss22 cls_reg_loss1 = mx.sym.sum(cls_loss1, axis=-1) + mx.sym.sum(reg_loss1, axis=-1) cls_reg_loss2 = mx.sym.sum(cls_loss2, axis=-1) + mx.sym.sum(reg_loss2, axis=-1) cls_reg_loss = mx.sym.minimum(cls_reg_loss1, cls_reg_loss2) cls_reg_loss = mx.sym.mean(cls_reg_loss) cls_reg_loss = X.loss(cls_reg_loss, grad_scale=1.0 * scale_loss_shift, name=prefix + 'cls_reg_loss') return cls_reg_loss
def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): p = self.p stride = p.anchor_generate.stride if not isinstance(stride, tuple): stride = (stride) num_class = p.num_class num_base_anchor = len(p.anchor_generate.ratio) * len( p.anchor_generate.scale) cls_logit_list, bbox_delta_list = self.get_output(conv_feat) # reshape logit and delta for i, s in enumerate(stride): # (N, A * C, H, W) -> (N, A, C, H * W) cls_logit = X.reshape(data=cls_logit_list[i], shape=(0, num_base_anchor, num_class - 1, -1), name="cls_stride%s_reshape" % s) # (N, A, C, H * W) -> (N, A, H * W, C) cls_logit = X.transpose(data=cls_logit, axes=(0, 1, 3, 2), name="cls_stride%s_transpose" % s) # (N, A, H * W, C) -> (N, A * H * W, C) cls_logit = X.reshape(data=cls_logit, shape=(0, -3, 0), name="cls_stride%s_transpose_reshape" % s) # (N, A * 4, H, W) -> (N, A * 4, H * W) bbox_delta = X.reshape(data=bbox_delta_list[i], shape=(0, 0, -1), name="bbox_stride%s_reshape" % s) cls_logit_list[i] = cls_logit bbox_delta_list[i] = bbox_delta cls_logit_concat = X.concat(cls_logit_list, axis=1, name="bbox_logit_concat") bbox_delta_concat = X.concat(bbox_delta_list, axis=2, name="bbox_delta_concat") # classification loss cls_loss = X.focal_loss(data=cls_logit_concat, label=cls_label, normalization='valid', alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, grad_scale=1.0, workspace=1024, name="cls_loss") scalar = 0.11 # regression loss bbox_norm = X.bbox_norm(data=bbox_delta_concat - bbox_target, label=cls_label, name="bbox_norm") bbox_loss = bbox_weight * X.smooth_l1( data=bbox_norm, scalar=math.sqrt(1 / scalar), name="bbox_loss") reg_loss = X.make_loss(data=bbox_loss, grad_scale=1.0, name="reg_loss") return cls_loss, reg_loss