コード例 #1
0
ファイル: rbbox_head.py プロジェクト: ming71/RIDet
 def get_target(self, sampling_results, gt_bboxes, gt_labels,
                rcnn_train_cfg):
     pos_proposals = [res.pos_bboxes for res in sampling_results]
     neg_proposals = [res.neg_bboxes for res in sampling_results]
     pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
     pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
     reg_classes = 1 if self.reg_class_agnostic else self.num_classes
     cls_reg_targets = bbox_target(pos_proposals,
                                   neg_proposals,
                                   pos_gt_bboxes,
                                   pos_gt_labels,
                                   rcnn_train_cfg,
                                   reg_classes,
                                   target_means=self.target_means,
                                   target_stds=self.target_stds)
     return cls_reg_targets
コード例 #2
0
 def get_target(self, sampling_results, gt_bboxes, gt_labels,
                rcnn_train_cfg):
     pos_proposals = [res.pos_bboxes for res in sampling_results]
     neg_proposals = [res.neg_bboxes for res in sampling_results]
     pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
     pos_gt_bboxes_3d = [res.pos_gt_bboxes_3d for res in sampling_results]
     # print(pos_gt_bboxes[0].size(), pos_gt_bboxes_3d[0].size())
     pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
     reg_classes = 1 if self.reg_class_agnostic else self.num_classes
     # from mmdet.apis import get_root_logger
     # logger = get_root_logger()
     # logger.info('pos_gt_bboxes{}'.format(pos_gt_bboxes))
     # logger.info('pos_gt_bboxes_3d{}'.format(pos_gt_bboxes_3d[:2]))  # [[ 5.5125e-01,  9.7504e-01, -5.9984e-01,  3.7557e+02,  3.5705e+02, -1.3744e+00,  8.6807e-01,  1.0626e+00],
     cls_reg_targets = bbox_target(pos_proposals,
                                   neg_proposals,
                                   pos_gt_bboxes,
                                   pos_gt_labels,
                                   rcnn_train_cfg,
                                   reg_classes,
                                   target_means=self.target_means,
                                   target_stds=self.target_stds,
                                   pos_gt_bboxes_3d=pos_gt_bboxes_3d)
     # logger.info('cls_reg_targets{}'.format(cls_reg_targets[3][:2]))
     return cls_reg_targets
コード例 #3
0
ファイル: bbox_head.py プロジェクト: apulis/lvis
    def get_target(self,
                   sampling_results,
                   gt_bboxes,
                   gt_labels,
                   rcnn_train_cfg,
                   img_metas=None):
        pos_proposals = [res.pos_bboxes for res in sampling_results]
        neg_proposals = [res.neg_bboxes for res in sampling_results]
        pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
        pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
        reg_classes = 1 if self.reg_class_agnostic else self.num_classes
        labels, label_weights, bbox_targets, bbox_weights = bbox_target(
            pos_proposals,
            neg_proposals,
            pos_gt_bboxes,
            pos_gt_labels,
            rcnn_train_cfg,
            reg_classes,
            target_means=self.target_means,
            target_stds=self.target_stds,
            concat=self.concat_targets)

        labels, label_weights, target_meta = process_class_label(
            labels,
            label_weights,
            img_metas,
            concat_targets=self.concat_targets,
            sparse_label=self.sparse_label,
            graph=self.graph,
            use_sigmoid_cls=self.use_sigmoid_cls,
            eql_cfg=self.eql_cfg,
            num_classes=self.num_classes)
        if not self.concat_targets:
            bbox_targets = torch.cat(bbox_targets)
            bbox_weights = torch.cat(bbox_weights)
        '''
        if self.sparse_label:
            target_meta = {'labels': _labels}
            bin_labels, bin_label_weights = _expand_binary_labels(
                _labels, _label_weights, self.num_classes)
            if not self.propagate_labels:
                # targets for sigmoid activation
                return bin_labels, bin_label_weights, bbox_targets, bbox_weights, target_meta  # noqa
            # propagate on graph
            assert self.graph is not None
            pos_inds = _labels > 0
            neg_inds = _labels == 0
            label_weights = bin_labels.new_zeros(
                bin_labels.size(), dtype=torch.float)
            # label weights for softmax loss formulation.
            label_weights[pos_inds, 1:] = torch.matmul(
                bin_labels[pos_inds, 1:].float(), self.graph)
            label_weights[neg_inds, 0] = 1.0
            labels = (label_weights.clone() > 0).long()
            if self.use_sigmoid_cls:
                label_weights = label_weights.new_ones(
                    label_weights.size(), dtype=torch.float)
                # label weights: further processing
            return labels, label_weights, bbox_targets, bbox_weights, target_meta  # noqa
        # categorical.
        '''
        return labels, label_weights, bbox_targets, bbox_weights, target_meta