Beispiel #1
0
    def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight):
        import mxnet as mx

        p = self.p
        stride = p.anchor_generate.stride
        if not isinstance(stride, tuple):
            stride = (stride)
        num_class = p.num_class
        num_base_anchor = len(p.anchor_generate.ratio) * len(
            p.anchor_generate.scale)
        image_per_device = p.batch_image
        sync_loss = p.sync_loss or False

        cls_logit_dict, bbox_delta_dict = self.get_output(conv_feat)
        cls_logit_reshape_list = []
        bbox_delta_reshape_list = []

        scale_loss_shift = 128.0 if p.fp16 else 1.0
        if sync_loss:
            fg_count = X.var("rpn_fg_count") * image_per_device
            fg_count = mx.sym.slice_axis(fg_count, axis=0, begin=0, end=1)

        # reshape logit and delta
        for s in stride:
            # (N, A * C, H, W) -> (N, A, C, H * W)
            cls_logit = X.reshape(data=cls_logit_dict["stride%s" % s],
                                  shape=(0, num_base_anchor, num_class - 1,
                                         -1),
                                  name="cls_stride%s_reshape" % s)
            # (N, A, C, H * W) -> (N, A, H * W, C)
            cls_logit = X.transpose(data=cls_logit,
                                    axes=(0, 1, 3, 2),
                                    name="cls_stride%s_transpose" % s)
            # (N, A, H * W, C) -> (N, A * H * W, C)
            cls_logit = X.reshape(data=cls_logit,
                                  shape=(0, -3, 0),
                                  name="cls_stride%s_transpose_reshape" % s)

            # (N, A * 4, H, W) -> (N, A * 4, H * W)
            bbox_delta = X.reshape(data=bbox_delta_dict["stride%s" % s],
                                   shape=(0, 0, -1),
                                   name="bbox_stride%s_reshape" % s)

            cls_logit_reshape_list.append(cls_logit)
            bbox_delta_reshape_list.append(bbox_delta)

        cls_logit_concat = X.concat(cls_logit_reshape_list,
                                    axis=1,
                                    name="bbox_logit_concat")
        bbox_delta_concat = X.concat(bbox_delta_reshape_list,
                                     axis=2,
                                     name="bbox_delta_concat")

        # classification loss
        if sync_loss:
            cls_loss = X.focal_loss(data=cls_logit_concat,
                                    label=cls_label,
                                    alpha=p.focal_loss.alpha,
                                    gamma=p.focal_loss.gamma,
                                    workspace=1500,
                                    out_grad=True)
            cls_loss = mx.sym.broadcast_div(cls_loss, fg_count)
            cls_loss = X.make_loss(cls_loss,
                                   grad_scale=scale_loss_shift,
                                   name="cls_loss")
        else:
            cls_loss = X.focal_loss(data=cls_logit_concat,
                                    label=cls_label,
                                    normalization='valid',
                                    alpha=p.focal_loss.alpha,
                                    gamma=p.focal_loss.gamma,
                                    grad_scale=1.0 * scale_loss_shift,
                                    workspace=1024,
                                    name="cls_loss")

        scalar = 0.11
        # regression loss
        bbox_loss = bbox_weight * X.smooth_l1(
            data=bbox_delta_concat - bbox_target,
            scalar=math.sqrt(1 / scalar),
            name="bbox_loss")
        if sync_loss:
            bbox_loss = mx.sym.broadcast_div(bbox_loss, fg_count)
        else:
            bbox_loss = X.bbox_norm(data=bbox_loss,
                                    label=cls_label,
                                    name="bbox_norm")
        reg_loss = X.make_loss(data=bbox_loss,
                               grad_scale=1.0 * scale_loss_shift,
                               name="reg_loss")

        return cls_loss, reg_loss
