Ejemplo n.º 1
0
    def get_loss(self, conv_fpn_feat, gt_bbox, im_info):
        p = self.p
        bs = p.batch_image	# batch_size on a single gpu
        centerness_logit_dict, cls_logit_dict, offset_logit_dict = self.get_output(conv_fpn_feat)

        centerness_loss_list = []
        cls_loss_list = []
        offset_loss_list = []

        # prepare gt
        ignore_label = X.block_grad(X.var('ignore_label', init=X.constant(p.loss_setting.ignore_label), shape=(1,1)))
        ignore_offset = X.block_grad(X.var('ignore_offset', init=X.constant(p.loss_setting.ignore_offset), shape=(1,1,1)))
        gt_bbox = X.var('gt_bbox')
        im_info = X.var('im_info')
        centerness_labels, cls_labels, offset_labels = make_fcos_gt(gt_bbox, im_info,
                                                                    p.loss_setting.ignore_offset,
                                                                    p.loss_setting.ignore_label,
                                                                    p.FCOSParam.num_classifier)
        centerness_labels = X.block_grad(centerness_labels)
        cls_labels = X.block_grad(cls_labels)
        offset_labels = X.block_grad(offset_labels)

        # gather output logits
        cls_logit_dict_list = []
        centerness_logit_dict_list = []
        offset_logit_dict_list = []
        for idx, stride in enumerate(p.FCOSParam.stride):
            # (c,H1,W1), (c,H2,W2), ..., (c,H5,W5) -> (H1W1+H2W2+...+H5W5), ...c..., (H1W1+H2W2+...+H5W5)
            cls_logit_dict_list.append(mx.sym.reshape(cls_logit_dict[stride], shape=(0,0,-1)))
            centerness_logit_dict_list.append(mx.sym.reshape(centerness_logit_dict[stride], shape=(0,0,-1)))
            offset_logit_dict_list.append(mx.sym.reshape(offset_logit_dict[stride], shape=(0,0,-1)))
        cls_logits = mx.sym.reshape(mx.sym.concat(*cls_logit_dict_list, dim=2), shape=(0,-1))
        centerness_logits = mx.sym.reshape(mx.sym.concat(*centerness_logit_dict_list, dim=2), shape=(0,-1))
        offset_logits = mx.sym.reshape(mx.sym.concat(*offset_logit_dict_list, dim=2), shape=(0,4,-1))

        # make losses
        nonignore_mask = mx.sym.broadcast_not_equal(lhs=cls_labels, rhs=ignore_label)
        nonignore_mask = X.block_grad(nonignore_mask)
        cls_loss = make_sigmoid_focal_loss(gamma=p.loss_setting.focal_loss_gamma, alpha=p.loss_setting.focal_loss_alpha,
                                           logits=cls_logits, labels=cls_labels, nonignore_mask=nonignore_mask)
        cls_loss = X.loss(cls_loss, grad_scale=1)

        nonignore_mask = mx.sym.broadcast_logical_and(lhs=mx.sym.broadcast_not_equal( lhs=X.block_grad(centerness_labels), rhs=ignore_label ),
                                                      rhs=mx.sym.broadcast_greater( lhs=centerness_labels, rhs=mx.sym.full((1,1), 0) )
                                                     )
        nonignore_mask = X.block_grad(nonignore_mask)
        centerness_loss = make_binary_cross_entropy_loss(centerness_logits, centerness_labels, nonignore_mask)
        centerness_loss = X.loss(centerness_loss, grad_scale=1)

        offset_loss = IoULoss(offset_logits, offset_labels, ignore_offset, centerness_labels, name='offset_loss')
        return centerness_loss, cls_loss, offset_loss
