Example #1
0
    def get_loss(self, conv_feat, gt_bbox):
        from models.RepPoints.point_ops import (_gen_points, _offset_to_pts,
                                                _point_target,
                                                _offset_to_boxes, _points2bbox)
        p = self.p
        batch_image = p.batch_image
        num_points = p.point_generate.num_points
        scale = p.point_generate.scale
        stride = p.point_generate.stride
        transform = p.point_generate.transform
        target_scale = p.point_target.target_scale
        num_pos = p.point_target.num_pos
        pos_iou_thr = p.bbox_target.pos_iou_thr
        neg_iou_thr = p.bbox_target.neg_iou_thr
        min_pos_iou = p.bbox_target.min_pos_iou

        pts_out_inits, pts_out_refines, cls_outs = self.get_output(conv_feat)

        points = dict()
        bboxes = dict()
        pts_coordinate_preds_inits = dict()
        pts_coordinate_preds_refines = dict()
        for s in stride:
            # generate points on base coordinate according to stride and size of feature map
            points["stride%s" % s] = _gen_points(mx.symbol,
                                                 pts_out_inits["stride%s" % s],
                                                 s)
            # generate bbox after init stage
            bboxes["stride%s" % s] = _offset_to_boxes(
                mx.symbol,
                points["stride%s" % s],
                X.block_grad(pts_out_inits["stride%s" % s]),
                s,
                transform,
                moment_transfer=self.moment_transfer)
            # generate final offsets in init stage
            pts_coordinate_preds_inits["stride%s" % s] = _offset_to_pts(
                mx.symbol, points["stride%s" % s],
                pts_out_inits["stride%s" % s], s, num_points)
            # generate final offsets in refine stage
            pts_coordinate_preds_refines["stride%s" % s] = _offset_to_pts(
                mx.symbol, points["stride%s" % s],
                pts_out_refines["stride%s" % s], s, num_points)

        # for init stage, use points assignment
        point_proposals = mx.symbol.tile(X.concat(
            [points["stride%s" % s] for s in stride],
            axis=1,
            name="point_concat"),
                                         reps=(batch_image, 1, 1))
        points_labels_init, points_gts_init, points_weight_init = _point_target(
            mx.symbol,
            point_proposals,
            gt_bbox,
            batch_image,
            "point",
            scale=target_scale,
            num_pos=num_pos)
        # for refine stage, use max iou assignment
        box_proposals = X.concat([bboxes["stride%s" % s] for s in stride],
                                 axis=1,
                                 name="box_concat")
        points_labels_refine, points_gts_refine, points_weight_refine = _point_target(
            mx.symbol,
            box_proposals,
            gt_bbox,
            batch_image,
            "box",
            pos_iou_thr=pos_iou_thr,
            neg_iou_thr=neg_iou_thr,
            min_pos_iou=min_pos_iou)

        bboxes_out_strides = dict()
        for s in stride:
            cls_outs["stride%s" % s] = X.reshape(
                X.transpose(data=cls_outs["stride%s" % s], axes=(0, 2, 3, 1)),
                (0, -3, -2))
            bboxes_out_strides["stride%s" % s] = mx.symbol.repeat(
                mx.symbol.ones_like(
                    mx.symbol.slice_axis(
                        cls_outs["stride%s" % s], begin=0, end=1, axis=-1)),
                repeats=4,
                axis=-1) * s

        # cls branch
        cls_outs_concat = X.concat([cls_outs["stride%s" % s] for s in stride],
                                   axis=1,
                                   name="cls_concat")
        cls_loss = X.focal_loss(data=cls_outs_concat,
                                label=points_labels_refine,
                                normalization='valid',
                                alpha=p.focal_loss.alpha,
                                gamma=p.focal_loss.gamma,
                                grad_scale=1.0,
                                workspace=1500,
                                name="cls_loss")

        # init box branch
        pts_inits_concat_ = X.concat(
            [pts_coordinate_preds_inits["stride%s" % s] for s in stride],
            axis=1,
            name="pts_init_concat_")
        pts_inits_concat = X.reshape(pts_inits_concat_, (-3, -2),
                                     name="pts_inits_concat")
        bboxes_inits_concat_ = _points2bbox(
            mx.symbol,
            pts_inits_concat,
            transform,
            y_first=False,
            moment_transfer=self.moment_transfer)
        bboxes_inits_concat = X.reshape(bboxes_inits_concat_,
                                        (-4, batch_image, -1, -2))
        normalize_term = X.concat(
            [bboxes_out_strides["stride%s" % s] for s in stride],
            axis=1,
            name="normalize_term") * scale
        pts_init_loss = X.smooth_l1(
            data=(bboxes_inits_concat - points_gts_init) / normalize_term,
            scalar=3.0,
            name="pts_init_l1_loss")
        pts_init_loss = pts_init_loss * points_weight_init
        pts_init_loss = X.bbox_norm(data=pts_init_loss,
                                    label=points_labels_init,
                                    name="pts_init_norm_loss")
        pts_init_loss = X.make_loss(data=pts_init_loss,
                                    grad_scale=0.5,
                                    name="pts_init_loss")
        points_init_labels = X.block_grad(points_labels_refine,
                                          name="points_init_labels")

        # refine box branch
        pts_refines_concat_ = X.concat(
            [pts_coordinate_preds_refines["stride%s" % s] for s in stride],
            axis=1,
            name="pts_refines_concat_")
        pts_refines_concat = X.reshape(pts_refines_concat_, (-3, -2),
                                       name="pts_refines_concat")
        bboxes_refines_concat_ = _points2bbox(
            mx.symbol,
            pts_refines_concat,
            transform,
            y_first=False,
            moment_transfer=self.moment_transfer)
        bboxes_refines_concat = X.reshape(bboxes_refines_concat_,
                                          (-4, batch_image, -1, -2))
        pts_refine_loss = X.smooth_l1(
            data=(bboxes_refines_concat - points_gts_refine) / normalize_term,
            scalar=3.0,
            name="pts_refine_l1_loss")
        pts_refine_loss = pts_refine_loss * points_weight_refine
        pts_refine_loss = X.bbox_norm(data=pts_refine_loss,
                                      label=points_labels_refine,
                                      name="pts_refine_norm_loss")
        pts_refine_loss = X.make_loss(data=pts_refine_loss,
                                      grad_scale=1.0,
                                      name="pts_refine_loss")
        points_refine_labels = X.block_grad(points_labels_refine,
                                            name="point_refine_labels")

        return cls_loss, pts_init_loss, pts_refine_loss, points_init_labels, points_refine_labels
