Exemple #1
0
    def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight):
        p = self.p
        batch_roi = p.image_roi * p.batch_image
        batch_image = p.batch_image

        cls_logit, bbox_delta = self.get_output(conv_feat)

        scale_loss_shift = 128.0 if p.fp16 else 1.0

        # classification loss
        cls_loss = X.softmax_output(data=cls_logit,
                                    label=cls_label,
                                    normalization='batch',
                                    grad_scale=1.0 * scale_loss_shift,
                                    name='bbox_cls_loss')

        # bounding box regression
        reg_loss = X.smooth_l1(bbox_delta - bbox_target,
                               scalar=1.0,
                               name='bbox_reg_l1')
        reg_loss = bbox_weight * reg_loss
        reg_loss = X.loss(
            reg_loss,
            grad_scale=1.0 / batch_roi * scale_loss_shift,
            name='bbox_reg_loss',
        )

        # append label
        cls_label = X.reshape(cls_label,
                              shape=(batch_image, -1),
                              name='bbox_label_reshape')
        cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad')

        # output
        return cls_loss, reg_loss, cls_label
Exemple #2
0
    def cls_pc_loss(self, logits, tsd_logits, gt_label, scale_loss_shift):
        '''
        TSD classification progressive constraint
        Args:
            logits: [batch_image * image_roi, num_class]
            tsd_logits: [batch_image * image_roi, num_class]
            gt_label:  [batch_image * image_roi]
            scale_loss_shift: float
        Returns:
            loss: [batch_image * image_roi]
        '''
        p = self.p
        batch_image = p.batch_image
        image_roi = p.image_roi
        batch_roi = batch_image * image_roi
        margin = self.p.TSD.pc_cls_margin

        cls_prob = mx.sym.SoftmaxActivation(logits, mode='instance')
        tsd_prob = mx.sym.SoftmaxActivation(tsd_logits, mode='instance')

        cls_score = mx.sym.pick(cls_prob, gt_label, axis=1)
        tsd_score = mx.sym.pick(tsd_prob, gt_label, axis=1)

        cls_score = X.block_grad(cls_score)
        cls_pc_margin = mx.sym.minimum(1. - cls_score, margin)
        loss = mx.sym.relu(-(tsd_score - cls_score - cls_pc_margin))

        grad_scale = 1. / batch_roi
        grad_scale *= scale_loss_shift
        loss = X.loss(loss, grad_scale=grad_scale, name='cls_pc_loss')
        return loss
Exemple #3
0
def make_binary_cross_entropy_loss(logits, labels, nonignore_mask):
    p = 1 / (1 + mx.sym.exp(-logits))
    loss = -labels * mx.sym.log(mx.sym.clip(p, a_min=1e-5, a_max=1)) - (
        1 - labels) * mx.sym.log(mx.sym.clip(1 - p, a_min=1e-5, a_max=1))
    loss = mx.sym.sum(
        loss * nonignore_mask) / (mx.sym.sum(nonignore_mask) + 1e-30)
    grad = mx.sym.broadcast_div(lhs=(p - labels) * nonignore_mask,
                                rhs=mx.sym.sum(nonignore_mask) + 1e-30)

    loss = X.block_grad(loss)
    grad = X.block_grad(grad)

    return mx.sym.Custom(logits=logits,
                         loss=loss,
                         grad=grad,
                         op_type='compute_bce_loss',
                         name='sigmoid_bce_loss')
