Beispiel #1
0
def offset_target_single(flat_anchors,
                         valid_flags,
                         bbox_pred,
                         gt_bboxes,
                         gt_cheby,
                         gt_skeleton,
                         gt_bboxes_ignore,
                         gt_labels,
                         img_meta,
                         target_means,
                         target_stds,
                         num_coords,
                         cfg,
                         label_channels=1,
                         sampling=True,
                         unmap_outputs=True):
    inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
                                       img_meta['img_shape'][:2],
                                       cfg.allowed_border)
    if not inside_flags.any():
        return (None, ) * 6
    # assign gt and sample anchors
    anchors = flat_anchors[inside_flags, :]

    if sampling:
        assign_result, sampling_result = assign_and_sample(
            anchors, gt_bboxes, gt_cheby, gt_skeleton, gt_bboxes_ignore, None,
            cfg)
    else:
        bbox_assigner = build_assigner(cfg.assigner)
        assign_result = bbox_assigner.assign(anchors, gt_bboxes,
                                             gt_bboxes_ignore, gt_labels)
        bbox_sampler = PseudoSampler()
        sampling_result = bbox_sampler.sample(assign_result, anchors,
                                              gt_bboxes)

    num_valid_anchors = anchors.shape[0]

    labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long)
    label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
    bbox_targets = torch.zeros((num_valid_anchors, num_coords),
                               dtype=torch.float).cuda()
    bbox_weights = torch.zeros((num_valid_anchors, num_coords),
                               dtype=torch.float).cuda()

    pos_inds = sampling_result.pos_inds
    neg_inds = sampling_result.neg_inds
    if len(pos_inds) > 0:
        deltas, weights = bbox2offset(sampling_result.pos_bboxes,
                                      sampling_result.pos_gt_bboxes,
                                      target_means, target_stds)
        bbox_targets[pos_inds, :] = deltas
        bbox_weights[pos_inds, :] = weights.unsqueeze(
            1) if cfg.use_centerness else 1.0
        if gt_labels is None:
            labels[pos_inds] = 1
        else:
            labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
        if cfg.pos_weight <= 0:
            label_weights[pos_inds] = weights if cfg.use_centerness else 1.0
        else:
            label_weights[pos_inds] = cfg.pos_weight
        # start adaptive weights
        use_adaptive_weights = False
        if use_adaptive_weights:
            pos_weights = torch.norm(pos_bbox_targets[:, -2:], dim=1)
            pos_weights = 1.0 / (0.5 + pos_weights)
            bbox_weights[pos_inds, :] = pos_weights.reshape(-1, 1)
            label_weights[pos_inds] = pos_weights
        # end of adaptive weights
    if len(neg_inds) > 0:
        label_weights[neg_inds] = 1.0

    # map up to original set of anchors
    if unmap_outputs:
        num_total_anchors = flat_anchors.size(0)
        labels = unmap(labels, num_total_anchors, inside_flags)
        label_weights = unmap(label_weights, num_total_anchors, inside_flags)
        bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
        bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
        return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
                neg_inds)