Example #2
0
    def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight):
        import mxnet as mx

        p = self.p
        stride = p.anchor_generate.stride
        if not isinstance(stride, tuple):
            stride = (stride)
        num_class = p.num_class
        num_base_anchor = len(p.anchor_generate.ratio) * len(
            p.anchor_generate.scale)
        image_per_device = p.batch_image
        sync_loss = p.sync_loss or False

        cls_logit_dict, bbox_delta_dict = self.get_output(conv_feat)
        cls_logit_reshape_list = []
        bbox_delta_reshape_list = []

        scale_loss_shift = 128.0 if p.fp16 else 1.0
        if sync_loss:
            fg_count = X.var("rpn_fg_count") * image_per_device
            fg_count = mx.sym.slice_axis(fg_count, axis=0, begin=0, end=1)

        # reshape logit and delta
        for s in stride:
            # (N, A * C, H, W) -> (N, A, C, H * W)
            cls_logit = X.reshape(data=cls_logit_dict["stride%s" % s],
                                  shape=(0, num_base_anchor, num_class - 1,
                                         -1),
                                  name="cls_stride%s_reshape" % s)
            # (N, A, C, H * W) -> (N, A, H * W, C)
            cls_logit = X.transpose(data=cls_logit,
                                    axes=(0, 1, 3, 2),
                                    name="cls_stride%s_transpose" % s)
            # (N, A, H * W, C) -> (N, A * H * W, C)
            cls_logit = X.reshape(data=cls_logit,
                                  shape=(0, -3, 0),
                                  name="cls_stride%s_transpose_reshape" % s)

            # (N, A * 4, H, W) -> (N, A * 4, H * W)
            bbox_delta = X.reshape(data=bbox_delta_dict["stride%s" % s],
                                   shape=(0, 0, -1),
                                   name="bbox_stride%s_reshape" % s)

            cls_logit_reshape_list.append(cls_logit)
            bbox_delta_reshape_list.append(bbox_delta)

        cls_logit_concat = X.concat(cls_logit_reshape_list,
                                    axis=1,
                                    name="bbox_logit_concat")
        bbox_delta_concat = X.concat(bbox_delta_reshape_list,
                                     axis=2,
                                     name="bbox_delta_concat")

        # classification loss
        if sync_loss:
            cls_loss = X.focal_loss(data=cls_logit_concat,
                                    label=cls_label,
                                    alpha=p.focal_loss.alpha,
                                    gamma=p.focal_loss.gamma,
                                    workspace=1500,
                                    out_grad=True)
            cls_loss = mx.sym.broadcast_div(cls_loss, fg_count)
            cls_loss = X.make_loss(cls_loss,
                                   grad_scale=scale_loss_shift,
                                   name="cls_loss")
        else:
            cls_loss = X.focal_loss(data=cls_logit_concat,
                                    label=cls_label,
                                    normalization='valid',
                                    alpha=p.focal_loss.alpha,
                                    gamma=p.focal_loss.gamma,
                                    grad_scale=1.0 * scale_loss_shift,
                                    workspace=1024,
                                    name="cls_loss")

        scalar = 0.11
        # regression loss
        bbox_loss = bbox_weight * X.smooth_l1(
            data=bbox_delta_concat - bbox_target,
            scalar=math.sqrt(1 / scalar),
            name="bbox_loss")
        if sync_loss:
            bbox_loss = mx.sym.broadcast_div(bbox_loss, fg_count)
        else:
            bbox_loss = X.bbox_norm(data=bbox_loss,
                                    label=cls_label,
                                    name="bbox_norm")
        reg_loss = X.make_loss(data=bbox_loss,
                               grad_scale=1.0 * scale_loss_shift,
                               name="reg_loss")

        return cls_loss, reg_loss