Exemple #4
0
def IoULoss(x_box, y_box, ignore_offset, centerness_label, name='iouloss'):
    centerness_label = mx.sym.reshape(centerness_label, shape=(0, 1, -1))
    y_box = X.block_grad(y_box)

    target_left = mx.sym.slice_axis(y_box, axis=1, begin=0, end=1)
    target_top = mx.sym.slice_axis(y_box, axis=1, begin=1, end=2)
    target_right = mx.sym.slice_axis(y_box, axis=1, begin=2, end=3)
    target_bottom = mx.sym.slice_axis(y_box, axis=1, begin=3, end=4)

    # filter out out-of-bbox area, loss is only computed inside bboxes
    nonignore_mask = mx.sym.broadcast_logical_and(
        lhs=mx.sym.broadcast_not_equal(lhs=target_left, rhs=ignore_offset),
        rhs=mx.sym.broadcast_greater(lhs=centerness_label,
                                     rhs=mx.sym.full((1, 1, 1), 0)))
    nonignore_mask = X.block_grad(nonignore_mask)
    x_box = mx.sym.clip(x_box, a_min=0, a_max=1e4)
    x_box = mx.sym.broadcast_mul(lhs=x_box, rhs=nonignore_mask)
    centerness_label = centerness_label * nonignore_mask

    pred_left = mx.sym.slice_axis(x_box, axis=1, begin=0, end=1)
    pred_top = mx.sym.slice_axis(x_box, axis=1, begin=1, end=2)
    pred_right = mx.sym.slice_axis(x_box, axis=1, begin=2, end=3)
    pred_bottom = mx.sym.slice_axis(x_box, axis=1, begin=3, end=4)

    target_area = (target_left + target_right) * (target_top + target_bottom)
    pred_area = (pred_left + pred_right) * (pred_top + pred_bottom)

    w_intersect = mx.sym.min(
        mx.sym.stack(pred_left, target_left, axis=0), axis=0) + mx.sym.min(
            mx.sym.stack(pred_right, target_right, axis=0), axis=0)
    h_intersect = mx.sym.min(
        mx.sym.stack(pred_bottom, target_bottom, axis=0), axis=0) + mx.sym.min(
            mx.sym.stack(pred_top, target_top, axis=0), axis=0)

    area_intersect = w_intersect * h_intersect
    area_union = (target_area + pred_area - area_intersect)

    loss = -mx.sym.log((area_intersect + 1.0) / (area_union + 1.0))

    loss = mx.sym.broadcast_mul(lhs=loss, rhs=centerness_label)
    loss = mx.sym.sum(loss) / (mx.sym.sum(centerness_label) + 1e-30)

    return X.loss(loss, grad_scale=1, name=name)
Exemple #5
0
def make_sigmoid_focal_loss(gamma, alpha, logits, labels, nonignore_mask):
    # conduct most of calculations using symbol and control gradient flow with custom op
    p = 1 / (1 + mx.sym.exp(-logits))  # sigmoid
    mask_logits_GE_zero = mx.sym.broadcast_greater_equal(lhs=logits,
                                                         rhs=mx.sym.zeros(
                                                             (1, 1)))
    # logits>=0
    minus_logits_mask = -1. * logits * mask_logits_GE_zero  # -1 * logits * [logits>=0]
    negative_abs_logits = logits - 2 * logits * mask_logits_GE_zero  # logtis - 2 * logits * [logits>=0]
    log_one_exp_minus_abs = mx.sym.log(1. + mx.sym.exp(negative_abs_logits))
    minus_log = minus_logits_mask - log_one_exp_minus_abs

    alpha_one_p_gamma_labels = alpha * (1 - p)**gamma * labels
    log_p_clip = mx.sym.log(mx.sym.clip(p, a_min=1e-5, a_max=1))
    one_alpha_p_gamma_one_labels = (1 - alpha) * p**gamma * (1 - labels)
    norm = mx.sym.sum(labels * nonignore_mask) + 1

    forward_term1 = alpha_one_p_gamma_labels * log_p_clip
    forward_term2 = one_alpha_p_gamma_one_labels * minus_log
    loss = mx.sym.sum(-1 *
                      (forward_term1 + forward_term2) * nonignore_mask) / norm

    backward_term1 = alpha_one_p_gamma_labels * (1 - p -
                                                 p * gamma * log_p_clip)
    backward_term2 = one_alpha_p_gamma_one_labels * (minus_log *
                                                     (1 - p) * gamma - p)
    grad = mx.sym.broadcast_div(lhs=-1 * (backward_term1 + backward_term2) *
                                nonignore_mask,
                                rhs=norm.reshape((1, 1)))

    loss = X.block_grad(loss)
    grad = X.block_grad(grad)

    loss = mx.sym.Custom(logits=logits,
                         loss=loss,
                         grad=grad,
                         op_type='compute_focal_loss',
                         name='focal_loss')
    return loss
Exemple #6
0
    def _refine_pts(self, cls_feat, reg_feat, dcn_offset, pts_init_out):
        p = self.p
        point_conv_channel = p.head.point_conv_channel
        num_class = p.num_class
        output_channel = num_class - 1
        pts_output_channel = p.point_generate.num_points * 2

        cls_conv = mx.symbol.contrib.DeformableConvolution(
            data=cls_feat,
            offset=dcn_offset,
            kernel=(self.dcn_kernel, self.dcn_kernel),
            pad=(self.dcn_pad, self.dcn_pad),
            stride=(1, 1),
            dilate=(1, 1),
            num_filter=point_conv_channel,
            weight=self.cls_conv_weight,
            bias=self.cls_conv_bias,
            no_bias=False,
            name="cls_conv")
        cls_conv_relu = X.relu(cls_conv)
        cls_out = X.conv(data=cls_conv_relu,
                         kernel=1,
                         filter=output_channel,
                         weight=self.cls_out_weight,
                         bias=self.cls_out_bias,
                         no_bias=False,
                         name="cls_out")

        pts_refine_conv = mx.symbol.contrib.DeformableConvolution(
            data=reg_feat,
            offset=dcn_offset,
            kernel=(self.dcn_kernel, self.dcn_kernel),
            pad=(self.dcn_pad, self.dcn_pad),
            stride=(1, 1),
            dilate=(1, 1),
            num_filter=point_conv_channel,
            weight=self.pts_refine_conv_weight,
            bias=self.pts_refine_conv_bias,
            no_bias=False,
            name="pts_refine_conv")
        pts_refine_conv_relu = X.relu(pts_refine_conv)
        pts_refine_out = X.conv(data=pts_refine_conv_relu,
                                kernel=1,
                                filter=pts_output_channel,
                                weight=self.pts_refine_out_weight,
                                bias=self.pts_refine_out_bias,
                                no_bias=False,
                                name="pts_refine_out")
        pts_refine_out = pts_refine_out + X.block_grad(pts_init_out)
        return pts_refine_out, cls_out