Beispiel #2
0
def radius_target_single(flat_anchors,
                         valid_flags,
                         bbox_pred,
                         gt_bboxes,
                         gt_cheby,
                         gt_skeleton,
                         gt_bboxes_ignore,
                         gt_labels,
                         img_meta,
                         target_means,
                         target_stds,
                         num_coords,
                         cfg,
                         label_channels=1,
                         sampling=True,
                         unmap_outputs=True):
    inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
                                       img_meta['img_shape'][:2],
                                       cfg.allowed_border)
    if not inside_flags.any():
        return (None, ) * 6
    # assign gt and sample anchors
    anchors = flat_anchors[inside_flags, :]
    #     print('at cheby_target, gt_bboxes_ignore:', gt_bboxes_ignore)
    if sampling:
        assign_result, sampling_result = assign_and_sample(
            anchors, gt_bboxes, gt_cheby, gt_skeleton, gt_bboxes_ignore, None,
            cfg)
    else:
        bbox_assigner = build_assigner(cfg.assigner)
        assign_result = bbox_assigner.assign(anchors, gt_bboxes,
                                             gt_bboxes_ignore, gt_labels)
        bbox_sampler = PseudoSampler()
        sampling_result = bbox_sampler.sample(assign_result, anchors,
                                              gt_bboxes)

    num_valid_anchors = anchors.shape[0]

    labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long)
    label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
    bbox_targets = torch.zeros((num_valid_anchors, num_coords - 3),
                               dtype=torch.float).cuda()
    bbox_weights = torch.zeros((num_valid_anchors, num_coords - 3),
                               dtype=torch.float).cuda()
    ctr_targets = torch.zeros((num_valid_anchors, 3), dtype=torch.float).cuda()
    ctr_weights = torch.zeros((num_valid_anchors, 3), dtype=torch.float).cuda()

    pos_inds = sampling_result.pos_inds
    neg_inds = sampling_result.neg_inds
    if len(pos_inds) > 0:
        deltas, weights = bbox2radius(sampling_result.pos_bboxes,
                                      sampling_result.pos_gt_bboxes,
                                      sampling_result.pos_gt_skeleton,
                                      num_coords, target_means, target_stds)

        bbox_targets[pos_inds, :] = deltas[:, :-3]
        bbox_weights[pos_inds, :] = weights.unsqueeze(
            1) if cfg.use_centerness else 1.0
        ctr_targets[pos_inds, :] = deltas[:, -3:]
        ctr_weights[pos_inds, :] = weights.unsqueeze(
            1) if cfg.use_centerness else 1.0
        if gt_labels is None:
            labels[pos_inds] = 1
        else:
            labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
        if cfg.pos_weight <= 0:
            label_weights[pos_inds] = weights if cfg.use_centerness else 1.0
        else:
            label_weights[pos_inds] = cfg.pos_weight
        # print("pos:", len(pos_inds), "neg:", len(neg_inds),  weights)
    if len(neg_inds) > 0:
        label_weights[neg_inds] = 1.0

    # map up to original set of anchors
    if unmap_outputs:
        num_total_anchors = flat_anchors.size(0)
        labels = unmap(labels, num_total_anchors, inside_flags)
        label_weights = unmap(label_weights, num_total_anchors, inside_flags)
        bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
        bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
        ctr_targets = unmap(ctr_targets, num_total_anchors, inside_flags)
        ctr_weights = unmap(ctr_weights, num_total_anchors, inside_flags)
    return (labels, label_weights, bbox_targets, bbox_weights, ctr_targets,
            ctr_weights, pos_inds, neg_inds)
Beispiel #3
0
    def _get_target_single(self,
                           flat_anchors,
                           valid_flags,
                           num_level_anchors,
                           gt_bboxes,
                           gt_bboxes_ignore,
                           gt_labels,
                           img_meta,
                           label_channels=1,
                           sampling=True,
                           unmap_outputs=True):
        inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
                                           img_meta['img_shape'][:2],
                                           self.train_cfg.allowed_border)
        if not inside_flags.any():
            return (None, ) * 6
        # assign gt and sample anchors
        anchors = flat_anchors[inside_flags, :]

        num_level_anchors_inside = self.get_num_level_anchors_inside(
            num_level_anchors, inside_flags)
        assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
                                             gt_bboxes, gt_bboxes_ignore,
                                             gt_labels)
        sampling_result = self.sampler.sample(assign_result, anchors,
                                              gt_bboxes)

        num_valid_anchors = anchors.shape[0]
        bbox_targets = torch.zeros_like(anchors)
        bbox_weights = torch.zeros_like(anchors)
        labels = anchors.new_full((num_valid_anchors, ),
                                  self.background_label,
                                  dtype=torch.long)
        label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)

        pos_inds = sampling_result.pos_inds
        neg_inds = sampling_result.neg_inds

        if len(pos_inds) > 0:
            pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
                                          sampling_result.pos_gt_bboxes,
                                          self.target_means, self.target_stds)
            bbox_targets[pos_inds, :] = pos_bbox_targets
            bbox_weights[pos_inds, :] = 1.0
            if gt_labels is None:
                labels[pos_inds] = 1
            else:
                labels[pos_inds] = gt_labels[
                    sampling_result.pos_assigned_gt_inds]
            if self.train_cfg.pos_weight <= 0:
                label_weights[pos_inds] = 1.0
            else:
                label_weights[pos_inds] = self.train_cfg.pos_weight
        if len(neg_inds) > 0:
            label_weights[neg_inds] = 1.0

        # map up to original set of anchors
        if unmap_outputs:
            num_total_anchors = flat_anchors.size(0)
            anchors = unmap(anchors, num_total_anchors, inside_flags)
            labels = unmap(labels,
                           num_total_anchors,
                           inside_flags,
                           fill=self.num_classes)
            label_weights = unmap(label_weights, num_total_anchors,
                                  inside_flags)
            bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
            bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)

        return (anchors, labels, label_weights, bbox_targets, bbox_weights,
                pos_inds, neg_inds)