Ejemplo n.º 2
0
    def cls_pc_loss(self, logits, tsd_logits, gt_label, scale_loss_shift):
        '''
        TSD classification progressive constraint
        Args:
            logits: [batch_image * image_roi, num_class]
            tsd_logits: [batch_image * image_roi, num_class]
            gt_label:  [batch_image * image_roi]
            scale_loss_shift: float
        Returns:
            loss: [batch_image * image_roi]
        '''
        p = self.p
        batch_image = p.batch_image
        image_roi = p.image_roi
        batch_roi = batch_image * image_roi
        margin = self.p.TSD.pc_cls_margin

        cls_prob = mx.sym.SoftmaxActivation(logits, mode='instance')
        tsd_prob = mx.sym.SoftmaxActivation(tsd_logits, mode='instance')

        cls_score = mx.sym.pick(cls_prob, gt_label, axis=1)
        tsd_score = mx.sym.pick(tsd_prob, gt_label, axis=1)

        cls_score = X.block_grad(cls_score)
        cls_pc_margin = mx.sym.minimum(1. - cls_score, margin)
        loss = mx.sym.relu(-(tsd_score - cls_score - cls_pc_margin))

        grad_scale = 1. / batch_roi
        grad_scale *= scale_loss_shift
        loss = X.loss(loss, grad_scale=grad_scale, name='cls_pc_loss')
        return loss
Ejemplo n.º 3
0
    def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight):
        p = self.p
        batch_image = p.batch_image
        image_anchor = p.anchor_generate.image_anchor

        cls_logit, bbox_delta = self.get_output(conv_feat)

        scale_loss_shift = 128.0 if p.fp16 else 1.0

        # classification loss
        cls_logit_reshape = X.reshape(
            cls_logit,
            shape=(0, -4, 2, -1, 0, 0),  # (N,C,H,W) -> (N,2,C/2,H,W)
            name="rpn_cls_logit_reshape")
        cls_loss = X.softmax_output(data=cls_logit_reshape,
                                    label=cls_label,
                                    multi_output=True,
                                    normalization='valid',
                                    use_ignore=True,
                                    ignore_label=-1,
                                    grad_scale=1.0 * scale_loss_shift,
                                    name="rpn_cls_loss")

        # regression loss
        reg_loss = X.smooth_l1((bbox_delta - bbox_target),
                               scalar=3.0,
                               name='rpn_reg_l1')
        reg_loss = bbox_weight * reg_loss
        reg_loss = X.loss(reg_loss,
                          grad_scale=1.0 / (batch_image * image_anchor) *
                          scale_loss_shift,
                          name='rpn_reg_loss')

        return cls_loss, reg_loss
Ejemplo n.º 4
0
    def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight):
        p = self.p
        batch_roi = p.image_roi * p.batch_image
        batch_image = p.batch_image

        cls_logit, bbox_delta = self.get_output(conv_feat)

        scale_loss_shift = 128.0 if p.fp16 else 1.0

        # classification loss
        cls_loss = X.softmax_output(data=cls_logit,
                                    label=cls_label,
                                    normalization='batch',
                                    grad_scale=1.0 * scale_loss_shift,
                                    name='bbox_cls_loss')

        # bounding box regression
        reg_loss = X.smooth_l1(bbox_delta - bbox_target,
                               scalar=1.0,
                               name='bbox_reg_l1')
        reg_loss = bbox_weight * reg_loss
        reg_loss = X.loss(
            reg_loss,
            grad_scale=1.0 / batch_roi * scale_loss_shift,
            name='bbox_reg_loss',
        )

        # append label
        cls_label = X.reshape(cls_label,
                              shape=(batch_image, -1),
                              name='bbox_label_reshape')
        cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad')

        # output
        return cls_loss, reg_loss, cls_label