Exemple #7
0
    def get_loss(self, conv_fpn_feat, gt_bbox, im_info):
        p = self.p
        bs = p.batch_image	# batch_size on a single gpu
        centerness_logit_dict, cls_logit_dict, offset_logit_dict = self.get_output(conv_fpn_feat)

        centerness_loss_list = []
        cls_loss_list = []
        offset_loss_list = []

        # prepare gt
        ignore_label = X.block_grad(X.var('ignore_label', init=X.constant(p.loss_setting.ignore_label), shape=(1,1)))
        ignore_offset = X.block_grad(X.var('ignore_offset', init=X.constant(p.loss_setting.ignore_offset), shape=(1,1,1)))
        gt_bbox = X.var('gt_bbox')
        im_info = X.var('im_info')
        centerness_labels, cls_labels, offset_labels = make_fcos_gt(gt_bbox, im_info,
                                                                    p.loss_setting.ignore_offset,
                                                                    p.loss_setting.ignore_label,
                                                                    p.FCOSParam.num_classifier)
        centerness_labels = X.block_grad(centerness_labels)
        cls_labels = X.block_grad(cls_labels)
        offset_labels = X.block_grad(offset_labels)

        # gather output logits
        cls_logit_dict_list = []
        centerness_logit_dict_list = []
        offset_logit_dict_list = []
        for idx, stride in enumerate(p.FCOSParam.stride):
            # (c,H1,W1), (c,H2,W2), ..., (c,H5,W5) -> (H1W1+H2W2+...+H5W5), ...c..., (H1W1+H2W2+...+H5W5)
            cls_logit_dict_list.append(mx.sym.reshape(cls_logit_dict[stride], shape=(0,0,-1)))
            centerness_logit_dict_list.append(mx.sym.reshape(centerness_logit_dict[stride], shape=(0,0,-1)))
            offset_logit_dict_list.append(mx.sym.reshape(offset_logit_dict[stride], shape=(0,0,-1)))
        cls_logits = mx.sym.reshape(mx.sym.concat(*cls_logit_dict_list, dim=2), shape=(0,-1))
        centerness_logits = mx.sym.reshape(mx.sym.concat(*centerness_logit_dict_list, dim=2), shape=(0,-1))
        offset_logits = mx.sym.reshape(mx.sym.concat(*offset_logit_dict_list, dim=2), shape=(0,4,-1))

        # make losses
        nonignore_mask = mx.sym.broadcast_not_equal(lhs=cls_labels, rhs=ignore_label)
        nonignore_mask = X.block_grad(nonignore_mask)
        cls_loss = make_sigmoid_focal_loss(gamma=p.loss_setting.focal_loss_gamma, alpha=p.loss_setting.focal_loss_alpha,
                                           logits=cls_logits, labels=cls_labels, nonignore_mask=nonignore_mask)
        cls_loss = X.loss(cls_loss, grad_scale=1)

        nonignore_mask = mx.sym.broadcast_logical_and(lhs=mx.sym.broadcast_not_equal( lhs=X.block_grad(centerness_labels), rhs=ignore_label ),
                                                      rhs=mx.sym.broadcast_greater( lhs=centerness_labels, rhs=mx.sym.full((1,1), 0) )
                                                     )
        nonignore_mask = X.block_grad(nonignore_mask)
        centerness_loss = make_binary_cross_entropy_loss(centerness_logits, centerness_labels, nonignore_mask)
        centerness_loss = X.loss(centerness_loss, grad_scale=1)

        offset_loss = IoULoss(offset_logits, offset_labels, ignore_offset, centerness_labels, name='offset_loss')
        return centerness_loss, cls_loss, offset_loss