Example #3
0
    def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight):
        p = self.p
        stride = p.anchor_generate.stride
        if not isinstance(stride, tuple):
            stride = (stride)
        num_class = p.num_class
        num_base_anchor = len(p.anchor_generate.ratio) * len(
            p.anchor_generate.scale)

        cls_logit_list, bbox_delta_list = self.get_output(conv_feat)

        # reshape logit and delta
        for i, s in enumerate(stride):
            # (N, A * C, H, W) -> (N, A, C, H * W)
            cls_logit = X.reshape(data=cls_logit_list[i],
                                  shape=(0, num_base_anchor, num_class - 1,
                                         -1),
                                  name="cls_stride%s_reshape" % s)
            # (N, A, C, H * W) -> (N, A, H * W, C)
            cls_logit = X.transpose(data=cls_logit,
                                    axes=(0, 1, 3, 2),
                                    name="cls_stride%s_transpose" % s)
            # (N, A, H * W, C) -> (N, A * H * W, C)
            cls_logit = X.reshape(data=cls_logit,
                                  shape=(0, -3, 0),
                                  name="cls_stride%s_transpose_reshape" % s)

            # (N, A * 4, H, W) -> (N, A * 4, H * W)
            bbox_delta = X.reshape(data=bbox_delta_list[i],
                                   shape=(0, 0, -1),
                                   name="bbox_stride%s_reshape" % s)

            cls_logit_list[i] = cls_logit
            bbox_delta_list[i] = bbox_delta

        cls_logit_concat = X.concat(cls_logit_list,
                                    axis=1,
                                    name="bbox_logit_concat")
        bbox_delta_concat = X.concat(bbox_delta_list,
                                     axis=2,
                                     name="bbox_delta_concat")

        # classification loss
        cls_loss = X.focal_loss(data=cls_logit_concat,
                                label=cls_label,
                                normalization='valid',
                                alpha=p.focal_loss.alpha,
                                gamma=p.focal_loss.gamma,
                                grad_scale=1.0,
                                workspace=1024,
                                name="cls_loss")

        scalar = 0.11
        # regression loss
        bbox_norm = X.bbox_norm(data=bbox_delta_concat - bbox_target,
                                label=cls_label,
                                name="bbox_norm")
        bbox_loss = bbox_weight * X.smooth_l1(
            data=bbox_norm, scalar=math.sqrt(1 / scalar), name="bbox_loss")
        reg_loss = X.make_loss(data=bbox_loss, grad_scale=1.0, name="reg_loss")

        return cls_loss, reg_loss