Ejemplo n.º 5
0
    def get_loss(self, conv_fpn_feat, cls_label, bbox_target, bbox_weight):
        p = self.p
        batch_image = p.batch_image
        image_anchor = p.anchor_generate.image_anchor
        rpn_stride = p.anchor_generate.stride

        cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat)

        scale_loss_shift = 128.0 if p.fp16 else 1.0

        rpn_cls_logit_list = []
        rpn_bbox_delta_list = []

        for stride in rpn_stride:
            rpn_cls_logit = cls_logit_dict[stride]
            rpn_bbox_delta = bbox_delta_dict[stride]
            rpn_cls_logit_reshape = X.reshape(
                data=rpn_cls_logit,
                shape=(0, 2, -1),
                name="rpn_cls_score_reshape_stride%s" % stride
            )
            rpn_bbox_delta_reshape = X.reshape(
                data=rpn_bbox_delta,
                shape=(0, 0, -1),
                name="rpn_bbox_pred_reshape_stride%s" % stride
            )
            rpn_bbox_delta_list.append(rpn_bbox_delta_reshape)
            rpn_cls_logit_list.append(rpn_cls_logit_reshape)

        # concat output of each level
        rpn_bbox_delta_concat = X.concat(rpn_bbox_delta_list, axis=2, name="rpn_bbox_pred_concat")
        rpn_cls_logit_concat = X.concat(rpn_cls_logit_list, axis=2, name="rpn_cls_score_concat")

        cls_loss = X.softmax_output(
            data=rpn_cls_logit_concat,
            label=cls_label,
            multi_output=True,
            normalization='valid',
            use_ignore=True,
            ignore_label=-1,
            grad_scale=1.0 * scale_loss_shift,
            name="rpn_cls_loss"
        )

        # regression loss
        reg_loss = X.smooth_l1(
            (rpn_bbox_delta_concat - bbox_target),
            scalar=3.0,
            name='rpn_reg_l1'
        )
        reg_loss = bbox_weight * reg_loss
        reg_loss = X.loss(
            reg_loss,
            grad_scale=1.0 / (batch_image * image_anchor) * scale_loss_shift,
            name='rpn_reg_loss'
        )
        return cls_loss, reg_loss
Ejemplo n.º 6
0
    def get_loss(self, conv_feat, gt_bboxes, im_infos, rpn_groups):
        p = self.p
        num_class = p.num_class
        batch_image = p.batch_image
        image_anchor = p.anchor_generate.image_anchor

        cls_logit, bbox_delta = self.get_output(conv_feat)

        scale_loss_shift = 128.0 if p.fp16 else 1.0

        cls_label = X.var("rpn_cls_label")
        bbox_target = X.var("rpn_reg_target")
        bbox_weight = X.var("rpn_reg_weight")

        # classification loss
        cls_logit_reshape = X.reshape(
            cls_logit,
            shape=(0, -4, num_class, -1, 0,
                   0),  # (N,C,H,W) -> (N,num_class,C/num_class,H,W)
            name="rpn_cls_logit_reshape")

        cls_loss = None
        if p.use_groupsoftmax:
            cls_loss = mx.sym.contrib.GroupSoftmaxOutput(
                data=cls_logit_reshape,
                label=cls_label,
                group=rpn_groups,
                multi_output=True,
                normalization='valid',
                use_ignore=True,
                ignore_label=-1,
                grad_scale=1.0 * scale_loss_shift,
                name="rpn_cls_loss")
        else:
            cls_loss = X.softmax_output(data=cls_logit_reshape,
                                        label=cls_label,
                                        multi_output=True,
                                        normalization='valid',
                                        use_ignore=True,
                                        ignore_label=-1,
                                        grad_scale=1.0 * scale_loss_shift,
                                        name="rpn_cls_loss")

        # regression loss
        reg_loss = X.smooth_l1((bbox_delta - bbox_target),
                               scalar=3.0,
                               name='rpn_reg_l1')
        reg_loss = bbox_weight * reg_loss
        reg_loss = X.loss(reg_loss,
                          grad_scale=1.0 / (batch_image * image_anchor) *
                          scale_loss_shift,
                          name='rpn_reg_loss')

        return cls_loss, reg_loss