Exemple #8
0
    def get_output(self, conv_feat):
        if self._pts_out_inits is not None and self._pts_out_refines is not None and \
                self._cls_outs is not None:
            return self._pts_out_inits, self._pts_out_refines, self._cls_outs

        p = self.p
        stride = p.point_generate.stride
        # init base offset for dcn
        from models.RepPoints.point_ops import _gen_offsets
        dcn_base_offset = _gen_offsets(mx.symbol,
                                       dcn_kernel=self.dcn_kernel,
                                       dcn_pad=self.dcn_pad)

        pts_out_inits = dict()
        pts_out_refines = dict()
        cls_outs = dict()

        for s in stride:
            # cls subnet with shared params across multiple strides
            cls_feat = self._cls_subnet(conv_feat=conv_feat["stride%s" % s],
                                        stride=s)
            # reg subnet with shared params across multiple strides
            reg_feat = self._reg_subnet(conv_feat=conv_feat["stride%s" % s],
                                        stride=s)
            # predict offsets on each center points
            pts_out_init = self._init_pts(reg_feat)
            # grad multiples 0.1 for offsets subnet
            pts_out_init_grad_mul = 0.9 * X.block_grad(
                pts_out_init) + 0.1 * pts_out_init
            # dcn uses offsets on grids as input,
            # thus the predicted offsets substract base dcn offsets here before using dcn.
            pts_out_init_offset = mx.symbol.broadcast_sub(
                pts_out_init_grad_mul, dcn_base_offset)
            # use offsets on features to refine box and cls
            pts_out_refine, cls_out = self._refine_pts(cls_feat, reg_feat,
                                                       pts_out_init_offset,
                                                       pts_out_init)
            pts_out_inits["stride%s" % s] = pts_out_init
            pts_out_refines["stride%s" % s] = pts_out_refine
            cls_outs["stride%s" % s] = cls_out

        self._pts_out_inits = pts_out_inits
        self._pts_out_refines = pts_out_refines
        self._cls_outs = cls_outs

        return self._pts_out_inits, self._pts_out_refines, self._cls_outs
Exemple #9
0
        def _box_decode(rois, deltas, means, stds):
            rois = X.block_grad(rois)
            rois = mx.sym.reshape(rois, [-1, 4])
            deltas = mx.sym.reshape(deltas, [-1, 4])

            x1, y1, x2, y2 = mx.sym.split(rois,
                                          axis=-1,
                                          num_outputs=4,
                                          squeeze_axis=True)
            dx, dy, dw, dh = mx.sym.split(deltas,
                                          axis=-1,
                                          num_outputs=4,
                                          squeeze_axis=True)

            dx = dx * stds[0] + means[0]
            dy = dy * stds[1] + means[1]
            dw = dw * stds[2] + means[2]
            dh = dh * stds[3] + means[3]

            x = (x1 + x2) * 0.5
            y = (y1 + y2) * 0.5
            w = x2 - x1 + 1
            h = y2 - y1 + 1

            nx = x + dx * w
            ny = y + dy * h
            nw = w * mx.sym.exp(dw)
            nh = h * mx.sym.exp(dh)

            nx1 = nx - 0.5 * nw
            ny1 = ny - 0.5 * nh
            nx2 = nx + 0.5 * nw
            ny2 = ny + 0.5 * nh

            return mx.sym.stack(nx1,
                                ny1,
                                nx2,
                                ny2,
                                axis=1,
                                name='pc_reg_loss_decoded_roi')