Beispiel #2
0
    def get_loss(self, conv_feat, gt_bbox, im_info):
        import mxnet as mx

        p = self.p
        stride = p.anchor_generate.stride
        if not isinstance(stride, tuple):
            stride = (stride)
        num_class = p.num_class
        num_base_anchor = len(p.anchor_generate.ratio) * len(p.anchor_generate.scale)
        image_per_device = p.batch_image

        cls_logit_dict, bbox_delta_dict = self.get_output(conv_feat)
        cls_logit_reshape_list = []
        bbox_delta_reshape_list = []
        feat_list = []

        scale_loss_shift = 128.0 if p.fp16 else 1.0

        # reshape logit and delta
        for s in stride:
            # (N, A * C, H, W) -> (N, A * C, H * W)
            cls_logit = X.reshape(
                data=cls_logit_dict["stride%s" % s],
                shape=(0, 0, -1),
                name="cls_stride%s_reshape" % s
            )

            # (N, A * 4, H, W) -> (N, A * 4, H * W)
            bbox_delta = X.reshape(
                data=bbox_delta_dict["stride%s" % s],
                shape=(0, 0, -1),
                name="bbox_stride%s_reshape" % s
            )

            cls_logit_reshape_list.append(cls_logit)
            bbox_delta_reshape_list.append(bbox_delta)
            feat_list.append(cls_logit_dict["stride%s" % s])

        # cls_logits -> (N, H' * W' * A, C)
        cls_logits = X.concat(cls_logit_reshape_list, axis=2, name="cls_logit_concat")
        cls_logits = X.transpose(cls_logits, axes=(0, 2, 1), name="cls_logit_transpose")
        cls_logits = X.reshape(cls_logits, shape=(0, -1, num_class - 1), name="cls_logit_reshape")
        cls_prob = X.sigmoid(cls_logits)
        # bbox_deltas -> (N, H' * W' * A, 4)
        bbox_deltas = X.concat(bbox_delta_reshape_list, axis=2, name="bbox_delta_concat")
        bbox_deltas = X.transpose(bbox_deltas, axes=(0, 2, 1), name="bbox_delta_transpose")
        bbox_deltas = X.reshape(bbox_deltas, shape=(0, -1, 4), name="bbox_delta_reshape")

        anchor_list = [self.anchor_dict["stride%s" % s] for s in stride]
        bbox_thr = p.anchor_assign.bbox_thr
        pre_anchor_top_n = p.anchor_assign.pre_anchor_top_n
        alpha = p.focal_loss.alpha
        gamma = p.focal_loss.gamma
        anchor_target_mean = p.head.mean or (0, 0, 0, 0)
        anchor_target_std = p.head.std or (1, 1, 1, 1)

        from models.FreeAnchor.ops import _prepare_anchors, _positive_loss, _negative_loss
        anchors = _prepare_anchors(
            mx.sym, feat_list, anchor_list, image_per_device, num_base_anchor)

        positive_loss = _positive_loss(
            mx.sym, anchors, gt_bbox, cls_prob, bbox_deltas, image_per_device,
            alpha, pre_anchor_top_n, anchor_target_mean, anchor_target_std
        )
        positive_loss = X.make_loss(
            data=positive_loss,
            grad_scale=1.0 * scale_loss_shift,
            name="positive_loss"
        )

        negative_loss = _negative_loss(
            mx.sym, anchors, gt_bbox, cls_prob, bbox_deltas, im_info, image_per_device,
            num_class, alpha, gamma, pre_anchor_top_n, bbox_thr,
            anchor_target_mean, anchor_target_std
        )
        negative_loss = X.make_loss(
            data=negative_loss,
            grad_scale=1.0 * scale_loss_shift,
            name="negative_loss"
        )

        return positive_loss, negative_loss