Ejemplo n.º 7
0
def IoULoss(x_box, y_box, ignore_offset, centerness_label, name='iouloss'):
    centerness_label = mx.sym.reshape(centerness_label, shape=(0, 1, -1))
    y_box = X.block_grad(y_box)

    target_left = mx.sym.slice_axis(y_box, axis=1, begin=0, end=1)
    target_top = mx.sym.slice_axis(y_box, axis=1, begin=1, end=2)
    target_right = mx.sym.slice_axis(y_box, axis=1, begin=2, end=3)
    target_bottom = mx.sym.slice_axis(y_box, axis=1, begin=3, end=4)

    # filter out out-of-bbox area, loss is only computed inside bboxes
    nonignore_mask = mx.sym.broadcast_logical_and(
        lhs=mx.sym.broadcast_not_equal(lhs=target_left, rhs=ignore_offset),
        rhs=mx.sym.broadcast_greater(lhs=centerness_label,
                                     rhs=mx.sym.full((1, 1, 1), 0)))
    nonignore_mask = X.block_grad(nonignore_mask)
    x_box = mx.sym.clip(x_box, a_min=0, a_max=1e4)
    x_box = mx.sym.broadcast_mul(lhs=x_box, rhs=nonignore_mask)
    centerness_label = centerness_label * nonignore_mask

    pred_left = mx.sym.slice_axis(x_box, axis=1, begin=0, end=1)
    pred_top = mx.sym.slice_axis(x_box, axis=1, begin=1, end=2)
    pred_right = mx.sym.slice_axis(x_box, axis=1, begin=2, end=3)
    pred_bottom = mx.sym.slice_axis(x_box, axis=1, begin=3, end=4)

    target_area = (target_left + target_right) * (target_top + target_bottom)
    pred_area = (pred_left + pred_right) * (pred_top + pred_bottom)

    w_intersect = mx.sym.min(
        mx.sym.stack(pred_left, target_left, axis=0), axis=0) + mx.sym.min(
            mx.sym.stack(pred_right, target_right, axis=0), axis=0)
    h_intersect = mx.sym.min(
        mx.sym.stack(pred_bottom, target_bottom, axis=0), axis=0) + mx.sym.min(
            mx.sym.stack(pred_top, target_top, axis=0), axis=0)

    area_intersect = w_intersect * h_intersect
    area_union = (target_area + pred_area - area_intersect)

    loss = -mx.sym.log((area_intersect + 1.0) / (area_union + 1.0))

    loss = mx.sym.broadcast_mul(lhs=loss, rhs=centerness_label)
    loss = mx.sym.sum(loss) / (mx.sym.sum(centerness_label) + 1e-30)

    return X.loss(loss, grad_scale=1, name=name)