Exemple #10
0
    def reg_pc_loss(self, bbox_delta, tsd_bbox_delta, rois, tsd_rois, gt_bbox,
                    gt_label, scale_loss_shift):
        '''
        TSD regression progressive constraint
        Args:
            bbox_delta: [batch_image * image_roi, num_class*4]
            tsd_bbox_delta: [batch_image * image_roi, num_class*4]
            rois: [batch_image, image_roi, 4]
            rois_r: [batch_image, image_roi, 4]
            gt_bbox: [batch_image, max_gt_num, 4]
            gt_label:  [batch_image * image_roi]
            scale_loss_shift: float
        Returns:
            loss: [batch_image * image_roi]
        '''
        def _box_decode(rois, deltas, means, stds):
            rois = X.block_grad(rois)
            rois = mx.sym.reshape(rois, [-1, 4])
            deltas = mx.sym.reshape(deltas, [-1, 4])

            x1, y1, x2, y2 = mx.sym.split(rois,
                                          axis=-1,
                                          num_outputs=4,
                                          squeeze_axis=True)
            dx, dy, dw, dh = mx.sym.split(deltas,
                                          axis=-1,
                                          num_outputs=4,
                                          squeeze_axis=True)

            dx = dx * stds[0] + means[0]
            dy = dy * stds[1] + means[1]
            dw = dw * stds[2] + means[2]
            dh = dh * stds[3] + means[3]

            x = (x1 + x2) * 0.5
            y = (y1 + y2) * 0.5
            w = x2 - x1 + 1
            h = y2 - y1 + 1

            nx = x + dx * w
            ny = y + dy * h
            nw = w * mx.sym.exp(dw)
            nh = h * mx.sym.exp(dh)

            nx1 = nx - 0.5 * nw
            ny1 = ny - 0.5 * nh
            nx2 = nx + 0.5 * nw
            ny2 = ny + 0.5 * nh

            return mx.sym.stack(nx1,
                                ny1,
                                nx2,
                                ny2,
                                axis=1,
                                name='pc_reg_loss_decoded_roi')

        def _gather_3d(data, indices, n):
            datas = mx.sym.split(data,
                                 axis=-1,
                                 num_outputs=n,
                                 squeeze_axis=True)
            outputs = []
            for d in datas:
                outputs.append(mx.sym.pick(d, indices, axis=1))
            return mx.sym.stack(*outputs, axis=1)

        batch_image = self.p.batch_image
        image_roi = self.p.image_roi
        batch_roi = batch_image * image_roi
        num_class = self.p.num_class
        bbox_mean = self.p.regress_target.mean
        bbox_std = self.p.regress_target.std
        margin = self.p.TSD.pc_reg_margin

        gt_label = mx.sym.reshape(gt_label, (-1, ))

        bbox_delta = mx.sym.reshape(bbox_delta,
                                    (batch_image * image_roi, num_class, 4))
        tsd_bbox_delta = mx.sym.reshape(
            tsd_bbox_delta, (batch_image * image_roi, num_class, 4))

        bbox_delta = _gather_3d(bbox_delta, gt_label, n=4)
        tsd_bbox_delta = _gather_3d(tsd_bbox_delta, gt_label, n=4)

        boxes = _box_decode(rois, bbox_delta, bbox_mean, bbox_std)
        tsd_bboxes = _box_decode(tsd_rois, tsd_bbox_delta, bbox_mean, bbox_std)

        rois = mx.sym.reshape(rois, [batch_image, -1, 4])
        tsd_rois = mx.sym.reshape(tsd_rois, [batch_image, -1, 4])

        boxes = mx.sym.reshape(boxes, [batch_image, -1, 4])
        tsd_bboxes = mx.sym.reshape(tsd_bboxes, [batch_image, -1, 4])

        rois_group = mx.sym.split(rois,
                                  axis=0,
                                  num_outputs=batch_image,
                                  squeeze_axis=True)
        tsd_rois_group = mx.sym.split(tsd_rois,
                                      axis=0,
                                      num_outputs=batch_image,
                                      squeeze_axis=True)
        boxes_group = mx.sym.split(boxes,
                                   axis=0,
                                   num_outputs=batch_image,
                                   squeeze_axis=True)
        tsd_bboxes_group = mx.sym.split(tsd_bboxes,
                                        axis=0,
                                        num_outputs=batch_image,
                                        squeeze_axis=True)
        gt_group = mx.sym.split(gt_bbox,
                                axis=0,
                                num_outputs=batch_image,
                                squeeze_axis=True)

        ious = []
        tsd_ious = []
        for i, (rois_i, tsd_rois_i, boxes_i, tsd_boxes_i, gt_i) in \
            enumerate(zip(rois_group, tsd_rois_group, boxes_group, tsd_bboxes_group, gt_group)):

            iou_mat = get_iou_mat(rois_i, gt_i, image_roi)
            tsd_iou_mat = get_iou_mat(tsd_rois_i, gt_i, image_roi)

            matched_gt = mx.sym.gather_nd(
                gt_i, X.reshape(mx.sym.argmax(iou_mat, axis=1), [1, -1]))
            tsd_matched_gt = mx.sym.gather_nd(
                gt_i, X.reshape(mx.sym.argmax(tsd_iou_mat, axis=1), [1, -1]))

            matched_gt = mx.sym.slice_axis(matched_gt, axis=-1, begin=0, end=4)
            tsd_matched_gt = mx.sym.slice_axis(tsd_matched_gt,
                                               axis=-1,
                                               begin=0,
                                               end=4)

            ious.append(get_iou(boxes_i, matched_gt))
            tsd_ious.append(get_iou(tsd_boxes_i, tsd_matched_gt))

        iou = mx.sym.concat(*ious, dim=0)
        tsd_iou = mx.sym.concat(*tsd_ious, dim=0)

        weight = X.block_grad(gt_label != 0)
        iou = X.block_grad(iou)

        reg_pc_margin = mx.sym.minimum(1. - iou, margin)
        loss = mx.sym.relu(-(tsd_iou - iou - reg_pc_margin))

        grad_scale = 1. / batch_roi
        grad_scale *= scale_loss_shift
        loss = X.loss(weight * loss, grad_scale=grad_scale, name='reg_pc_loss')
        return loss