Beispiel #3
0
    def get_loss(self, conv_feat, gt_bbox):
        from models.RepPoints.point_ops import (_gen_points, _offset_to_pts,
                                                _point_target,
                                                _offset_to_boxes, _points2bbox)
        p = self.p
        batch_image = p.batch_image
        num_points = p.point_generate.num_points
        scale = p.point_generate.scale
        stride = p.point_generate.stride
        transform = p.point_generate.transform
        target_scale = p.point_target.target_scale
        num_pos = p.point_target.num_pos
        pos_iou_thr = p.bbox_target.pos_iou_thr
        neg_iou_thr = p.bbox_target.neg_iou_thr
        min_pos_iou = p.bbox_target.min_pos_iou

        pts_out_inits, pts_out_refines, cls_outs = self.get_output(conv_feat)

        points = dict()
        bboxes = dict()
        pts_coordinate_preds_inits = dict()
        pts_coordinate_preds_refines = dict()
        for s in stride:
            # generate points on base coordinate according to stride and size of feature map
            points["stride%s" % s] = _gen_points(mx.symbol,
                                                 pts_out_inits["stride%s" % s],
                                                 s)
            # generate bbox after init stage
            bboxes["stride%s" % s] = _offset_to_boxes(
                mx.symbol,
                points["stride%s" % s],
                X.block_grad(pts_out_inits["stride%s" % s]),
                s,
                transform,
                moment_transfer=self.moment_transfer)
            # generate final offsets in init stage
            pts_coordinate_preds_inits["stride%s" % s] = _offset_to_pts(
                mx.symbol, points["stride%s" % s],
                pts_out_inits["stride%s" % s], s, num_points)
            # generate final offsets in refine stage
            pts_coordinate_preds_refines["stride%s" % s] = _offset_to_pts(
                mx.symbol, points["stride%s" % s],
                pts_out_refines["stride%s" % s], s, num_points)

        # for init stage, use points assignment
        point_proposals = mx.symbol.tile(X.concat(
            [points["stride%s" % s] for s in stride],
            axis=1,
            name="point_concat"),
                                         reps=(batch_image, 1, 1))
        points_labels_init, points_gts_init, points_weight_init = _point_target(
            mx.symbol,
            point_proposals,
            gt_bbox,
            batch_image,
            "point",
            scale=target_scale,
            num_pos=num_pos)
        # for refine stage, use max iou assignment
        box_proposals = X.concat([bboxes["stride%s" % s] for s in stride],
                                 axis=1,
                                 name="box_concat")
        points_labels_refine, points_gts_refine, points_weight_refine = _point_target(
            mx.symbol,
            box_proposals,
            gt_bbox,
            batch_image,
            "box",
            pos_iou_thr=pos_iou_thr,
            neg_iou_thr=neg_iou_thr,
            min_pos_iou=min_pos_iou)

        bboxes_out_strides = dict()
        for s in stride:
            cls_outs["stride%s" % s] = X.reshape(
                X.transpose(data=cls_outs["stride%s" % s], axes=(0, 2, 3, 1)),
                (0, -3, -2))
            bboxes_out_strides["stride%s" % s] = mx.symbol.repeat(
                mx.symbol.ones_like(
                    mx.symbol.slice_axis(
                        cls_outs["stride%s" % s], begin=0, end=1, axis=-1)),
                repeats=4,
                axis=-1) * s

        # cls branch
        cls_outs_concat = X.concat([cls_outs["stride%s" % s] for s in stride],
                                   axis=1,
                                   name="cls_concat")
        cls_loss = X.focal_loss(data=cls_outs_concat,
                                label=points_labels_refine,
                                normalization='valid',
                                alpha=p.focal_loss.alpha,
                                gamma=p.focal_loss.gamma,
                                grad_scale=1.0,
                                workspace=1500,
                                name="cls_loss")

        # init box branch
        pts_inits_concat_ = X.concat(
            [pts_coordinate_preds_inits["stride%s" % s] for s in stride],
            axis=1,
            name="pts_init_concat_")
        pts_inits_concat = X.reshape(pts_inits_concat_, (-3, -2),
                                     name="pts_inits_concat")
        bboxes_inits_concat_ = _points2bbox(
            mx.symbol,
            pts_inits_concat,
            transform,
            y_first=False,
            moment_transfer=self.moment_transfer)
        bboxes_inits_concat = X.reshape(bboxes_inits_concat_,
                                        (-4, batch_image, -1, -2))
        normalize_term = X.concat(
            [bboxes_out_strides["stride%s" % s] for s in stride],
            axis=1,
            name="normalize_term") * scale
        pts_init_loss = X.smooth_l1(
            data=(bboxes_inits_concat - points_gts_init) / normalize_term,
            scalar=3.0,
            name="pts_init_l1_loss")
        pts_init_loss = pts_init_loss * points_weight_init
        pts_init_loss = X.bbox_norm(data=pts_init_loss,
                                    label=points_labels_init,
                                    name="pts_init_norm_loss")
        pts_init_loss = X.make_loss(data=pts_init_loss,
                                    grad_scale=0.5,
                                    name="pts_init_loss")
        points_init_labels = X.block_grad(points_labels_refine,
                                          name="points_init_labels")

        # refine box branch
        pts_refines_concat_ = X.concat(
            [pts_coordinate_preds_refines["stride%s" % s] for s in stride],
            axis=1,
            name="pts_refines_concat_")
        pts_refines_concat = X.reshape(pts_refines_concat_, (-3, -2),
                                       name="pts_refines_concat")
        bboxes_refines_concat_ = _points2bbox(
            mx.symbol,
            pts_refines_concat,
            transform,
            y_first=False,
            moment_transfer=self.moment_transfer)
        bboxes_refines_concat = X.reshape(bboxes_refines_concat_,
                                          (-4, batch_image, -1, -2))
        pts_refine_loss = X.smooth_l1(
            data=(bboxes_refines_concat - points_gts_refine) / normalize_term,
            scalar=3.0,
            name="pts_refine_l1_loss")
        pts_refine_loss = pts_refine_loss * points_weight_refine
        pts_refine_loss = X.bbox_norm(data=pts_refine_loss,
                                      label=points_labels_refine,
                                      name="pts_refine_norm_loss")
        pts_refine_loss = X.make_loss(data=pts_refine_loss,
                                      grad_scale=1.0,
                                      name="pts_refine_loss")
        points_refine_labels = X.block_grad(points_labels_refine,
                                            name="point_refine_labels")

        return cls_loss, pts_init_loss, pts_refine_loss, points_init_labels, points_refine_labels