Ejemplo n.º 8
0
    def get_loss(self, conv_fpn_feat, gt_bbox, im_info):
        p = self.p
        batch_image = p.batch_image
        image_anchor = p.anchor_assign.image_anchor
        rpn_stride = p.anchor_generate.stride
        anchor_scale = p.anchor_generate.scale
        anchor_ratio = p.anchor_generate.ratio
        num_anchor = len(p.anchor_generate.ratio) * len(
            p.anchor_generate.scale)

        cls_logit_dict, bbox_delta_dict = self.get_output(conv_fpn_feat)

        scale_loss_shift = 128.0 if p.fp16 else 1.0

        rpn_cls_logit_list = []
        rpn_bbox_delta_list = []
        feat_list = []

        for stride in rpn_stride:
            rpn_cls_logit = cls_logit_dict[stride]
            rpn_bbox_delta = bbox_delta_dict[stride]
            rpn_cls_logit_reshape = X.reshape(
                data=rpn_cls_logit,
                shape=(0, 2, num_anchor, -1),
                name="rpn_cls_score_reshape_stride%s" % stride)
            rpn_bbox_delta_reshape = X.reshape(
                data=rpn_bbox_delta,
                shape=(0, 0, -1),
                name="rpn_bbox_pred_reshape_stride%s" % stride)
            rpn_bbox_delta_list.append(rpn_bbox_delta_reshape)
            rpn_cls_logit_list.append(rpn_cls_logit_reshape)
            feat_list.append(rpn_cls_logit)

        if p.nnvm_rpn_target:
            from mxnext.tvm.rpn_target import _fpn_rpn_target_batch

            anchor_list = [
                self.anchor_dict["stride%s" % s] for s in rpn_stride
            ]
            gt_bbox = mx.sym.slice_axis(gt_bbox, axis=-1, begin=0, end=4)

            max_side = p.anchor_generate.max_side
            allowed_border = p.anchor_assign.allowed_border
            fg_fraction = p.anchor_assign.pos_fraction
            fg_thr = p.anchor_assign.pos_thr
            bg_thr = p.anchor_assign.neg_thr

            cls_label, bbox_target, bbox_weight = _fpn_rpn_target_batch(
                mx.sym, feat_list, anchor_list, gt_bbox, im_info, batch_image,
                num_anchor, max_side, rpn_stride, allowed_border, image_anchor,
                fg_fraction, fg_thr, bg_thr)
        else:
            cls_label = X.var("rpn_cls_label")
            bbox_target = X.var("rpn_reg_target")
            bbox_weight = X.var("rpn_reg_weight")

        # concat output of each level
        rpn_bbox_delta_concat = X.concat(rpn_bbox_delta_list,
                                         axis=2,
                                         name="rpn_bbox_pred_concat")
        rpn_cls_logit_concat = X.concat(rpn_cls_logit_list,
                                        axis=-1,
                                        name="rpn_cls_score_concat")

        cls_loss = X.softmax_output(data=rpn_cls_logit_concat,
                                    label=cls_label,
                                    multi_output=True,
                                    normalization='valid',
                                    use_ignore=True,
                                    ignore_label=-1,
                                    grad_scale=1.0 * scale_loss_shift,
                                    name="rpn_cls_loss")

        # regression loss
        reg_loss = X.smooth_l1((rpn_bbox_delta_concat - bbox_target),
                               scalar=3.0,
                               name='rpn_reg_l1')
        reg_loss = bbox_weight * reg_loss
        reg_loss = X.loss(reg_loss,
                          grad_scale=1.0 / (batch_image * image_anchor) *
                          scale_loss_shift,
                          name='rpn_reg_loss')
        return cls_loss, reg_loss, X.stop_grad(cls_label,
                                               "rpn_cls_label_blockgrad")