Exemple #11
0
    def get_reg_target(self, rois, gt_bbox):
        '''
        Args:
            rois: [batch_image, image_roi, 4]
            gt_bbox: [batch_image, max_gt_num, 4]
        Returns:
            reg_target: [batch_image * image_roi, num_class * 4]
        '''
        def get_transform(rois, gt_boxes):
            bbox_mean = self.p.regress_target.mean
            bbox_std = self.p.regress_target.std

            xmin1, ymin1, xmax1, ymax1 = mx.sym.split(rois,
                                                      axis=-1,
                                                      num_outputs=4,
                                                      squeeze_axis=True)
            xmin2, ymin2, xmax2, ymax2, _ = mx.sym.split(gt_boxes,
                                                         axis=-1,
                                                         num_outputs=5,
                                                         squeeze_axis=True)

            w1 = xmax1 - xmin1 + 1.0
            h1 = ymax1 - ymin1 + 1.0
            x1 = xmin1 + 0.5 * (w1 - 1.0)
            y1 = ymin1 + 0.5 * (h1 - 1.0)

            w2 = xmax2 - xmin2 + 1.0
            h2 = ymax2 - ymin2 + 1.0
            x2 = xmin2 + 0.5 * (w2 - 1.0)
            y2 = ymin2 + 0.5 * (h2 - 1.0)

            dx = (x2 - x1) / (w1 + 1e-14)
            dy = (y2 - y1) / (h1 + 1e-14)
            dw = mx.sym.log(w2 / w1)
            dh = mx.sym.log(h2 / h1)

            dx = (dx - bbox_mean[0]) / bbox_std[0]
            dy = (dy - bbox_mean[1]) / bbox_std[1]
            dw = (dw - bbox_mean[2]) / bbox_std[2]
            dh = (dh - bbox_mean[3]) / bbox_std[3]

            return mx.sym.stack(dx,
                                dy,
                                dw,
                                dh,
                                axis=1,
                                name='delta_r_roi_transform')

        batch_image = self.p.batch_image
        image_roi = self.p.image_roi  #image_roi
        num_class = self.p.num_class  # num_class

        reg_target = []

        rois_group = mx.sym.split(rois,
                                  axis=0,
                                  num_outputs=batch_image,
                                  squeeze_axis=True)
        gt_group = mx.sym.split(gt_bbox,
                                axis=0,
                                num_outputs=batch_image,
                                squeeze_axis=True)

        for i, (rois_i, gt_box_i) in enumerate(zip(rois_group, gt_group)):
            iou_mat = get_iou_mat(rois_i, gt_box_i,
                                  image_roi)  # [image_roi, 100]
            idxs = mx.sym.argmax(iou_mat, axis=1)  # [image_roi]
            match_gt_boxes = mx.sym.gather_nd(gt_box_i, X.reshape(
                idxs, [1, -1]))  # [image_roi, 4]
            delta_i = get_transform(rois_i, match_gt_boxes)  # [image_roi, 4]
            delta_i = mx.sym.reshape(
                mx.sym.repeat(delta_i, repeats=num_class, axis=0),
                (image_roi, -1))  #[image_roi, num_class * 4]
            reg_target.append(delta_i)

        reg_target = X.block_grad(
            mx.sym.reshape(mx.sym.stack(*reg_target, axis=0),
                           [batch_image * image_roi, -1],
                           name='TSD_reg_target'))  # [batch_roi, num_class*4]
        return reg_target