Beispiel #4
0
    def get_prediction(self, conv_feat, im_info):
        from models.RepPoints.point_ops import _gen_points, _points2bbox
        p = self.p
        batch_image = p.batch_image
        stride = p.point_generate.stride
        transform = p.point_generate.transform
        pre_nms_top_n = p.proposal.pre_nms_top_n

        pts_out_inits, pts_out_refines, cls_outs = self.get_output(conv_feat)

        cls_score_dict = dict()
        bbox_xyxy_dict = dict()
        for s in stride:
            # NOTE: pre_nms_top_n_ is hard-coded as -1 because the number of proposals is less
            # than pre_nms_top_n in these low-resolution feature maps. Also note that one should
            # select the appropriate params here if using low-resolution images as input.
            pre_nms_top_n_ = pre_nms_top_n if s <= 32 else -1
            points_ = _gen_points(mx.symbol, pts_out_inits["stride%s" % s], s)
            preds_refines_ = _points2bbox(mx.symbol,
                                          pts_out_refines["stride%s" % s],
                                          transform,
                                          moment_transfer=self.moment_transfer)
            preds_refines_ = X.reshape(
                X.transpose(data=preds_refines_, axes=(0, 2, 3, 1)),
                (0, -3, -2))
            cls_ = X.reshape(
                X.transpose(data=cls_outs["stride%s" % s], axes=(0, 2, 3, 1)),
                (0, -3, -2))
            scores_ = X.sigmoid(cls_)
            max_scores_ = mx.symbol.max(scores_, axis=-1)
            max_index_ = mx.symbol.topk(max_scores_, axis=1, k=pre_nms_top_n_)
            scores_dict = dict()
            bboxes_dict = dict()
            for i in range(batch_image):
                max_index_i = X.reshape(
                    mx.symbol.slice_axis(max_index_,
                                         axis=0,
                                         begin=i,
                                         end=i + 1), (-1, ))
                scores_i = X.reshape(
                    mx.symbol.slice_axis(scores_, axis=0, begin=i, end=i + 1),
                    (-3, -2))
                points_i = X.reshape(points_, (-3, -2))
                preds_refines_i = X.reshape(
                    mx.symbol.slice_axis(preds_refines_,
                                         axis=0,
                                         begin=i,
                                         end=i + 1), (-3, -2))
                scores_i = mx.symbol.take(scores_i, max_index_i)
                points_i = mx.symbol.take(points_i, max_index_i)
                preds_refines_i = mx.symbol.take(preds_refines_i, max_index_i)
                points_i = mx.symbol.slice_axis(points_i,
                                                axis=-1,
                                                begin=0,
                                                end=2)
                points_xyxy_i = X.concat([points_i, points_i],
                                         axis=-1,
                                         name="points_xyxy_b{}_s{}".format(
                                             i, s))
                bboxes_i = preds_refines_i * s + points_xyxy_i
                im_info_i = mx.symbol.slice_axis(im_info,
                                                 axis=0,
                                                 begin=i,
                                                 end=i + 1)
                h_i, w_i, _ = mx.symbol.split(im_info_i, num_outputs=3, axis=1)
                l_i, t_i, r_i, b_i = mx.symbol.split(bboxes_i,
                                                     num_outputs=4,
                                                     axis=1)
                clip_l_i = mx.symbol.maximum(
                    mx.symbol.broadcast_minimum(l_i, w_i - 1.0), 0.0)
                clip_t_i = mx.symbol.maximum(
                    mx.symbol.broadcast_minimum(t_i, h_i - 1.0), 0.0)
                clip_r_i = mx.symbol.maximum(
                    mx.symbol.broadcast_minimum(r_i, w_i - 1.0), 0.0)
                clip_b_i = mx.symbol.maximum(
                    mx.symbol.broadcast_minimum(b_i, h_i - 1.0), 0.0)
                clip_bboxes_i = X.concat(
                    [clip_l_i, clip_t_i, clip_r_i, clip_b_i],
                    axis=1,
                    name="clip_bboxes_b{}_s{}".format(i, s))
                scores_dict["img%s" % i] = scores_i
                bboxes_dict["img%s" % i] = clip_bboxes_i
            cls_score_ = mx.symbol.stack(
                *[scores_dict["img%s" % i] for i in range(batch_image)],
                axis=0)
            pad_zeros_ = mx.symbol.zeros_like(
                mx.symbol.slice_axis(cls_score_, axis=-1, begin=0, end=1))
            cls_score_ = X.concat([pad_zeros_, cls_score_],
                                  axis=-1,
                                  name="cls_score_s{}".format(s))
            bboxes_ = mx.symbol.stack(
                *[bboxes_dict["img%s" % i] for i in range(batch_image)],
                axis=0)
            cls_score_dict["stride%s" % s] = cls_score_
            bbox_xyxy_dict["stride%s" % s] = bboxes_

        cls_score = X.concat([cls_score_dict["stride%s" % s] for s in stride],
                             axis=1,
                             name="cls_score_concat")
        bbox_xyxy = X.concat([bbox_xyxy_dict["stride%s" % s] for s in stride],
                             axis=1,
                             name="bbox_xyxy_concat")

        return cls_score, bbox_xyxy