Ejemplo n.º 9
0
    def reg_pc_loss(self, bbox_delta, tsd_bbox_delta, rois, tsd_rois, gt_bbox,
                    gt_label, scale_loss_shift):
        '''
        TSD regression progressive constraint
        Args:
            bbox_delta: [batch_image * image_roi, num_class*4]
            tsd_bbox_delta: [batch_image * image_roi, num_class*4]
            rois: [batch_image, image_roi, 4]
            rois_r: [batch_image, image_roi, 4]
            gt_bbox: [batch_image, max_gt_num, 4]
            gt_label:  [batch_image * image_roi]
            scale_loss_shift: float
        Returns:
            loss: [batch_image * image_roi]
        '''
        def _box_decode(rois, deltas, means, stds):
            rois = X.block_grad(rois)
            rois = mx.sym.reshape(rois, [-1, 4])
            deltas = mx.sym.reshape(deltas, [-1, 4])

            x1, y1, x2, y2 = mx.sym.split(rois,
                                          axis=-1,
                                          num_outputs=4,
                                          squeeze_axis=True)
            dx, dy, dw, dh = mx.sym.split(deltas,
                                          axis=-1,
                                          num_outputs=4,
                                          squeeze_axis=True)

            dx = dx * stds[0] + means[0]
            dy = dy * stds[1] + means[1]
            dw = dw * stds[2] + means[2]
            dh = dh * stds[3] + means[3]

            x = (x1 + x2) * 0.5
            y = (y1 + y2) * 0.5
            w = x2 - x1 + 1
            h = y2 - y1 + 1

            nx = x + dx * w
            ny = y + dy * h
            nw = w * mx.sym.exp(dw)
            nh = h * mx.sym.exp(dh)

            nx1 = nx - 0.5 * nw
            ny1 = ny - 0.5 * nh
            nx2 = nx + 0.5 * nw
            ny2 = ny + 0.5 * nh

            return mx.sym.stack(nx1,
                                ny1,
                                nx2,
                                ny2,
                                axis=1,
                                name='pc_reg_loss_decoded_roi')

        def _gather_3d(data, indices, n):
            datas = mx.sym.split(data,
                                 axis=-1,
                                 num_outputs=n,
                                 squeeze_axis=True)
            outputs = []
            for d in datas:
                outputs.append(mx.sym.pick(d, indices, axis=1))
            return mx.sym.stack(*outputs, axis=1)

        batch_image = self.p.batch_image
        image_roi = self.p.image_roi
        batch_roi = batch_image * image_roi
        num_class = self.p.num_class
        bbox_mean = self.p.regress_target.mean
        bbox_std = self.p.regress_target.std
        margin = self.p.TSD.pc_reg_margin

        gt_label = mx.sym.reshape(gt_label, (-1, ))

        bbox_delta = mx.sym.reshape(bbox_delta,
                                    (batch_image * image_roi, num_class, 4))
        tsd_bbox_delta = mx.sym.reshape(
            tsd_bbox_delta, (batch_image * image_roi, num_class, 4))

        bbox_delta = _gather_3d(bbox_delta, gt_label, n=4)
        tsd_bbox_delta = _gather_3d(tsd_bbox_delta, gt_label, n=4)

        boxes = _box_decode(rois, bbox_delta, bbox_mean, bbox_std)
        tsd_bboxes = _box_decode(tsd_rois, tsd_bbox_delta, bbox_mean, bbox_std)

        rois = mx.sym.reshape(rois, [batch_image, -1, 4])
        tsd_rois = mx.sym.reshape(tsd_rois, [batch_image, -1, 4])

        boxes = mx.sym.reshape(boxes, [batch_image, -1, 4])
        tsd_bboxes = mx.sym.reshape(tsd_bboxes, [batch_image, -1, 4])

        rois_group = mx.sym.split(rois,
                                  axis=0,
                                  num_outputs=batch_image,
                                  squeeze_axis=True)
        tsd_rois_group = mx.sym.split(tsd_rois,
                                      axis=0,
                                      num_outputs=batch_image,
                                      squeeze_axis=True)
        boxes_group = mx.sym.split(boxes,
                                   axis=0,
                                   num_outputs=batch_image,
                                   squeeze_axis=True)
        tsd_bboxes_group = mx.sym.split(tsd_bboxes,
                                        axis=0,
                                        num_outputs=batch_image,
                                        squeeze_axis=True)
        gt_group = mx.sym.split(gt_bbox,
                                axis=0,
                                num_outputs=batch_image,
                                squeeze_axis=True)

        ious = []
        tsd_ious = []
        for i, (rois_i, tsd_rois_i, boxes_i, tsd_boxes_i, gt_i) in \
            enumerate(zip(rois_group, tsd_rois_group, boxes_group, tsd_bboxes_group, gt_group)):

            iou_mat = get_iou_mat(rois_i, gt_i, image_roi)
            tsd_iou_mat = get_iou_mat(tsd_rois_i, gt_i, image_roi)

            matched_gt = mx.sym.gather_nd(
                gt_i, X.reshape(mx.sym.argmax(iou_mat, axis=1), [1, -1]))
            tsd_matched_gt = mx.sym.gather_nd(
                gt_i, X.reshape(mx.sym.argmax(tsd_iou_mat, axis=1), [1, -1]))

            matched_gt = mx.sym.slice_axis(matched_gt, axis=-1, begin=0, end=4)
            tsd_matched_gt = mx.sym.slice_axis(tsd_matched_gt,
                                               axis=-1,
                                               begin=0,
                                               end=4)

            ious.append(get_iou(boxes_i, matched_gt))
            tsd_ious.append(get_iou(tsd_boxes_i, tsd_matched_gt))

        iou = mx.sym.concat(*ious, dim=0)
        tsd_iou = mx.sym.concat(*tsd_ious, dim=0)

        weight = X.block_grad(gt_label != 0)
        iou = X.block_grad(iou)

        reg_pc_margin = mx.sym.minimum(1. - iou, margin)
        loss = mx.sym.relu(-(tsd_iou - iou - reg_pc_margin))

        grad_scale = 1. / batch_roi
        grad_scale *= scale_loss_shift
        loss = X.loss(weight * loss, grad_scale=grad_scale, name='reg_pc_loss')
        return loss