Exemple #12
0
    def get_loss(self, rois, roi_feat, fpn_conv_feats, cls_label, bbox_target,
                 bbox_weight, gt_bbox):
        '''
        Args:
            rois: [batch_image, image_roi, 4]
            roi_feat: [batch_image * image_roi, 256, roi_size, roi_size]
            fpn_conv_feats: dict of FPN features, each [batch_image, in_channels, fh, fw]
            cls_label: [batch_image * image_roi]
            bbox_target: [batch_image * image_roi, num_class * 4]
            bbox_weight: [batch_image * image_roi, num_class * 4]
            gt_bbox: [batch_image, max_gt_num, 4]
        Returns:
            cls_loss: [batch_image * image_roi, num_class]
            reg_loss: [batch_image * image_roi, num_class * 4]
            tsd_cls_loss: [batch_image * image_roi, num_class]
            tsd_reg_loss: [batch_image * image_roi, num_class * 4]
            tsd_cls_pc_loss: [batch_image * image_roi]
            tsd_reg_pc_loss: [batch_image * image_roi]
            cls_label: [batch_image, image_roi]
        '''
        p = self.p
        assert not p.regress_target.class_agnostic
        batch_image = p.batch_image
        image_roi = p.image_roi
        batch_roi = batch_image * image_roi
        smooth_l1_scalar = p.regress_target.smooth_l1_scalar or 1.0

        cls_logit, bbox_delta, tsd_cls_logit, tsd_bbox_delta, delta_c, delta_r = self.get_output(
            fpn_conv_feats, roi_feat, rois, is_train=True)

        rois_r = self._get_delta_r_box(delta_r, rois)
        tsd_reg_target = self.get_reg_target(
            rois_r, gt_bbox)  # [batch_roi, num_class*4]

        scale_loss_shift = 128 if self.p.fp16 else 1.0
        # origin loss
        cls_loss = X.softmax_output(data=cls_logit,
                                    label=cls_label,
                                    normalization='batch',
                                    grad_scale=1.0 * scale_loss_shift,
                                    name='bbox_cls_loss')
        reg_loss = X.smooth_l1(bbox_delta - bbox_target,
                               scalar=smooth_l1_scalar,
                               name='bbox_reg_l1')
        reg_loss = bbox_weight * reg_loss
        reg_loss = X.loss(
            reg_loss,
            grad_scale=1.0 / batch_roi * scale_loss_shift,
            name='bbox_reg_loss',
        )
        # tsd loss
        tsd_cls_loss = X.softmax_output(data=tsd_cls_logit,
                                        label=cls_label,
                                        normalization='batch',
                                        grad_scale=1.0 * scale_loss_shift,
                                        name='tsd_bbox_cls_loss')
        tsd_reg_loss = X.smooth_l1(tsd_bbox_delta - tsd_reg_target,
                                   scalar=smooth_l1_scalar,
                                   name='tsd_bbox_reg_l1')
        tsd_reg_loss = bbox_weight * tsd_reg_loss
        tsd_reg_loss = X.loss(
            tsd_reg_loss,
            grad_scale=1.0 / batch_roi * scale_loss_shift,
            name='tsd_bbox_reg_loss',
        )

        losses = [
            cls_loss, reg_loss, tsd_cls_loss, tsd_reg_loss, tsd_cls_pc_loss
        ]
        if p.TSD.pc_cls:
            losses.append(
                self.cls_pc_loss(cls_logit, tsd_cls_logit, cls_label,
                                 scale_loss_shift))
        if p.TSD.pc_reg:
            losses.append(
                self.reg_pc_loss(bbox_delta, tsd_bbox_delta, rois, rois_r,
                                 gt_bbox, cls_label, scale_loss_shift))
        # append label
        cls_label = X.reshape(cls_label,
                              shape=(batch_image, -1),
                              name='bbox_label_reshape')
        cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad')
        losses.append(cls_label)

        return tuple(losses)