Beispiel #5
0
    def get_loss(self, conv_feat, cls_label, bbox_target, bbox_weight):
        p = self.p
        stride = p.anchor_generate.stride
        if not isinstance(stride, tuple):
            stride = (stride)
        num_class = p.num_class
        num_base_anchor = len(p.anchor_generate.ratio) * len(
            p.anchor_generate.scale)

        cls_logit_list, bbox_delta_list = self.get_output(conv_feat)

        # reshape logit and delta
        for i, s in enumerate(stride):
            # (N, A * C, H, W) -> (N, A, C, H * W)
            cls_logit = X.reshape(data=cls_logit_list[i],
                                  shape=(0, num_base_anchor, num_class - 1,
                                         -1),
                                  name="cls_stride%s_reshape" % s)
            # (N, A, C, H * W) -> (N, A, H * W, C)
            cls_logit = X.transpose(data=cls_logit,
                                    axes=(0, 1, 3, 2),
                                    name="cls_stride%s_transpose" % s)
            # (N, A, H * W, C) -> (N, A * H * W, C)
            cls_logit = X.reshape(data=cls_logit,
                                  shape=(0, -3, 0),
                                  name="cls_stride%s_transpose_reshape" % s)

            # (N, A * 4, H, W) -> (N, A * 4, H * W)
            bbox_delta = X.reshape(data=bbox_delta_list[i],
                                   shape=(0, 0, -1),
                                   name="bbox_stride%s_reshape" % s)

            cls_logit_list[i] = cls_logit
            bbox_delta_list[i] = bbox_delta

        cls_logit_concat = X.concat(cls_logit_list,
                                    axis=1,
                                    name="bbox_logit_concat")
        bbox_delta_concat = X.concat(bbox_delta_list,
                                     axis=2,
                                     name="bbox_delta_concat")

        # classification loss
        cls_loss = X.focal_loss(data=cls_logit_concat,
                                label=cls_label,
                                normalization='valid',
                                alpha=p.focal_loss.alpha,
                                gamma=p.focal_loss.gamma,
                                grad_scale=1.0,
                                workspace=1024,
                                name="cls_loss")

        scalar = 0.11
        # regression loss
        bbox_norm = X.bbox_norm(data=bbox_delta_concat - bbox_target,
                                label=cls_label,
                                name="bbox_norm")
        bbox_loss = bbox_weight * X.smooth_l1(
            data=bbox_norm, scalar=math.sqrt(1 / scalar), name="bbox_loss")
        reg_loss = X.make_loss(data=bbox_loss, grad_scale=1.0, name="reg_loss")

        return cls_loss, reg_loss