def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight): p = self.p batch_roi = p.image_roi * p.batch_image batch_image = p.batch_image cls_logit, bbox_delta = self.get_output(conv_feat) scale_loss_shift = 128.0 if p.fp16 else 1.0 # classification loss cls_loss = X.softmax_output(data=cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='bbox_cls_loss') # bounding box regression reg_loss = X.smooth_l1(bbox_delta - bbox_target, scalar=1.0, name='bbox_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='bbox_reg_loss', ) # append label cls_label = X.reshape(cls_label, shape=(batch_image, -1), name='bbox_label_reshape') cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad') # output return cls_loss, reg_loss, cls_label
def cls_pc_loss(self, logits, tsd_logits, gt_label, scale_loss_shift): ''' TSD classification progressive constraint Args: logits: [batch_image * image_roi, num_class] tsd_logits: [batch_image * image_roi, num_class] gt_label: [batch_image * image_roi] scale_loss_shift: float Returns: loss: [batch_image * image_roi] ''' p = self.p batch_image = p.batch_image image_roi = p.image_roi batch_roi = batch_image * image_roi margin = self.p.TSD.pc_cls_margin cls_prob = mx.sym.SoftmaxActivation(logits, mode='instance') tsd_prob = mx.sym.SoftmaxActivation(tsd_logits, mode='instance') cls_score = mx.sym.pick(cls_prob, gt_label, axis=1) tsd_score = mx.sym.pick(tsd_prob, gt_label, axis=1) cls_score = X.block_grad(cls_score) cls_pc_margin = mx.sym.minimum(1. - cls_score, margin) loss = mx.sym.relu(-(tsd_score - cls_score - cls_pc_margin)) grad_scale = 1. / batch_roi grad_scale *= scale_loss_shift loss = X.loss(loss, grad_scale=grad_scale, name='cls_pc_loss') return loss
def make_binary_cross_entropy_loss(logits, labels, nonignore_mask): p = 1 / (1 + mx.sym.exp(-logits)) loss = -labels * mx.sym.log(mx.sym.clip(p, a_min=1e-5, a_max=1)) - ( 1 - labels) * mx.sym.log(mx.sym.clip(1 - p, a_min=1e-5, a_max=1)) loss = mx.sym.sum( loss * nonignore_mask) / (mx.sym.sum(nonignore_mask) + 1e-30) grad = mx.sym.broadcast_div(lhs=(p - labels) * nonignore_mask, rhs=mx.sym.sum(nonignore_mask) + 1e-30) loss = X.block_grad(loss) grad = X.block_grad(grad) return mx.sym.Custom(logits=logits, loss=loss, grad=grad, op_type='compute_bce_loss', name='sigmoid_bce_loss')
def IoULoss(x_box, y_box, ignore_offset, centerness_label, name='iouloss'): centerness_label = mx.sym.reshape(centerness_label, shape=(0, 1, -1)) y_box = X.block_grad(y_box) target_left = mx.sym.slice_axis(y_box, axis=1, begin=0, end=1) target_top = mx.sym.slice_axis(y_box, axis=1, begin=1, end=2) target_right = mx.sym.slice_axis(y_box, axis=1, begin=2, end=3) target_bottom = mx.sym.slice_axis(y_box, axis=1, begin=3, end=4) # filter out out-of-bbox area, loss is only computed inside bboxes nonignore_mask = mx.sym.broadcast_logical_and( lhs=mx.sym.broadcast_not_equal(lhs=target_left, rhs=ignore_offset), rhs=mx.sym.broadcast_greater(lhs=centerness_label, rhs=mx.sym.full((1, 1, 1), 0))) nonignore_mask = X.block_grad(nonignore_mask) x_box = mx.sym.clip(x_box, a_min=0, a_max=1e4) x_box = mx.sym.broadcast_mul(lhs=x_box, rhs=nonignore_mask) centerness_label = centerness_label * nonignore_mask pred_left = mx.sym.slice_axis(x_box, axis=1, begin=0, end=1) pred_top = mx.sym.slice_axis(x_box, axis=1, begin=1, end=2) pred_right = mx.sym.slice_axis(x_box, axis=1, begin=2, end=3) pred_bottom = mx.sym.slice_axis(x_box, axis=1, begin=3, end=4) target_area = (target_left + target_right) * (target_top + target_bottom) pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) w_intersect = mx.sym.min( mx.sym.stack(pred_left, target_left, axis=0), axis=0) + mx.sym.min( mx.sym.stack(pred_right, target_right, axis=0), axis=0) h_intersect = mx.sym.min( mx.sym.stack(pred_bottom, target_bottom, axis=0), axis=0) + mx.sym.min( mx.sym.stack(pred_top, target_top, axis=0), axis=0) area_intersect = w_intersect * h_intersect area_union = (target_area + pred_area - area_intersect) loss = -mx.sym.log((area_intersect + 1.0) / (area_union + 1.0)) loss = mx.sym.broadcast_mul(lhs=loss, rhs=centerness_label) loss = mx.sym.sum(loss) / (mx.sym.sum(centerness_label) + 1e-30) return X.loss(loss, grad_scale=1, name=name)
def make_sigmoid_focal_loss(gamma, alpha, logits, labels, nonignore_mask): # conduct most of calculations using symbol and control gradient flow with custom op p = 1 / (1 + mx.sym.exp(-logits)) # sigmoid mask_logits_GE_zero = mx.sym.broadcast_greater_equal(lhs=logits, rhs=mx.sym.zeros( (1, 1))) # logits>=0 minus_logits_mask = -1. * logits * mask_logits_GE_zero # -1 * logits * [logits>=0] negative_abs_logits = logits - 2 * logits * mask_logits_GE_zero # logtis - 2 * logits * [logits>=0] log_one_exp_minus_abs = mx.sym.log(1. + mx.sym.exp(negative_abs_logits)) minus_log = minus_logits_mask - log_one_exp_minus_abs alpha_one_p_gamma_labels = alpha * (1 - p)**gamma * labels log_p_clip = mx.sym.log(mx.sym.clip(p, a_min=1e-5, a_max=1)) one_alpha_p_gamma_one_labels = (1 - alpha) * p**gamma * (1 - labels) norm = mx.sym.sum(labels * nonignore_mask) + 1 forward_term1 = alpha_one_p_gamma_labels * log_p_clip forward_term2 = one_alpha_p_gamma_one_labels * minus_log loss = mx.sym.sum(-1 * (forward_term1 + forward_term2) * nonignore_mask) / norm backward_term1 = alpha_one_p_gamma_labels * (1 - p - p * gamma * log_p_clip) backward_term2 = one_alpha_p_gamma_one_labels * (minus_log * (1 - p) * gamma - p) grad = mx.sym.broadcast_div(lhs=-1 * (backward_term1 + backward_term2) * nonignore_mask, rhs=norm.reshape((1, 1))) loss = X.block_grad(loss) grad = X.block_grad(grad) loss = mx.sym.Custom(logits=logits, loss=loss, grad=grad, op_type='compute_focal_loss', name='focal_loss') return loss
def _refine_pts(self, cls_feat, reg_feat, dcn_offset, pts_init_out): p = self.p point_conv_channel = p.head.point_conv_channel num_class = p.num_class output_channel = num_class - 1 pts_output_channel = p.point_generate.num_points * 2 cls_conv = mx.symbol.contrib.DeformableConvolution( data=cls_feat, offset=dcn_offset, kernel=(self.dcn_kernel, self.dcn_kernel), pad=(self.dcn_pad, self.dcn_pad), stride=(1, 1), dilate=(1, 1), num_filter=point_conv_channel, weight=self.cls_conv_weight, bias=self.cls_conv_bias, no_bias=False, name="cls_conv") cls_conv_relu = X.relu(cls_conv) cls_out = X.conv(data=cls_conv_relu, kernel=1, filter=output_channel, weight=self.cls_out_weight, bias=self.cls_out_bias, no_bias=False, name="cls_out") pts_refine_conv = mx.symbol.contrib.DeformableConvolution( data=reg_feat, offset=dcn_offset, kernel=(self.dcn_kernel, self.dcn_kernel), pad=(self.dcn_pad, self.dcn_pad), stride=(1, 1), dilate=(1, 1), num_filter=point_conv_channel, weight=self.pts_refine_conv_weight, bias=self.pts_refine_conv_bias, no_bias=False, name="pts_refine_conv") pts_refine_conv_relu = X.relu(pts_refine_conv) pts_refine_out = X.conv(data=pts_refine_conv_relu, kernel=1, filter=pts_output_channel, weight=self.pts_refine_out_weight, bias=self.pts_refine_out_bias, no_bias=False, name="pts_refine_out") pts_refine_out = pts_refine_out + X.block_grad(pts_init_out) return pts_refine_out, cls_out
def get_loss(self, conv_fpn_feat, gt_bbox, im_info): p = self.p bs = p.batch_image # batch_size on a single gpu centerness_logit_dict, cls_logit_dict, offset_logit_dict = self.get_output(conv_fpn_feat) centerness_loss_list = [] cls_loss_list = [] offset_loss_list = [] # prepare gt ignore_label = X.block_grad(X.var('ignore_label', init=X.constant(p.loss_setting.ignore_label), shape=(1,1))) ignore_offset = X.block_grad(X.var('ignore_offset', init=X.constant(p.loss_setting.ignore_offset), shape=(1,1,1))) gt_bbox = X.var('gt_bbox') im_info = X.var('im_info') centerness_labels, cls_labels, offset_labels = make_fcos_gt(gt_bbox, im_info, p.loss_setting.ignore_offset, p.loss_setting.ignore_label, p.FCOSParam.num_classifier) centerness_labels = X.block_grad(centerness_labels) cls_labels = X.block_grad(cls_labels) offset_labels = X.block_grad(offset_labels) # gather output logits cls_logit_dict_list = [] centerness_logit_dict_list = [] offset_logit_dict_list = [] for idx, stride in enumerate(p.FCOSParam.stride): # (c,H1,W1), (c,H2,W2), ..., (c,H5,W5) -> (H1W1+H2W2+...+H5W5), ...c..., (H1W1+H2W2+...+H5W5) cls_logit_dict_list.append(mx.sym.reshape(cls_logit_dict[stride], shape=(0,0,-1))) centerness_logit_dict_list.append(mx.sym.reshape(centerness_logit_dict[stride], shape=(0,0,-1))) offset_logit_dict_list.append(mx.sym.reshape(offset_logit_dict[stride], shape=(0,0,-1))) cls_logits = mx.sym.reshape(mx.sym.concat(*cls_logit_dict_list, dim=2), shape=(0,-1)) centerness_logits = mx.sym.reshape(mx.sym.concat(*centerness_logit_dict_list, dim=2), shape=(0,-1)) offset_logits = mx.sym.reshape(mx.sym.concat(*offset_logit_dict_list, dim=2), shape=(0,4,-1)) # make losses nonignore_mask = mx.sym.broadcast_not_equal(lhs=cls_labels, rhs=ignore_label) nonignore_mask = X.block_grad(nonignore_mask) cls_loss = make_sigmoid_focal_loss(gamma=p.loss_setting.focal_loss_gamma, alpha=p.loss_setting.focal_loss_alpha, logits=cls_logits, labels=cls_labels, nonignore_mask=nonignore_mask) cls_loss = X.loss(cls_loss, grad_scale=1) nonignore_mask = mx.sym.broadcast_logical_and(lhs=mx.sym.broadcast_not_equal( lhs=X.block_grad(centerness_labels), rhs=ignore_label ), rhs=mx.sym.broadcast_greater( lhs=centerness_labels, rhs=mx.sym.full((1,1), 0) ) ) nonignore_mask = X.block_grad(nonignore_mask) centerness_loss = make_binary_cross_entropy_loss(centerness_logits, centerness_labels, nonignore_mask) centerness_loss = X.loss(centerness_loss, grad_scale=1) offset_loss = IoULoss(offset_logits, offset_labels, ignore_offset, centerness_labels, name='offset_loss') return centerness_loss, cls_loss, offset_loss
def get_output(self, conv_feat): if self._pts_out_inits is not None and self._pts_out_refines is not None and \ self._cls_outs is not None: return self._pts_out_inits, self._pts_out_refines, self._cls_outs p = self.p stride = p.point_generate.stride # init base offset for dcn from models.RepPoints.point_ops import _gen_offsets dcn_base_offset = _gen_offsets(mx.symbol, dcn_kernel=self.dcn_kernel, dcn_pad=self.dcn_pad) pts_out_inits = dict() pts_out_refines = dict() cls_outs = dict() for s in stride: # cls subnet with shared params across multiple strides cls_feat = self._cls_subnet(conv_feat=conv_feat["stride%s" % s], stride=s) # reg subnet with shared params across multiple strides reg_feat = self._reg_subnet(conv_feat=conv_feat["stride%s" % s], stride=s) # predict offsets on each center points pts_out_init = self._init_pts(reg_feat) # grad multiples 0.1 for offsets subnet pts_out_init_grad_mul = 0.9 * X.block_grad( pts_out_init) + 0.1 * pts_out_init # dcn uses offsets on grids as input, # thus the predicted offsets substract base dcn offsets here before using dcn. pts_out_init_offset = mx.symbol.broadcast_sub( pts_out_init_grad_mul, dcn_base_offset) # use offsets on features to refine box and cls pts_out_refine, cls_out = self._refine_pts(cls_feat, reg_feat, pts_out_init_offset, pts_out_init) pts_out_inits["stride%s" % s] = pts_out_init pts_out_refines["stride%s" % s] = pts_out_refine cls_outs["stride%s" % s] = cls_out self._pts_out_inits = pts_out_inits self._pts_out_refines = pts_out_refines self._cls_outs = cls_outs return self._pts_out_inits, self._pts_out_refines, self._cls_outs
def _box_decode(rois, deltas, means, stds): rois = X.block_grad(rois) rois = mx.sym.reshape(rois, [-1, 4]) deltas = mx.sym.reshape(deltas, [-1, 4]) x1, y1, x2, y2 = mx.sym.split(rois, axis=-1, num_outputs=4, squeeze_axis=True) dx, dy, dw, dh = mx.sym.split(deltas, axis=-1, num_outputs=4, squeeze_axis=True) dx = dx * stds[0] + means[0] dy = dy * stds[1] + means[1] dw = dw * stds[2] + means[2] dh = dh * stds[3] + means[3] x = (x1 + x2) * 0.5 y = (y1 + y2) * 0.5 w = x2 - x1 + 1 h = y2 - y1 + 1 nx = x + dx * w ny = y + dy * h nw = w * mx.sym.exp(dw) nh = h * mx.sym.exp(dh) nx1 = nx - 0.5 * nw ny1 = ny - 0.5 * nh nx2 = nx + 0.5 * nw ny2 = ny + 0.5 * nh return mx.sym.stack(nx1, ny1, nx2, ny2, axis=1, name='pc_reg_loss_decoded_roi')
def reg_pc_loss(self, bbox_delta, tsd_bbox_delta, rois, tsd_rois, gt_bbox, gt_label, scale_loss_shift): ''' TSD regression progressive constraint Args: bbox_delta: [batch_image * image_roi, num_class*4] tsd_bbox_delta: [batch_image * image_roi, num_class*4] rois: [batch_image, image_roi, 4] rois_r: [batch_image, image_roi, 4] gt_bbox: [batch_image, max_gt_num, 4] gt_label: [batch_image * image_roi] scale_loss_shift: float Returns: loss: [batch_image * image_roi] ''' def _box_decode(rois, deltas, means, stds): rois = X.block_grad(rois) rois = mx.sym.reshape(rois, [-1, 4]) deltas = mx.sym.reshape(deltas, [-1, 4]) x1, y1, x2, y2 = mx.sym.split(rois, axis=-1, num_outputs=4, squeeze_axis=True) dx, dy, dw, dh = mx.sym.split(deltas, axis=-1, num_outputs=4, squeeze_axis=True) dx = dx * stds[0] + means[0] dy = dy * stds[1] + means[1] dw = dw * stds[2] + means[2] dh = dh * stds[3] + means[3] x = (x1 + x2) * 0.5 y = (y1 + y2) * 0.5 w = x2 - x1 + 1 h = y2 - y1 + 1 nx = x + dx * w ny = y + dy * h nw = w * mx.sym.exp(dw) nh = h * mx.sym.exp(dh) nx1 = nx - 0.5 * nw ny1 = ny - 0.5 * nh nx2 = nx + 0.5 * nw ny2 = ny + 0.5 * nh return mx.sym.stack(nx1, ny1, nx2, ny2, axis=1, name='pc_reg_loss_decoded_roi') def _gather_3d(data, indices, n): datas = mx.sym.split(data, axis=-1, num_outputs=n, squeeze_axis=True) outputs = [] for d in datas: outputs.append(mx.sym.pick(d, indices, axis=1)) return mx.sym.stack(*outputs, axis=1) batch_image = self.p.batch_image image_roi = self.p.image_roi batch_roi = batch_image * image_roi num_class = self.p.num_class bbox_mean = self.p.regress_target.mean bbox_std = self.p.regress_target.std margin = self.p.TSD.pc_reg_margin gt_label = mx.sym.reshape(gt_label, (-1, )) bbox_delta = mx.sym.reshape(bbox_delta, (batch_image * image_roi, num_class, 4)) tsd_bbox_delta = mx.sym.reshape( tsd_bbox_delta, (batch_image * image_roi, num_class, 4)) bbox_delta = _gather_3d(bbox_delta, gt_label, n=4) tsd_bbox_delta = _gather_3d(tsd_bbox_delta, gt_label, n=4) boxes = _box_decode(rois, bbox_delta, bbox_mean, bbox_std) tsd_bboxes = _box_decode(tsd_rois, tsd_bbox_delta, bbox_mean, bbox_std) rois = mx.sym.reshape(rois, [batch_image, -1, 4]) tsd_rois = mx.sym.reshape(tsd_rois, [batch_image, -1, 4]) boxes = mx.sym.reshape(boxes, [batch_image, -1, 4]) tsd_bboxes = mx.sym.reshape(tsd_bboxes, [batch_image, -1, 4]) rois_group = mx.sym.split(rois, axis=0, num_outputs=batch_image, squeeze_axis=True) tsd_rois_group = mx.sym.split(tsd_rois, axis=0, num_outputs=batch_image, squeeze_axis=True) boxes_group = mx.sym.split(boxes, axis=0, num_outputs=batch_image, squeeze_axis=True) tsd_bboxes_group = mx.sym.split(tsd_bboxes, axis=0, num_outputs=batch_image, squeeze_axis=True) gt_group = mx.sym.split(gt_bbox, axis=0, num_outputs=batch_image, squeeze_axis=True) ious = [] tsd_ious = [] for i, (rois_i, tsd_rois_i, boxes_i, tsd_boxes_i, gt_i) in \ enumerate(zip(rois_group, tsd_rois_group, boxes_group, tsd_bboxes_group, gt_group)): iou_mat = get_iou_mat(rois_i, gt_i, image_roi) tsd_iou_mat = get_iou_mat(tsd_rois_i, gt_i, image_roi) matched_gt = mx.sym.gather_nd( gt_i, X.reshape(mx.sym.argmax(iou_mat, axis=1), [1, -1])) tsd_matched_gt = mx.sym.gather_nd( gt_i, X.reshape(mx.sym.argmax(tsd_iou_mat, axis=1), [1, -1])) matched_gt = mx.sym.slice_axis(matched_gt, axis=-1, begin=0, end=4) tsd_matched_gt = mx.sym.slice_axis(tsd_matched_gt, axis=-1, begin=0, end=4) ious.append(get_iou(boxes_i, matched_gt)) tsd_ious.append(get_iou(tsd_boxes_i, tsd_matched_gt)) iou = mx.sym.concat(*ious, dim=0) tsd_iou = mx.sym.concat(*tsd_ious, dim=0) weight = X.block_grad(gt_label != 0) iou = X.block_grad(iou) reg_pc_margin = mx.sym.minimum(1. - iou, margin) loss = mx.sym.relu(-(tsd_iou - iou - reg_pc_margin)) grad_scale = 1. / batch_roi grad_scale *= scale_loss_shift loss = X.loss(weight * loss, grad_scale=grad_scale, name='reg_pc_loss') return loss
def get_reg_target(self, rois, gt_bbox): ''' Args: rois: [batch_image, image_roi, 4] gt_bbox: [batch_image, max_gt_num, 4] Returns: reg_target: [batch_image * image_roi, num_class * 4] ''' def get_transform(rois, gt_boxes): bbox_mean = self.p.regress_target.mean bbox_std = self.p.regress_target.std xmin1, ymin1, xmax1, ymax1 = mx.sym.split(rois, axis=-1, num_outputs=4, squeeze_axis=True) xmin2, ymin2, xmax2, ymax2, _ = mx.sym.split(gt_boxes, axis=-1, num_outputs=5, squeeze_axis=True) w1 = xmax1 - xmin1 + 1.0 h1 = ymax1 - ymin1 + 1.0 x1 = xmin1 + 0.5 * (w1 - 1.0) y1 = ymin1 + 0.5 * (h1 - 1.0) w2 = xmax2 - xmin2 + 1.0 h2 = ymax2 - ymin2 + 1.0 x2 = xmin2 + 0.5 * (w2 - 1.0) y2 = ymin2 + 0.5 * (h2 - 1.0) dx = (x2 - x1) / (w1 + 1e-14) dy = (y2 - y1) / (h1 + 1e-14) dw = mx.sym.log(w2 / w1) dh = mx.sym.log(h2 / h1) dx = (dx - bbox_mean[0]) / bbox_std[0] dy = (dy - bbox_mean[1]) / bbox_std[1] dw = (dw - bbox_mean[2]) / bbox_std[2] dh = (dh - bbox_mean[3]) / bbox_std[3] return mx.sym.stack(dx, dy, dw, dh, axis=1, name='delta_r_roi_transform') batch_image = self.p.batch_image image_roi = self.p.image_roi #image_roi num_class = self.p.num_class # num_class reg_target = [] rois_group = mx.sym.split(rois, axis=0, num_outputs=batch_image, squeeze_axis=True) gt_group = mx.sym.split(gt_bbox, axis=0, num_outputs=batch_image, squeeze_axis=True) for i, (rois_i, gt_box_i) in enumerate(zip(rois_group, gt_group)): iou_mat = get_iou_mat(rois_i, gt_box_i, image_roi) # [image_roi, 100] idxs = mx.sym.argmax(iou_mat, axis=1) # [image_roi] match_gt_boxes = mx.sym.gather_nd(gt_box_i, X.reshape( idxs, [1, -1])) # [image_roi, 4] delta_i = get_transform(rois_i, match_gt_boxes) # [image_roi, 4] delta_i = mx.sym.reshape( mx.sym.repeat(delta_i, repeats=num_class, axis=0), (image_roi, -1)) #[image_roi, num_class * 4] reg_target.append(delta_i) reg_target = X.block_grad( mx.sym.reshape(mx.sym.stack(*reg_target, axis=0), [batch_image * image_roi, -1], name='TSD_reg_target')) # [batch_roi, num_class*4] return reg_target
def get_loss(self, rois, roi_feat, fpn_conv_feats, cls_label, bbox_target, bbox_weight, gt_bbox): ''' Args: rois: [batch_image, image_roi, 4] roi_feat: [batch_image * image_roi, 256, roi_size, roi_size] fpn_conv_feats: dict of FPN features, each [batch_image, in_channels, fh, fw] cls_label: [batch_image * image_roi] bbox_target: [batch_image * image_roi, num_class * 4] bbox_weight: [batch_image * image_roi, num_class * 4] gt_bbox: [batch_image, max_gt_num, 4] Returns: cls_loss: [batch_image * image_roi, num_class] reg_loss: [batch_image * image_roi, num_class * 4] tsd_cls_loss: [batch_image * image_roi, num_class] tsd_reg_loss: [batch_image * image_roi, num_class * 4] tsd_cls_pc_loss: [batch_image * image_roi] tsd_reg_pc_loss: [batch_image * image_roi] cls_label: [batch_image, image_roi] ''' p = self.p assert not p.regress_target.class_agnostic batch_image = p.batch_image image_roi = p.image_roi batch_roi = batch_image * image_roi smooth_l1_scalar = p.regress_target.smooth_l1_scalar or 1.0 cls_logit, bbox_delta, tsd_cls_logit, tsd_bbox_delta, delta_c, delta_r = self.get_output( fpn_conv_feats, roi_feat, rois, is_train=True) rois_r = self._get_delta_r_box(delta_r, rois) tsd_reg_target = self.get_reg_target( rois_r, gt_bbox) # [batch_roi, num_class*4] scale_loss_shift = 128 if self.p.fp16 else 1.0 # origin loss cls_loss = X.softmax_output(data=cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='bbox_cls_loss') reg_loss = X.smooth_l1(bbox_delta - bbox_target, scalar=smooth_l1_scalar, name='bbox_reg_l1') reg_loss = bbox_weight * reg_loss reg_loss = X.loss( reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='bbox_reg_loss', ) # tsd loss tsd_cls_loss = X.softmax_output(data=tsd_cls_logit, label=cls_label, normalization='batch', grad_scale=1.0 * scale_loss_shift, name='tsd_bbox_cls_loss') tsd_reg_loss = X.smooth_l1(tsd_bbox_delta - tsd_reg_target, scalar=smooth_l1_scalar, name='tsd_bbox_reg_l1') tsd_reg_loss = bbox_weight * tsd_reg_loss tsd_reg_loss = X.loss( tsd_reg_loss, grad_scale=1.0 / batch_roi * scale_loss_shift, name='tsd_bbox_reg_loss', ) losses = [ cls_loss, reg_loss, tsd_cls_loss, tsd_reg_loss, tsd_cls_pc_loss ] if p.TSD.pc_cls: losses.append( self.cls_pc_loss(cls_logit, tsd_cls_logit, cls_label, scale_loss_shift)) if p.TSD.pc_reg: losses.append( self.reg_pc_loss(bbox_delta, tsd_bbox_delta, rois, rois_r, gt_bbox, cls_label, scale_loss_shift)) # append label cls_label = X.reshape(cls_label, shape=(batch_image, -1), name='bbox_label_reshape') cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad') losses.append(cls_label) return tuple(losses)
def get_loss(self, conv_feat, gt_bbox): from models.RepPoints.point_ops import (_gen_points, _offset_to_pts, _point_target, _offset_to_boxes, _points2bbox) p = self.p batch_image = p.batch_image num_points = p.point_generate.num_points scale = p.point_generate.scale stride = p.point_generate.stride transform = p.point_generate.transform target_scale = p.point_target.target_scale num_pos = p.point_target.num_pos pos_iou_thr = p.bbox_target.pos_iou_thr neg_iou_thr = p.bbox_target.neg_iou_thr min_pos_iou = p.bbox_target.min_pos_iou pts_out_inits, pts_out_refines, cls_outs = self.get_output(conv_feat) points = dict() bboxes = dict() pts_coordinate_preds_inits = dict() pts_coordinate_preds_refines = dict() for s in stride: # generate points on base coordinate according to stride and size of feature map points["stride%s" % s] = _gen_points(mx.symbol, pts_out_inits["stride%s" % s], s) # generate bbox after init stage bboxes["stride%s" % s] = _offset_to_boxes( mx.symbol, points["stride%s" % s], X.block_grad(pts_out_inits["stride%s" % s]), s, transform, moment_transfer=self.moment_transfer) # generate final offsets in init stage pts_coordinate_preds_inits["stride%s" % s] = _offset_to_pts( mx.symbol, points["stride%s" % s], pts_out_inits["stride%s" % s], s, num_points) # generate final offsets in refine stage pts_coordinate_preds_refines["stride%s" % s] = _offset_to_pts( mx.symbol, points["stride%s" % s], pts_out_refines["stride%s" % s], s, num_points) # for init stage, use points assignment point_proposals = mx.symbol.tile(X.concat( [points["stride%s" % s] for s in stride], axis=1, name="point_concat"), reps=(batch_image, 1, 1)) points_labels_init, points_gts_init, points_weight_init = _point_target( mx.symbol, point_proposals, gt_bbox, batch_image, "point", scale=target_scale, num_pos=num_pos) # for refine stage, use max iou assignment box_proposals = X.concat([bboxes["stride%s" % s] for s in stride], axis=1, name="box_concat") points_labels_refine, points_gts_refine, points_weight_refine = _point_target( mx.symbol, box_proposals, gt_bbox, batch_image, "box", pos_iou_thr=pos_iou_thr, neg_iou_thr=neg_iou_thr, min_pos_iou=min_pos_iou) bboxes_out_strides = dict() for s in stride: cls_outs["stride%s" % s] = X.reshape( X.transpose(data=cls_outs["stride%s" % s], axes=(0, 2, 3, 1)), (0, -3, -2)) bboxes_out_strides["stride%s" % s] = mx.symbol.repeat( mx.symbol.ones_like( mx.symbol.slice_axis( cls_outs["stride%s" % s], begin=0, end=1, axis=-1)), repeats=4, axis=-1) * s # cls branch cls_outs_concat = X.concat([cls_outs["stride%s" % s] for s in stride], axis=1, name="cls_concat") cls_loss = X.focal_loss(data=cls_outs_concat, label=points_labels_refine, normalization='valid', alpha=p.focal_loss.alpha, gamma=p.focal_loss.gamma, grad_scale=1.0, workspace=1500, name="cls_loss") # init box branch pts_inits_concat_ = X.concat( [pts_coordinate_preds_inits["stride%s" % s] for s in stride], axis=1, name="pts_init_concat_") pts_inits_concat = X.reshape(pts_inits_concat_, (-3, -2), name="pts_inits_concat") bboxes_inits_concat_ = _points2bbox( mx.symbol, pts_inits_concat, transform, y_first=False, moment_transfer=self.moment_transfer) bboxes_inits_concat = X.reshape(bboxes_inits_concat_, (-4, batch_image, -1, -2)) normalize_term = X.concat( [bboxes_out_strides["stride%s" % s] for s in stride], axis=1, name="normalize_term") * scale pts_init_loss = X.smooth_l1( data=(bboxes_inits_concat - points_gts_init) / normalize_term, scalar=3.0, name="pts_init_l1_loss") pts_init_loss = pts_init_loss * points_weight_init pts_init_loss = X.bbox_norm(data=pts_init_loss, label=points_labels_init, name="pts_init_norm_loss") pts_init_loss = X.make_loss(data=pts_init_loss, grad_scale=0.5, name="pts_init_loss") points_init_labels = X.block_grad(points_labels_refine, name="points_init_labels") # refine box branch pts_refines_concat_ = X.concat( [pts_coordinate_preds_refines["stride%s" % s] for s in stride], axis=1, name="pts_refines_concat_") pts_refines_concat = X.reshape(pts_refines_concat_, (-3, -2), name="pts_refines_concat") bboxes_refines_concat_ = _points2bbox( mx.symbol, pts_refines_concat, transform, y_first=False, moment_transfer=self.moment_transfer) bboxes_refines_concat = X.reshape(bboxes_refines_concat_, (-4, batch_image, -1, -2)) pts_refine_loss = X.smooth_l1( data=(bboxes_refines_concat - points_gts_refine) / normalize_term, scalar=3.0, name="pts_refine_l1_loss") pts_refine_loss = pts_refine_loss * points_weight_refine pts_refine_loss = X.bbox_norm(data=pts_refine_loss, label=points_labels_refine, name="pts_refine_norm_loss") pts_refine_loss = X.make_loss(data=pts_refine_loss, grad_scale=1.0, name="pts_refine_loss") points_refine_labels = X.block_grad(points_labels_refine, name="point_refine_labels") return cls_loss, pts_init_loss, pts_refine_loss, points_init_labels, points_refine_labels