Ejemplo n.º 10
0
    def get_loss(self, rois, roi_feat, fpn_conv_feats, cls_label, bbox_target,
                 bbox_weight, gt_bbox):
        '''
        Args:
            rois: [batch_image, image_roi, 4]
            roi_feat: [batch_image * image_roi, 256, roi_size, roi_size]
            fpn_conv_feats: dict of FPN features, each [batch_image, in_channels, fh, fw]
            cls_label: [batch_image * image_roi]
            bbox_target: [batch_image * image_roi, num_class * 4]
            bbox_weight: [batch_image * image_roi, num_class * 4]
            gt_bbox: [batch_image, max_gt_num, 4]
        Returns:
            cls_loss: [batch_image * image_roi, num_class]
            reg_loss: [batch_image * image_roi, num_class * 4]
            tsd_cls_loss: [batch_image * image_roi, num_class]
            tsd_reg_loss: [batch_image * image_roi, num_class * 4]
            tsd_cls_pc_loss: [batch_image * image_roi]
            tsd_reg_pc_loss: [batch_image * image_roi]
            cls_label: [batch_image, image_roi]
        '''
        p = self.p
        assert not p.regress_target.class_agnostic
        batch_image = p.batch_image
        image_roi = p.image_roi
        batch_roi = batch_image * image_roi
        smooth_l1_scalar = p.regress_target.smooth_l1_scalar or 1.0

        cls_logit, bbox_delta, tsd_cls_logit, tsd_bbox_delta, delta_c, delta_r = self.get_output(
            fpn_conv_feats, roi_feat, rois, is_train=True)

        rois_r = self._get_delta_r_box(delta_r, rois)
        tsd_reg_target = self.get_reg_target(
            rois_r, gt_bbox)  # [batch_roi, num_class*4]

        scale_loss_shift = 128 if self.p.fp16 else 1.0
        # origin loss
        cls_loss = X.softmax_output(data=cls_logit,
                                    label=cls_label,
                                    normalization='batch',
                                    grad_scale=1.0 * scale_loss_shift,
                                    name='bbox_cls_loss')
        reg_loss = X.smooth_l1(bbox_delta - bbox_target,
                               scalar=smooth_l1_scalar,
                               name='bbox_reg_l1')
        reg_loss = bbox_weight * reg_loss
        reg_loss = X.loss(
            reg_loss,
            grad_scale=1.0 / batch_roi * scale_loss_shift,
            name='bbox_reg_loss',
        )
        # tsd loss
        tsd_cls_loss = X.softmax_output(data=tsd_cls_logit,
                                        label=cls_label,
                                        normalization='batch',
                                        grad_scale=1.0 * scale_loss_shift,
                                        name='tsd_bbox_cls_loss')
        tsd_reg_loss = X.smooth_l1(tsd_bbox_delta - tsd_reg_target,
                                   scalar=smooth_l1_scalar,
                                   name='tsd_bbox_reg_l1')
        tsd_reg_loss = bbox_weight * tsd_reg_loss
        tsd_reg_loss = X.loss(
            tsd_reg_loss,
            grad_scale=1.0 / batch_roi * scale_loss_shift,
            name='tsd_bbox_reg_loss',
        )

        losses = [
            cls_loss, reg_loss, tsd_cls_loss, tsd_reg_loss, tsd_cls_pc_loss
        ]
        if p.TSD.pc_cls:
            losses.append(
                self.cls_pc_loss(cls_logit, tsd_cls_logit, cls_label,
                                 scale_loss_shift))
        if p.TSD.pc_reg:
            losses.append(
                self.reg_pc_loss(bbox_delta, tsd_bbox_delta, rois, rois_r,
                                 gt_bbox, cls_label, scale_loss_shift))
        # append label
        cls_label = X.reshape(cls_label,
                              shape=(batch_image, -1),
                              name='bbox_label_reshape')
        cls_label = X.block_grad(cls_label, name='bbox_label_blockgrad')
        losses.append(cls_label)

        return tuple(losses)