Exemple #13
0
    def get_loss(self, conv_feat, gt_bbox):
        from models.RepPoints.point_ops import (_gen_points, _offset_to_pts,
                                                _point_target,
                                                _offset_to_boxes, _points2bbox)
        p = self.p
        batch_image = p.batch_image
        num_points = p.point_generate.num_points
        scale = p.point_generate.scale
        stride = p.point_generate.stride
        transform = p.point_generate.transform
        target_scale = p.point_target.target_scale
        num_pos = p.point_target.num_pos
        pos_iou_thr = p.bbox_target.pos_iou_thr
        neg_iou_thr = p.bbox_target.neg_iou_thr
        min_pos_iou = p.bbox_target.min_pos_iou

        pts_out_inits, pts_out_refines, cls_outs = self.get_output(conv_feat)

        points = dict()
        bboxes = dict()
        pts_coordinate_preds_inits = dict()
        pts_coordinate_preds_refines = dict()
        for s in stride:
            # generate points on base coordinate according to stride and size of feature map
            points["stride%s" % s] = _gen_points(mx.symbol,
                                                 pts_out_inits["stride%s" % s],
                                                 s)
            # generate bbox after init stage
            bboxes["stride%s" % s] = _offset_to_boxes(
                mx.symbol,
                points["stride%s" % s],
                X.block_grad(pts_out_inits["stride%s" % s]),
                s,
                transform,
                moment_transfer=self.moment_transfer)
            # generate final offsets in init stage
            pts_coordinate_preds_inits["stride%s" % s] = _offset_to_pts(
                mx.symbol, points["stride%s" % s],
                pts_out_inits["stride%s" % s], s, num_points)
            # generate final offsets in refine stage
            pts_coordinate_preds_refines["stride%s" % s] = _offset_to_pts(
                mx.symbol, points["stride%s" % s],
                pts_out_refines["stride%s" % s], s, num_points)

        # for init stage, use points assignment
        point_proposals = mx.symbol.tile(X.concat(
            [points["stride%s" % s] for s in stride],
            axis=1,
            name="point_concat"),
                                         reps=(batch_image, 1, 1))
        points_labels_init, points_gts_init, points_weight_init = _point_target(
            mx.symbol,
            point_proposals,
            gt_bbox,
            batch_image,
            "point",
            scale=target_scale,
            num_pos=num_pos)
        # for refine stage, use max iou assignment
        box_proposals = X.concat([bboxes["stride%s" % s] for s in stride],
                                 axis=1,
                                 name="box_concat")
        points_labels_refine, points_gts_refine, points_weight_refine = _point_target(
            mx.symbol,
            box_proposals,
            gt_bbox,
            batch_image,
            "box",
            pos_iou_thr=pos_iou_thr,
            neg_iou_thr=neg_iou_thr,
            min_pos_iou=min_pos_iou)

        bboxes_out_strides = dict()
        for s in stride:
            cls_outs["stride%s" % s] = X.reshape(
                X.transpose(data=cls_outs["stride%s" % s], axes=(0, 2, 3, 1)),
                (0, -3, -2))
            bboxes_out_strides["stride%s" % s] = mx.symbol.repeat(
                mx.symbol.ones_like(
                    mx.symbol.slice_axis(
                        cls_outs["stride%s" % s], begin=0, end=1, axis=-1)),
                repeats=4,
                axis=-1) * s

        # cls branch
        cls_outs_concat = X.concat([cls_outs["stride%s" % s] for s in stride],
                                   axis=1,
                                   name="cls_concat")
        cls_loss = X.focal_loss(data=cls_outs_concat,
                                label=points_labels_refine,
                                normalization='valid',
                                alpha=p.focal_loss.alpha,
                                gamma=p.focal_loss.gamma,
                                grad_scale=1.0,
                                workspace=1500,
                                name="cls_loss")

        # init box branch
        pts_inits_concat_ = X.concat(
            [pts_coordinate_preds_inits["stride%s" % s] for s in stride],
            axis=1,
            name="pts_init_concat_")
        pts_inits_concat = X.reshape(pts_inits_concat_, (-3, -2),
                                     name="pts_inits_concat")
        bboxes_inits_concat_ = _points2bbox(
            mx.symbol,
            pts_inits_concat,
            transform,
            y_first=False,
            moment_transfer=self.moment_transfer)
        bboxes_inits_concat = X.reshape(bboxes_inits_concat_,
                                        (-4, batch_image, -1, -2))
        normalize_term = X.concat(
            [bboxes_out_strides["stride%s" % s] for s in stride],
            axis=1,
            name="normalize_term") * scale
        pts_init_loss = X.smooth_l1(
            data=(bboxes_inits_concat - points_gts_init) / normalize_term,
            scalar=3.0,
            name="pts_init_l1_loss")
        pts_init_loss = pts_init_loss * points_weight_init
        pts_init_loss = X.bbox_norm(data=pts_init_loss,
                                    label=points_labels_init,
                                    name="pts_init_norm_loss")
        pts_init_loss = X.make_loss(data=pts_init_loss,
                                    grad_scale=0.5,
                                    name="pts_init_loss")
        points_init_labels = X.block_grad(points_labels_refine,
                                          name="points_init_labels")

        # refine box branch
        pts_refines_concat_ = X.concat(
            [pts_coordinate_preds_refines["stride%s" % s] for s in stride],
            axis=1,
            name="pts_refines_concat_")
        pts_refines_concat = X.reshape(pts_refines_concat_, (-3, -2),
                                       name="pts_refines_concat")
        bboxes_refines_concat_ = _points2bbox(
            mx.symbol,
            pts_refines_concat,
            transform,
            y_first=False,
            moment_transfer=self.moment_transfer)
        bboxes_refines_concat = X.reshape(bboxes_refines_concat_,
                                          (-4, batch_image, -1, -2))
        pts_refine_loss = X.smooth_l1(
            data=(bboxes_refines_concat - points_gts_refine) / normalize_term,
            scalar=3.0,
            name="pts_refine_l1_loss")
        pts_refine_loss = pts_refine_loss * points_weight_refine
        pts_refine_loss = X.bbox_norm(data=pts_refine_loss,
                                      label=points_labels_refine,
                                      name="pts_refine_norm_loss")
        pts_refine_loss = X.make_loss(data=pts_refine_loss,
                                      grad_scale=1.0,
                                      name="pts_refine_loss")
        points_refine_labels = X.block_grad(points_labels_refine,
                                            name="point_refine_labels")

        return cls_loss, pts_init_loss, pts_refine_loss, points_init_labels, points_refine_labels