Ejemplo n.º 11
0
    def emd_loss(self,
                 cls_logit,
                 cls_label,
                 cls_sec_logit,
                 cls_sec_label,
                 bbox_delta,
                 bbox_target,
                 bbox_sec_delta,
                 bbox_sec_target,
                 bbox_weight,
                 bbox_sec_weight,
                 prefix=""):
        p = self.p
        smooth_l1_scalar = p.regress_target.smooth_l1_scalar or 1.0
        scale_loss_shift = 128.0 if p.fp16 else 1.0
        cls_loss11 = self.softmax_entropy(cls_logit,
                                          cls_label,
                                          prefix=prefix + 'cls_loss11')
        cls_loss12 = self.softmax_entropy(cls_sec_logit,
                                          cls_sec_label,
                                          prefix=prefix + 'cls_loss12')
        cls_loss1 = cls_loss11 + cls_loss12

        cls_loss21 = self.softmax_entropy(cls_logit,
                                          cls_sec_label,
                                          prefix=prefix + 'cls_loss21')
        cls_loss22 = self.softmax_entropy(cls_sec_logit,
                                          cls_label,
                                          prefix=prefix + 'cls_loss22')
        cls_loss2 = cls_loss21 + cls_loss22

        # bounding box regression
        reg_loss11 = X.smooth_l1(bbox_delta - bbox_target,
                                 scalar=smooth_l1_scalar,
                                 name=prefix + 'bbox_reg_l1_11')
        reg_loss11 = bbox_weight * reg_loss11
        reg_loss12 = X.smooth_l1(bbox_sec_delta - bbox_sec_target,
                                 scalar=smooth_l1_scalar,
                                 name=prefix + 'bbox_reg_l1_12')
        reg_loss12 = bbox_sec_weight * reg_loss12
        reg_loss1 = reg_loss11 + reg_loss12

        reg_loss21 = X.smooth_l1(bbox_delta - bbox_sec_target,
                                 scalar=smooth_l1_scalar,
                                 name=prefix + 'bbox_reg_l1_21')
        reg_loss21 = bbox_sec_weight * reg_loss21
        reg_loss22 = X.smooth_l1(bbox_sec_delta - bbox_target,
                                 scalar=smooth_l1_scalar,
                                 name=prefix + 'bbox_reg_l1_22')
        reg_loss22 = bbox_weight * reg_loss22
        reg_loss2 = reg_loss21 + reg_loss22

        cls_reg_loss1 = mx.sym.sum(cls_loss1, axis=-1) + mx.sym.sum(reg_loss1,
                                                                    axis=-1)
        cls_reg_loss2 = mx.sym.sum(cls_loss2, axis=-1) + mx.sym.sum(reg_loss2,
                                                                    axis=-1)

        cls_reg_loss = mx.sym.minimum(cls_reg_loss1, cls_reg_loss2)
        cls_reg_loss = mx.sym.mean(cls_reg_loss)
        cls_reg_loss = X.loss(cls_reg_loss,
                              grad_scale=1.0 * scale_loss_shift,
                              name=prefix + 'cls_reg_loss')
        return cls_reg_loss