def test_vote_module():
    from mmdet3d.models.model_utils import VoteModule

    vote_loss = dict(type='ChamferDistance',
                     mode='l1',
                     reduction='none',
                     loss_dst_weight=10.0)
    self = VoteModule(vote_per_seed=3, in_channels=8, vote_loss=vote_loss)

    seed_xyz = torch.rand([2, 64, 3], dtype=torch.float32)  # (b, npoints, 3)
    seed_features = torch.rand(
        [2, 8, 64], dtype=torch.float32)  # (b, in_channels, npoints)

    # test forward
    vote_xyz, vote_features, vote_offset = self(seed_xyz, seed_features)
    assert vote_xyz.shape == torch.Size([2, 192, 3])
    assert vote_features.shape == torch.Size([2, 8, 192])
    assert vote_offset.shape == torch.Size([2, 3, 192])

    # test clip offset and without feature residual
    self = VoteModule(vote_per_seed=1,
                      in_channels=8,
                      num_points=32,
                      with_res_feat=False,
                      vote_xyz_range=(2.0, 2.0, 2.0))

    vote_xyz, vote_features, vote_offset = self(seed_xyz, seed_features)
    assert vote_xyz.shape == torch.Size([2, 32, 3])
    assert vote_features.shape == torch.Size([2, 8, 32])
    assert vote_offset.shape == torch.Size([2, 3, 32])
    assert torch.allclose(seed_features[..., :32], vote_features)
    assert vote_offset.max() <= 2.0
    assert vote_offset.min() >= -2.0
Beispiel #2
0
    def __init__(self,
                 num_classes,
                 bbox_coder,
                 train_cfg=None,
                 test_cfg=None,
                 vote_module_cfg=None,
                 vote_aggregation_cfg=None,
                 pred_layer_cfg=None,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 center_loss_mse=None,
                 dir_class_loss=None,
                 dir_res_loss=None,
                 size_class_loss=None,
                 size_res_loss=None,
                 semantic_loss=None,
                 iou_loss=None):
        super(VoteHead, self).__init__()
        self.num_classes = num_classes
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_module_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']

        self.objectness_loss = build_loss(objectness_loss)
        self.dir_res_loss = build_loss(dir_res_loss)
        self.dir_class_loss = build_loss(dir_class_loss)
        self.size_res_loss = build_loss(size_res_loss)
        if size_class_loss is not None:
            self.size_class_loss = build_loss(size_class_loss)
        if semantic_loss is not None:
            self.semantic_loss = build_loss(semantic_loss)
        if iou_loss is not None:
            self.iou_loss = build_loss(iou_loss)
        else:
            self.iou_loss = None
        if center_loss is not None:
            self.center_loss = build_loss(center_loss)
        if center_loss_mse is not None:
            self.center_loss_mse = build_loss(center_loss_mse)

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.num_sizes = self.bbox_coder.num_sizes
        self.num_dir_bins = self.bbox_coder.num_dir_bins

        self.vote_module = VoteModule(**vote_module_cfg)
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
        self.fp16_enabled = False

        # Bbox classification and regression
        self.conv_pred = BaseConvBboxHead(
            **pred_layer_cfg,
            num_cls_out_channels=self._get_cls_out_channels(),
            num_reg_out_channels=self._get_reg_out_channels())
def test_vote_module():
    from mmdet3d.models.model_utils import VoteModule

    vote_loss = dict(
        type='ChamferDistance',
        mode='l1',
        reduction='none',
        loss_dst_weight=10.0)
    self = VoteModule(vote_per_seed=3, in_channels=8, vote_loss=vote_loss)

    seed_xyz = torch.rand([2, 64, 3], dtype=torch.float32)  # (b, npoints, 3)
    seed_features = torch.rand(
        [2, 8, 64], dtype=torch.float32)  # (b, in_channels, npoints)

    # test forward
    vote_xyz, vote_features = self(seed_xyz, seed_features)
    assert vote_xyz.shape == torch.Size([2, 192, 3])
    assert vote_features.shape == torch.Size([2, 8, 192])
Beispiel #4
0
    def __init__(self,
                 num_dims,
                 num_classes,
                 primitive_mode,
                 train_cfg=None,
                 test_cfg=None,
                 vote_module_cfg=None,
                 vote_aggregation_cfg=None,
                 feat_channels=(128, 128),
                 upper_thresh=100.0,
                 surface_thresh=0.5,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 semantic_reg_loss=None,
                 semantic_cls_loss=None,
                 init_cfg=None):
        super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
        assert primitive_mode in ['z', 'xy', 'line']
        # The dimension of primitive semantic information.
        self.num_dims = num_dims
        self.num_classes = num_classes
        self.primitive_mode = primitive_mode
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_module_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']
        self.upper_thresh = upper_thresh
        self.surface_thresh = surface_thresh

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.semantic_reg_loss = build_loss(semantic_reg_loss)
        self.semantic_cls_loss = build_loss(semantic_cls_loss)

        assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
            'in_channels']

        # Primitive existence flag prediction
        self.flag_conv = ConvModule(
            vote_module_cfg['conv_channels'][-1],
            vote_module_cfg['conv_channels'][-1] // 2,
            1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            bias=True,
            inplace=True)
        self.flag_pred = torch.nn.Conv1d(
            vote_module_cfg['conv_channels'][-1] // 2, 2, 1)

        self.vote_module = VoteModule(**vote_module_cfg)
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)

        prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
        conv_pred_list = list()
        for k in range(len(feat_channels)):
            conv_pred_list.append(
                ConvModule(
                    prev_channel,
                    feat_channels[k],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    bias=True,
                    inplace=True))
            prev_channel = feat_channels[k]
        self.conv_pred = nn.Sequential(*conv_pred_list)

        conv_out_channel = 3 + num_dims + num_classes
        self.conv_pred.add_module('conv_out',
                                  nn.Conv1d(prev_channel, conv_out_channel, 1))
Beispiel #5
0
class PrimitiveHead(BaseModule):
    r"""Primitive head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.

    Args:
        num_dims (int): The dimension of primitive semantic information.
        num_classes (int): The number of class.
        primitive_mode (str): The mode of primitive module,
            avaliable mode ['z', 'xy', 'line'].
        bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
            decoding boxes.
        train_cfg (dict): Config for training.
        test_cfg (dict): Config for testing.
        vote_module_cfg (dict): Config of VoteModule for point-wise votes.
        vote_aggregation_cfg (dict): Config of vote aggregation layer.
        feat_channels (tuple[int]): Convolution channels of
            prediction layer.
        upper_thresh (float): Threshold for line matching.
        surface_thresh (float): Threshold for suface matching.
        conv_cfg (dict): Config of convolution in prediction layer.
        norm_cfg (dict): Config of BN in prediction layer.
        objectness_loss (dict): Config of objectness loss.
        center_loss (dict): Config of center loss.
        semantic_loss (dict): Config of point-wise semantic segmentation loss.
    """

    def __init__(self,
                 num_dims,
                 num_classes,
                 primitive_mode,
                 train_cfg=None,
                 test_cfg=None,
                 vote_module_cfg=None,
                 vote_aggregation_cfg=None,
                 feat_channels=(128, 128),
                 upper_thresh=100.0,
                 surface_thresh=0.5,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 semantic_reg_loss=None,
                 semantic_cls_loss=None,
                 init_cfg=None):
        super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
        assert primitive_mode in ['z', 'xy', 'line']
        # The dimension of primitive semantic information.
        self.num_dims = num_dims
        self.num_classes = num_classes
        self.primitive_mode = primitive_mode
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_module_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']
        self.upper_thresh = upper_thresh
        self.surface_thresh = surface_thresh

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.semantic_reg_loss = build_loss(semantic_reg_loss)
        self.semantic_cls_loss = build_loss(semantic_cls_loss)

        assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
            'in_channels']

        # Primitive existence flag prediction
        self.flag_conv = ConvModule(
            vote_module_cfg['conv_channels'][-1],
            vote_module_cfg['conv_channels'][-1] // 2,
            1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            bias=True,
            inplace=True)
        self.flag_pred = torch.nn.Conv1d(
            vote_module_cfg['conv_channels'][-1] // 2, 2, 1)

        self.vote_module = VoteModule(**vote_module_cfg)
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)

        prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
        conv_pred_list = list()
        for k in range(len(feat_channels)):
            conv_pred_list.append(
                ConvModule(
                    prev_channel,
                    feat_channels[k],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    bias=True,
                    inplace=True))
            prev_channel = feat_channels[k]
        self.conv_pred = nn.Sequential(*conv_pred_list)

        conv_out_channel = 3 + num_dims + num_classes
        self.conv_pred.add_module('conv_out',
                                  nn.Conv1d(prev_channel, conv_out_channel, 1))

    def forward(self, feats_dict, sample_mod):
        """Forward pass.

        Args:
            feats_dict (dict): Feature dict from backbone.
            sample_mod (str): Sample mode for vote aggregation layer.
                valid modes are "vote", "seed" and "random".

        Returns:
            dict: Predictions of primitive head.
        """
        assert sample_mod in ['vote', 'seed', 'random']

        seed_points = feats_dict['fp_xyz_net0'][-1]
        seed_features = feats_dict['hd_feature']
        results = {}

        primitive_flag = self.flag_conv(seed_features)
        primitive_flag = self.flag_pred(primitive_flag)

        results['pred_flag_' + self.primitive_mode] = primitive_flag

        # 1. generate vote_points from seed_points
        vote_points, vote_features, _ = self.vote_module(
            seed_points, seed_features)
        results['vote_' + self.primitive_mode] = vote_points
        results['vote_features_' + self.primitive_mode] = vote_features

        # 2. aggregate vote_points
        if sample_mod == 'vote':
            # use fps in vote_aggregation
            sample_indices = None
        elif sample_mod == 'seed':
            # FPS on seed and choose the votes corresponding to the seeds
            sample_indices = furthest_point_sample(seed_points,
                                                   self.num_proposal)
        elif sample_mod == 'random':
            # Random sampling from the votes
            batch_size, num_seed = seed_points.shape[:2]
            sample_indices = torch.randint(
                0,
                num_seed, (batch_size, self.num_proposal),
                dtype=torch.int32,
                device=seed_points.device)
        else:
            raise NotImplementedError('Unsupported sample mod!')

        vote_aggregation_ret = self.vote_aggregation(vote_points,
                                                     vote_features,
                                                     sample_indices)
        aggregated_points, features, aggregated_indices = vote_aggregation_ret
        results['aggregated_points_' + self.primitive_mode] = aggregated_points
        results['aggregated_features_' + self.primitive_mode] = features
        results['aggregated_indices_' +
                self.primitive_mode] = aggregated_indices

        # 3. predict primitive offsets and semantic information
        predictions = self.conv_pred(features)

        # 4. decode predictions
        decode_ret = self.primitive_decode_scores(predictions,
                                                  aggregated_points)
        results.update(decode_ret)

        center, pred_ind = self.get_primitive_center(
            primitive_flag, decode_ret['center_' + self.primitive_mode])

        results['pred_' + self.primitive_mode + '_ind'] = pred_ind
        results['pred_' + self.primitive_mode + '_center'] = center
        return results

    def loss(self,
             bbox_preds,
             points,
             gt_bboxes_3d,
             gt_labels_3d,
             pts_semantic_mask=None,
             pts_instance_mask=None,
             img_metas=None,
             gt_bboxes_ignore=None):
        """Compute loss.

        Args:
            bbox_preds (dict): Predictions from forward of primitive head.
            points (list[torch.Tensor]): Input points.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each sample.
            gt_labels_3d (list[torch.Tensor]): Labels of each sample.
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise
                semantic mask.
            pts_instance_mask (None | list[torch.Tensor]): Point-wise
                instance mask.
            img_metas (list[dict]): Contain pcd and img's meta info.
            gt_bboxes_ignore (None | list[torch.Tensor]): Specify
                which bounding.

        Returns:
            dict: Losses of Primitive Head.
        """
        targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
                                   pts_semantic_mask, pts_instance_mask,
                                   bbox_preds)

        (point_mask, point_offset, gt_primitive_center, gt_primitive_semantic,
         gt_sem_cls_label, gt_primitive_mask) = targets

        losses = {}
        # Compute the loss of primitive existence flag
        pred_flag = bbox_preds['pred_flag_' + self.primitive_mode]
        flag_loss = self.objectness_loss(pred_flag, gt_primitive_mask.long())
        losses['flag_loss_' + self.primitive_mode] = flag_loss

        # calculate vote loss
        vote_loss = self.vote_module.get_loss(
            bbox_preds['seed_points'],
            bbox_preds['vote_' + self.primitive_mode],
            bbox_preds['seed_indices'], point_mask, point_offset)
        losses['vote_loss_' + self.primitive_mode] = vote_loss

        num_proposal = bbox_preds['aggregated_points_' +
                                  self.primitive_mode].shape[1]
        primitive_center = bbox_preds['center_' + self.primitive_mode]
        if self.primitive_mode != 'line':
            primitive_semantic = bbox_preds['size_residuals_' +
                                            self.primitive_mode].contiguous()
        else:
            primitive_semantic = None
        semancitc_scores = bbox_preds['sem_cls_scores_' +
                                      self.primitive_mode].transpose(2, 1)

        gt_primitive_mask = gt_primitive_mask / \
            (gt_primitive_mask.sum() + 1e-6)
        center_loss, size_loss, sem_cls_loss = self.compute_primitive_loss(
            primitive_center, primitive_semantic, semancitc_scores,
            num_proposal, gt_primitive_center, gt_primitive_semantic,
            gt_sem_cls_label, gt_primitive_mask)
        losses['center_loss_' + self.primitive_mode] = center_loss
        losses['size_loss_' + self.primitive_mode] = size_loss
        losses['sem_loss_' + self.primitive_mode] = sem_cls_loss

        return losses

    def get_targets(self,
                    points,
                    gt_bboxes_3d,
                    gt_labels_3d,
                    pts_semantic_mask=None,
                    pts_instance_mask=None,
                    bbox_preds=None):
        """Generate targets of primitive head.

        Args:
            points (list[torch.Tensor]): Points of each batch.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each batch.
            gt_labels_3d (list[torch.Tensor]): Labels of each batch.
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic
                label of each batch.
            pts_instance_mask (None | list[torch.Tensor]): Point-wise instance
                label of each batch.
            bbox_preds (dict): Predictions from forward of primitive head.

        Returns:
            tuple[torch.Tensor]: Targets of primitive head.
        """
        for index in range(len(gt_labels_3d)):
            if len(gt_labels_3d[index]) == 0:
                fake_box = gt_bboxes_3d[index].tensor.new_zeros(
                    1, gt_bboxes_3d[index].tensor.shape[-1])
                gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
                gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)

        if pts_semantic_mask is None:
            pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
            pts_instance_mask = [None for i in range(len(gt_labels_3d))]

        (point_mask, point_sem,
         point_offset) = multi_apply(self.get_targets_single, points,
                                     gt_bboxes_3d, gt_labels_3d,
                                     pts_semantic_mask, pts_instance_mask)

        point_mask = torch.stack(point_mask)
        point_sem = torch.stack(point_sem)
        point_offset = torch.stack(point_offset)

        batch_size = point_mask.shape[0]
        num_proposal = bbox_preds['aggregated_points_' +
                                  self.primitive_mode].shape[1]
        num_seed = bbox_preds['seed_points'].shape[1]
        seed_inds = bbox_preds['seed_indices'].long()
        seed_inds_expand = seed_inds.view(batch_size, num_seed,
                                          1).repeat(1, 1, 3)
        seed_gt_votes = torch.gather(point_offset, 1, seed_inds_expand)
        seed_gt_votes += bbox_preds['seed_points']
        gt_primitive_center = seed_gt_votes.view(batch_size * num_proposal, 1,
                                                 3)

        seed_inds_expand_sem = seed_inds.view(batch_size, num_seed, 1).repeat(
            1, 1, 4 + self.num_dims)
        seed_gt_sem = torch.gather(point_sem, 1, seed_inds_expand_sem)
        gt_primitive_semantic = seed_gt_sem[:, :, 3:3 + self.num_dims].view(
            batch_size * num_proposal, 1, self.num_dims).contiguous()

        gt_sem_cls_label = seed_gt_sem[:, :, -1].long()

        gt_votes_mask = torch.gather(point_mask, 1, seed_inds)

        return (point_mask, point_offset, gt_primitive_center,
                gt_primitive_semantic, gt_sem_cls_label, gt_votes_mask)

    def get_targets_single(self,
                           points,
                           gt_bboxes_3d,
                           gt_labels_3d,
                           pts_semantic_mask=None,
                           pts_instance_mask=None):
        """Generate targets of primitive head for single batch.

        Args:
            points (torch.Tensor): Points of each batch.
            gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \
                boxes of each batch.
            gt_labels_3d (torch.Tensor): Labels of each batch.
            pts_semantic_mask (None | torch.Tensor): Point-wise semantic
                label of each batch.
            pts_instance_mask (None | torch.Tensor): Point-wise instance
                label of each batch.

        Returns:
            tuple[torch.Tensor]: Targets of primitive head.
        """
        gt_bboxes_3d = gt_bboxes_3d.to(points.device)
        num_points = points.shape[0]

        point_mask = points.new_zeros(num_points)
        # Offset to the primitive center
        point_offset = points.new_zeros([num_points, 3])
        # Semantic information of primitive center
        point_sem = points.new_zeros([num_points, 3 + self.num_dims + 1])

        # Generate pts_semantic_mask and pts_instance_mask when they are None
        if pts_semantic_mask is None or pts_instance_mask is None:
            points2box_mask = gt_bboxes_3d.points_in_boxes(points)
            assignment = points2box_mask.argmax(1)
            background_mask = points2box_mask.max(1)[0] == 0

            if pts_semantic_mask is None:
                pts_semantic_mask = gt_labels_3d[assignment]
                pts_semantic_mask[background_mask] = self.num_classes

            if pts_instance_mask is None:
                pts_instance_mask = assignment
                pts_instance_mask[background_mask] = gt_labels_3d.shape[0]

        instance_flag = torch.nonzero(
            pts_semantic_mask != self.num_classes, as_tuple=False).squeeze(1)
        instance_labels = pts_instance_mask[instance_flag].unique()

        with_yaw = gt_bboxes_3d.with_yaw
        for i, i_instance in enumerate(instance_labels):
            indices = instance_flag[pts_instance_mask[instance_flag] ==
                                    i_instance]
            coords = points[indices, :3]
            cur_cls_label = pts_semantic_mask[indices][0]

            # Bbox Corners
            cur_corners = gt_bboxes_3d.corners[i]

            plane_lower_temp = points.new_tensor(
                [0, 0, 1, -cur_corners[7, -1]])
            upper_points = cur_corners[[1, 2, 5, 6]]
            refined_distance = (upper_points * plane_lower_temp[:3]).sum(dim=1)

            if self.check_horizon(upper_points) and \
                    plane_lower_temp[0] + plane_lower_temp[1] < \
                    self.train_cfg['lower_thresh']:
                plane_lower = points.new_tensor(
                    [0, 0, 1, plane_lower_temp[-1]])
                plane_upper = points.new_tensor(
                    [0, 0, 1, -torch.mean(refined_distance)])
            else:
                raise NotImplementedError('Only horizontal plane is support!')

            if self.check_dist(plane_upper, upper_points) is False:
                raise NotImplementedError(
                    'Mean distance to plane should be lower than thresh!')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_lower, coords)

            # Get bottom four lines
            if self.primitive_mode == 'line':
                point2line_matching = self.match_point2line(
                    coords[selected], cur_corners, with_yaw, mode='bottom')

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_line_targets(point_mask,
                                                        point_offset,
                                                        point_sem,
                                                        coords[selected],
                                                        indices[selected],
                                                        cur_cls_label,
                                                        point2line_matching,
                                                        cur_corners,
                                                        [1, 1, 0, 0],
                                                        with_yaw,
                                                        mode='bottom')

            # Set the surface labels here
            if self.primitive_mode == 'z' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(point_mask,
                                                           point_offset,
                                                           point_sem,
                                                           coords[selected],
                                                           indices[selected],
                                                           cur_cls_label,
                                                           cur_corners,
                                                           with_yaw,
                                                           mode='bottom')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_upper, coords)

            # Get top four lines
            if self.primitive_mode == 'line':
                point2line_matching = self.match_point2line(
                    coords[selected], cur_corners, with_yaw, mode='top')

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_line_targets(point_mask,
                                                        point_offset,
                                                        point_sem,
                                                        coords[selected],
                                                        indices[selected],
                                                        cur_cls_label,
                                                        point2line_matching,
                                                        cur_corners,
                                                        [1, 1, 0, 0],
                                                        with_yaw,
                                                        mode='top')

            if self.primitive_mode == 'z' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(point_mask,
                                                           point_offset,
                                                           point_sem,
                                                           coords[selected],
                                                           indices[selected],
                                                           cur_cls_label,
                                                           cur_corners,
                                                           with_yaw,
                                                           mode='top')

            # Get left two lines
            plane_left_temp = self._get_plane_fomulation(
                cur_corners[2] - cur_corners[3],
                cur_corners[3] - cur_corners[0], cur_corners[0])

            right_points = cur_corners[[4, 5, 7, 6]]
            plane_left_temp /= torch.norm(plane_left_temp[:3])
            refined_distance = (right_points * plane_left_temp[:3]).sum(dim=1)

            if plane_left_temp[2] < self.train_cfg['lower_thresh']:
                plane_left = plane_left_temp
                plane_right = points.new_tensor([
                    plane_left_temp[0], plane_left_temp[1], plane_left_temp[2],
                    -refined_distance.mean()
                ])
            else:
                raise NotImplementedError(
                    'Normal vector of the plane should be horizontal!')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_left, coords)

            # Get left four lines
            if self.primitive_mode == 'line':
                point2line_matching = self.match_point2line(
                    coords[selected], cur_corners, with_yaw, mode='left')
                point_mask, point_offset, point_sem = \
                    self._assign_primitive_line_targets(
                        point_mask, point_offset, point_sem,
                        coords[selected], indices[selected], cur_cls_label,
                        point2line_matching[2:], cur_corners, [2, 2],
                        with_yaw, mode='left')

            if self.primitive_mode == 'xy' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(
                        point_mask, point_offset, point_sem,
                        coords[selected], indices[selected], cur_cls_label,
                        cur_corners, with_yaw, mode='left')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_right, coords)

            # Get right four lines
            if self.primitive_mode == 'line':
                point2line_matching = self.match_point2line(
                    coords[selected], cur_corners, with_yaw, mode='right')

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_line_targets(
                        point_mask, point_offset, point_sem,
                        coords[selected], indices[selected], cur_cls_label,
                        point2line_matching[2:], cur_corners, [2, 2],
                        with_yaw, mode='right')

            if self.primitive_mode == 'xy' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(
                        point_mask, point_offset, point_sem,
                        coords[selected], indices[selected], cur_cls_label,
                        cur_corners, with_yaw, mode='right')

            plane_front_temp = self._get_plane_fomulation(
                cur_corners[0] - cur_corners[4],
                cur_corners[4] - cur_corners[5], cur_corners[5])

            back_points = cur_corners[[3, 2, 7, 6]]
            plane_front_temp /= torch.norm(plane_front_temp[:3])
            refined_distance = (back_points * plane_front_temp[:3]).sum(dim=1)

            if plane_front_temp[2] < self.train_cfg['lower_thresh']:
                plane_front = plane_front_temp
                plane_back = points.new_tensor([
                    plane_front_temp[0], plane_front_temp[1],
                    plane_front_temp[2], -torch.mean(refined_distance)
                ])
            else:
                raise NotImplementedError(
                    'Normal vector of the plane should be horizontal!')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_front, coords)

            if self.primitive_mode == 'xy' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    (point2plane_dist[selected]).var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(
                        point_mask, point_offset, point_sem,
                        coords[selected], indices[selected], cur_cls_label,
                        cur_corners, with_yaw, mode='front')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_back, coords)

            if self.primitive_mode == 'xy' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(
                        point_mask, point_offset, point_sem,
                        coords[selected], indices[selected], cur_cls_label,
                        cur_corners, with_yaw, mode='back')

        return (point_mask, point_sem, point_offset)

    def primitive_decode_scores(self, predictions, aggregated_points):
        """Decode predicted parts to primitive head.

        Args:
            predictions (torch.Tensor): primitive pridictions of each batch.
            aggregated_points (torch.Tensor): The aggregated points
                of vote stage.

        Returns:
            Dict: Predictions of primitive head, including center,
                semantic size and semantic scores.
        """

        ret_dict = {}
        pred_transposed = predictions.transpose(2, 1)

        center = aggregated_points + pred_transposed[:, :, 0:3]
        ret_dict['center_' + self.primitive_mode] = center

        if self.primitive_mode in ['z', 'xy']:
            ret_dict['size_residuals_' + self.primitive_mode] = \
                pred_transposed[:, :, 3:3 + self.num_dims]

        ret_dict['sem_cls_scores_' + self.primitive_mode] = \
            pred_transposed[:, :, 3 + self.num_dims:]

        return ret_dict

    def check_horizon(self, points):
        """Check whether is a horizontal plane.

        Args:
            points (torch.Tensor): Points of input.

        Returns:
            Bool: Flag of result.
        """
        return (points[0][-1] == points[1][-1]) and \
               (points[1][-1] == points[2][-1]) and \
               (points[2][-1] == points[3][-1])

    def check_dist(self, plane_equ, points):
        """Whether the mean of points to plane distance is lower than thresh.

        Args:
            plane_equ (torch.Tensor): Plane to be checked.
            points (torch.Tensor): Points to be checked.

        Returns:
            Tuple: Flag of result.
        """
        return (points[:, 2] +
                plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh']

    def point2line_dist(self, points, pts_a, pts_b):
        """Calculate the distance from point to line.

        Args:
            points (torch.Tensor): Points of input.
            pts_a (torch.Tensor): Point on the specific line.
            pts_b (torch.Tensor): Point on the specific line.

        Returns:
            torch.Tensor: Distance between each point to line.
        """
        line_a2b = pts_b - pts_a
        line_a2pts = points - pts_a
        length = (line_a2pts * line_a2b.view(1, 3)).sum(1) / \
            line_a2b.norm()
        dist = (line_a2pts.norm(dim=1)**2 - length**2).sqrt()

        return dist

    def match_point2line(self, points, corners, with_yaw, mode='bottom'):
        """Match points to corresponding line.

        Args:
            points (torch.Tensor): Points of input.
            corners (torch.Tensor): Eight corners of a bounding box.
            with_yaw (Bool): Whether the boundind box is with rotation.
            mode (str, optional): Specify which line should be matched,
                available mode are ('bottom', 'top', 'left', 'right').
                Defaults to 'bottom'.

        Returns:
            Tuple: Flag of matching correspondence.
        """
        if with_yaw:
            corners_pair = {
                'bottom': [[0, 3], [4, 7], [0, 4], [3, 7]],
                'top': [[1, 2], [5, 6], [1, 5], [2, 6]],
                'left': [[0, 1], [3, 2], [0, 1], [3, 2]],
                'right': [[4, 5], [7, 6], [4, 5], [7, 6]]
            }
            selected_list = []
            for pair_index in corners_pair[mode]:
                selected = self.point2line_dist(
                    points, corners[pair_index[0]], corners[pair_index[1]]) \
                    < self.train_cfg['line_thresh']
                selected_list.append(selected)
        else:
            xmin, ymin, _ = corners.min(0)[0]
            xmax, ymax, _ = corners.max(0)[0]
            sel1 = torch.abs(points[:, 0] -
                             xmin) < self.train_cfg['line_thresh']
            sel2 = torch.abs(points[:, 0] -
                             xmax) < self.train_cfg['line_thresh']
            sel3 = torch.abs(points[:, 1] -
                             ymin) < self.train_cfg['line_thresh']
            sel4 = torch.abs(points[:, 1] -
                             ymax) < self.train_cfg['line_thresh']
            selected_list = [sel1, sel2, sel3, sel4]
        return selected_list

    def match_point2plane(self, plane, points):
        """Match points to plane.

        Args:
            plane (torch.Tensor): Equation of the plane.
            points (torch.Tensor): Points of input.

        Returns:
            Tuple: Distance of each point to the plane and
                flag of matching correspondence.
        """
        point2plane_dist = torch.abs((points * plane[:3]).sum(dim=1) +
                                     plane[-1])
        min_dist = point2plane_dist.min()
        selected = torch.abs(point2plane_dist -
                             min_dist) < self.train_cfg['dist_thresh']
        return point2plane_dist, selected

    def compute_primitive_loss(self, primitive_center, primitive_semantic,
                               semantic_scores, num_proposal,
                               gt_primitive_center, gt_primitive_semantic,
                               gt_sem_cls_label, gt_primitive_mask):
        """Compute loss of primitive module.

        Args:
            primitive_center (torch.Tensor): Pridictions of primitive center.
            primitive_semantic (torch.Tensor): Pridictions of primitive
                semantic.
            semantic_scores (torch.Tensor): Pridictions of primitive
                semantic scores.
            num_proposal (int): The number of primitive proposal.
            gt_primitive_center (torch.Tensor): Ground truth of
                primitive center.
            gt_votes_sem (torch.Tensor): Ground truth of primitive semantic.
            gt_sem_cls_label (torch.Tensor): Ground truth of primitive
                semantic class.
            gt_primitive_mask (torch.Tensor): Ground truth of primitive mask.

        Returns:
            Tuple: Loss of primitive module.
        """
        batch_size = primitive_center.shape[0]
        vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1,
                                                 3)

        center_loss = self.center_loss(
            vote_xyz_reshape,
            gt_primitive_center,
            dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1]

        if self.primitive_mode != 'line':
            size_xyz_reshape = primitive_semantic.view(
                batch_size * num_proposal, -1, self.num_dims).contiguous()
            size_loss = self.semantic_reg_loss(
                size_xyz_reshape,
                gt_primitive_semantic,
                dst_weight=gt_primitive_mask.view(batch_size * num_proposal,
                                                  1))[1]
        else:
            size_loss = center_loss.new_tensor(0.0)

        # Semantic cls loss
        sem_cls_loss = self.semantic_cls_loss(
            semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask)

        return center_loss, size_loss, sem_cls_loss

    def get_primitive_center(self, pred_flag, center):
        """Generate primitive center from predictions.

        Args:
            pred_flag (torch.Tensor): Scores of primitive center.
            center (torch.Tensor): Pridictions of primitive center.

        Returns:
            Tuple: Primitive center and the prediction indices.
        """
        ind_normal = F.softmax(pred_flag, dim=1)
        pred_indices = (ind_normal[:, 1, :] >
                        self.surface_thresh).detach().float()
        selected = (ind_normal[:, 1, :] <=
                    self.surface_thresh).detach().float()
        offset = torch.ones_like(center) * self.upper_thresh
        center = center + offset * selected.unsqueeze(-1)
        return center, pred_indices

    def _assign_primitive_line_targets(self,
                                       point_mask,
                                       point_offset,
                                       point_sem,
                                       coords,
                                       indices,
                                       cls_label,
                                       point2line_matching,
                                       corners,
                                       center_axises,
                                       with_yaw,
                                       mode='bottom'):
        """Generate targets of line primitive.

        Args:
            point_mask (torch.Tensor): Tensor to store the ground
                truth of mask.
            point_offset (torch.Tensor): Tensor to store the ground
                truth of offset.
            point_sem (torch.Tensor): Tensor to store the ground
                truth of semantic.
            coords (torch.Tensor): The selected points.
            indices (torch.Tensor): Indices of the selected points.
            cls_label (int): Class label of the ground truth bounding box.
            point2line_matching (torch.Tensor): Flag indicate that
                matching line of each point.
            corners (torch.Tensor): Corners of the ground truth bounding box.
            center_axises (list[int]): Indicate in which axis the line center
                should be refined.
            with_yaw (Bool): Whether the boundind box is with rotation.
            mode (str, optional): Specify which line should be matched,
                available mode are ('bottom', 'top', 'left', 'right').
                Defaults to 'bottom'.

        Returns:
            Tuple: Targets of the line primitive.
        """
        corners_pair = {
            'bottom': [[0, 3], [4, 7], [0, 4], [3, 7]],
            'top': [[1, 2], [5, 6], [1, 5], [2, 6]],
            'left': [[0, 1], [3, 2]],
            'right': [[4, 5], [7, 6]]
        }
        corners_pair = corners_pair[mode]
        assert len(corners_pair) == len(point2line_matching) == len(
            center_axises)
        for line_select, center_axis, pair_index in zip(
                point2line_matching, center_axises, corners_pair):
            if line_select.sum() > self.train_cfg['num_point_line']:
                point_mask[indices[line_select]] = 1.0

                if with_yaw:
                    line_center = (corners[pair_index[0]] +
                                   corners[pair_index[1]]) / 2
                else:
                    line_center = coords[line_select].mean(dim=0)
                    line_center[center_axis] = corners[:, center_axis].mean()

                point_offset[indices[line_select]] = \
                    line_center - coords[line_select]
                point_sem[indices[line_select]] = \
                    point_sem.new_tensor([line_center[0], line_center[1],
                                          line_center[2], cls_label])
        return point_mask, point_offset, point_sem

    def _assign_primitive_surface_targets(self,
                                          point_mask,
                                          point_offset,
                                          point_sem,
                                          coords,
                                          indices,
                                          cls_label,
                                          corners,
                                          with_yaw,
                                          mode='bottom'):
        """Generate targets for primitive z and primitive xy.

        Args:
            point_mask (torch.Tensor): Tensor to store the ground
                truth of mask.
            point_offset (torch.Tensor): Tensor to store the ground
                truth of offset.
            point_sem (torch.Tensor): Tensor to store the ground
                truth of semantic.
            coords (torch.Tensor): The selected points.
            indices (torch.Tensor): Indices of the selected points.
            cls_label (int): Class label of the ground truth bounding box.
            corners (torch.Tensor): Corners of the ground truth bounding box.
            with_yaw (Bool): Whether the boundind box is with rotation.
            mode (str, optional): Specify which line should be matched,
                available mode are ('bottom', 'top', 'left', 'right',
                'front', 'back').
                Defaults to 'bottom'.

        Returns:
            Tuple: Targets of the center primitive.
        """
        point_mask[indices] = 1.0
        corners_pair = {
            'bottom': [0, 7],
            'top': [1, 6],
            'left': [0, 1],
            'right': [4, 5],
            'front': [0, 1],
            'back': [3, 2]
        }
        pair_index = corners_pair[mode]
        if self.primitive_mode == 'z':
            if with_yaw:
                center = (corners[pair_index[0]] +
                          corners[pair_index[1]]) / 2.0
                center[2] = coords[:, 2].mean()
                point_sem[indices] = point_sem.new_tensor([
                    center[0], center[1],
                    center[2], (corners[4] - corners[0]).norm(),
                    (corners[3] - corners[0]).norm(), cls_label
                ])
            else:
                center = point_mask.new_tensor([
                    corners[:, 0].mean(), corners[:, 1].mean(),
                    coords[:, 2].mean()
                ])
                point_sem[indices] = point_sem.new_tensor([
                    center[0], center[1], center[2],
                    corners[:, 0].max() - corners[:, 0].min(),
                    corners[:, 1].max() - corners[:, 1].min(), cls_label
                ])
        elif self.primitive_mode == 'xy':
            if with_yaw:
                center = coords.mean(0)
                center[2] = (corners[pair_index[0], 2] +
                             corners[pair_index[1], 2]) / 2.0
                point_sem[indices] = point_sem.new_tensor([
                    center[0], center[1], center[2],
                    corners[pair_index[1], 2] - corners[pair_index[0], 2],
                    cls_label
                ])
            else:
                center = point_mask.new_tensor([
                    coords[:, 0].mean(), coords[:, 1].mean(),
                    corners[:, 2].mean()
                ])
                point_sem[indices] = point_sem.new_tensor([
                    center[0], center[1], center[2],
                    corners[:, 2].max() - corners[:, 2].min(), cls_label
                ])
        point_offset[indices] = center - coords
        return point_mask, point_offset, point_sem

    def _get_plane_fomulation(self, vector1, vector2, point):
        """Compute the equation of the plane.

        Args:
            vector1 (torch.Tensor): Parallel vector of the plane.
            vector2 (torch.Tensor): Parallel vector of the plane.
            point (torch.Tensor): Point on the plane.

        Returns:
            torch.Tensor: Equation of the plane.
        """
        surface_norm = torch.cross(vector1, vector2)
        surface_dis = -torch.dot(surface_norm, point)
        plane = point.new_tensor(
            [surface_norm[0], surface_norm[1], surface_norm[2], surface_dis])
        return plane
Beispiel #6
0
class VoteHead(nn.Module):
    r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_.

    Args:
        num_classes (int): The number of class.
        bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
            decoding boxes.
        train_cfg (dict): Config for training.
        test_cfg (dict): Config for testing.
        vote_module_cfg (dict): Config of VoteModule for point-wise votes.
        vote_aggregation_cfg (dict): Config of vote aggregation layer.
        pred_layer_cfg (dict): Config of classfication and regression
            prediction layers.
        conv_cfg (dict): Config of convolution in prediction layer.
        norm_cfg (dict): Config of BN in prediction layer.
        objectness_loss (dict): Config of objectness loss.
        center_loss (dict): Config of center loss.
        dir_class_loss (dict): Config of direction classification loss.
        dir_res_loss (dict): Config of direction residual regression loss.
        size_class_loss (dict): Config of size classification loss.
        size_res_loss (dict): Config of size residual regression loss.
        semantic_loss (dict): Config of point-wise semantic segmentation loss.
    """

    def __init__(self,
                 num_classes,
                 bbox_coder,
                 train_cfg=None,
                 test_cfg=None,
                 vote_module_cfg=None,
                 vote_aggregation_cfg=None,
                 pred_layer_cfg=None,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 dir_class_loss=None,
                 dir_res_loss=None,
                 size_class_loss=None,
                 size_res_loss=None,
                 semantic_loss=None,
                 iou_loss=None):
        super(VoteHead, self).__init__()
        self.num_classes = num_classes
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_module_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.dir_res_loss = build_loss(dir_res_loss)
        self.dir_class_loss = build_loss(dir_class_loss)
        self.size_res_loss = build_loss(size_res_loss)
        if size_class_loss is not None:
            self.size_class_loss = build_loss(size_class_loss)
        if semantic_loss is not None:
            self.semantic_loss = build_loss(semantic_loss)
        if iou_loss is not None:
            self.iou_loss = build_loss(iou_loss)
        else:
            self.iou_loss = None

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.num_sizes = self.bbox_coder.num_sizes
        self.num_dir_bins = self.bbox_coder.num_dir_bins

        self.vote_module = VoteModule(**vote_module_cfg)
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
        self.fp16_enabled = False

        # Bbox classification and regression
        self.conv_pred = BaseConvBboxHead(
            **pred_layer_cfg,
            num_cls_out_channels=self._get_cls_out_channels(),
            num_reg_out_channels=self._get_reg_out_channels())

    def init_weights(self):
        """Initialize weights of VoteHead."""
        pass

    def _get_cls_out_channels(self):
        """Return the channel number of classification outputs."""
        # Class numbers (k) + objectness (2)
        return self.num_classes + 2

    def _get_reg_out_channels(self):
        """Return the channel number of regression outputs."""
        # Objectness scores (2), center residual (3),
        # heading class+residual (num_dir_bins*2),
        # size class+residual(num_sizes*4)
        return 3 + self.num_dir_bins * 2 + self.num_sizes * 4

    def _extract_input(self, feat_dict):
        """Extract inputs from features dictionary.

        Args:
            feat_dict (dict): Feature dict from backbone.

        Returns:
            torch.Tensor: Coordinates of input points.
            torch.Tensor: Features of input points.
            torch.Tensor: Indices of input points.
        """
        seed_points = feat_dict['fp_xyz'][-1]
        seed_features = feat_dict['fp_features'][-1]
        seed_indices = feat_dict['fp_indices'][-1]

        return seed_points, seed_features, seed_indices

    def forward(self, feat_dict, sample_mod):
        """Forward pass.

        Note:
            The forward of VoteHead is devided into 4 steps:

                1. Generate vote_points from seed_points.
                2. Aggregate vote_points.
                3. Predict bbox and score.
                4. Decode predictions.

        Args:
            feat_dict (dict): Feature dict from backbone.
            sample_mod (str): Sample mode for vote aggregation layer.
                valid modes are "vote", "seed", "random" and "spec".

        Returns:
            dict: Predictions of vote head.
        """
        assert sample_mod in ['vote', 'seed', 'random', 'spec']

        seed_points, seed_features, seed_indices = self._extract_input(
            feat_dict)

        # 1. generate vote_points from seed_points
        vote_points, vote_features, vote_offset = self.vote_module(
            seed_points, seed_features)
        results = dict(
            seed_points=seed_points,
            seed_indices=seed_indices,
            vote_points=vote_points,
            vote_features=vote_features,
            vote_offset=vote_offset)

        # 2. aggregate vote_points
        if sample_mod == 'vote':
            # use fps in vote_aggregation
            aggregation_inputs = dict(
                points_xyz=vote_points, features=vote_features)
        elif sample_mod == 'seed':
            # FPS on seed and choose the votes corresponding to the seeds
            sample_indices = furthest_point_sample(seed_points,
                                                   self.num_proposal)
            aggregation_inputs = dict(
                points_xyz=vote_points,
                features=vote_features,
                indices=sample_indices)
        elif sample_mod == 'random':
            # Random sampling from the votes
            batch_size, num_seed = seed_points.shape[:2]
            sample_indices = seed_points.new_tensor(
                torch.randint(0, num_seed, (batch_size, self.num_proposal)),
                dtype=torch.int32)
            aggregation_inputs = dict(
                points_xyz=vote_points,
                features=vote_features,
                indices=sample_indices)
        elif sample_mod == 'spec':
            # Specify the new center in vote_aggregation
            aggregation_inputs = dict(
                points_xyz=seed_points,
                features=seed_features,
                target_xyz=vote_points)
        else:
            raise NotImplementedError(
                f'Sample mode {sample_mod} is not supported!')

        vote_aggregation_ret = self.vote_aggregation(**aggregation_inputs)
        aggregated_points, features, aggregated_indices = vote_aggregation_ret


        results['aggregated_points'] = aggregated_points
        results['aggregated_features'] = features
        results['aggregated_indices'] = aggregated_indices

        # 3. predict bbox and score
        cls_predictions, reg_predictions = self.conv_pred(features)

        # 4. decode predictions
        decode_res = self.bbox_coder.split_pred(cls_predictions,
                                                reg_predictions,
                                                aggregated_points)

        results.update(decode_res)

        return results

    @force_fp32(apply_to=('bbox_preds', ))
    def loss(self,
             bbox_preds,
             points,
             gt_bboxes_3d,
             gt_labels_3d,
             pts_semantic_mask=None,
             pts_instance_mask=None,
             img_metas=None,
             gt_bboxes_ignore=None,
             ret_target=False):
        """Compute loss.

        Args:
            bbox_preds (dict): Predictions from forward of vote head.
            points (list[torch.Tensor]): Input points.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each sample.
            gt_labels_3d (list[torch.Tensor]): Labels of each sample.
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise
                semantic mask.
            pts_instance_mask (None | list[torch.Tensor]): Point-wise
                instance mask.
            img_metas (list[dict]): Contain pcd and img's meta info.
            gt_bboxes_ignore (None | list[torch.Tensor]): Specify
                which bounding.
            ret_target (Bool): Return targets or not.

        Returns:
            dict: Losses of Votenet.
        """
        targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
                                   pts_semantic_mask, pts_instance_mask,
                                   bbox_preds)
        (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
         dir_class_targets, dir_res_targets, center_targets,
         assigned_center_targets, mask_targets, valid_gt_masks,
         objectness_targets, objectness_weights, box_loss_weights,
         valid_gt_weights) = targets

        # calculate vote loss
        vote_loss = self.vote_module.get_loss(bbox_preds['seed_points'],
                                              bbox_preds['vote_points'],
                                              bbox_preds['seed_indices'],
                                              vote_target_masks, vote_targets)

        # calculate objectness loss
        objectness_loss = self.objectness_loss(
            bbox_preds['obj_scores'].transpose(2, 1),
            objectness_targets,
            weight=objectness_weights)

        # calculate center loss
        source2target_loss, target2source_loss = self.center_loss(
            bbox_preds['center'],
            center_targets,
            src_weight=box_loss_weights,
            dst_weight=valid_gt_weights)
        center_loss = source2target_loss + target2source_loss

        # calculate direction class loss
        dir_class_loss = self.dir_class_loss(
            bbox_preds['dir_class'].transpose(2, 1),
            dir_class_targets,
            weight=box_loss_weights)

        # calculate direction residual loss
        batch_size, proposal_num = size_class_targets.shape[:2]
        heading_label_one_hot = vote_targets.new_zeros(
            (batch_size, proposal_num, self.num_dir_bins))
        heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
        dir_res_norm = torch.sum(
            bbox_preds['dir_res_norm'] * heading_label_one_hot, -1)
        dir_res_loss = self.dir_res_loss(
            dir_res_norm, dir_res_targets, weight=box_loss_weights)

        # calculate size class loss
        size_class_loss = self.size_class_loss(
            bbox_preds['size_class'].transpose(2, 1),
            size_class_targets,
            weight=box_loss_weights)

        # calculate size residual loss
        one_hot_size_targets = vote_targets.new_zeros(
            (batch_size, proposal_num, self.num_sizes))
        one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1)
        one_hot_size_targets_expand = one_hot_size_targets.unsqueeze(
            -1).repeat(1, 1, 1, 3).contiguous()
        size_residual_norm = torch.sum(
            bbox_preds['size_res_norm'] * one_hot_size_targets_expand, 2)
        box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
            1, 1, 3)
        size_res_loss = self.size_res_loss(
            size_residual_norm,
            size_res_targets,
            weight=box_loss_weights_expand)

        # calculate semantic loss
        semantic_loss = self.semantic_loss(
            bbox_preds['sem_scores'].transpose(2, 1),
            mask_targets,
            weight=box_loss_weights)

        losses = dict(
            vote_loss=vote_loss,
            objectness_loss=objectness_loss,
            semantic_loss=semantic_loss,
            center_loss=center_loss,
            dir_class_loss=dir_class_loss,
            dir_res_loss=dir_res_loss,
            size_class_loss=size_class_loss,
            size_res_loss=size_res_loss)

        if self.iou_loss:
            corners_pred = self.bbox_coder.decode_corners(
                bbox_preds['center'], size_residual_norm,
                one_hot_size_targets_expand)
            corners_target = self.bbox_coder.decode_corners(
                assigned_center_targets, size_res_targets,
                one_hot_size_targets_expand)
            iou_loss = self.iou_loss(
                corners_pred, corners_target, weight=box_loss_weights)
            losses['iou_loss'] = iou_loss

        if ret_target:
            losses['targets'] = targets

        return losses

    def get_targets(self,
                    points,
                    gt_bboxes_3d,
                    gt_labels_3d,
                    pts_semantic_mask=None,
                    pts_instance_mask=None,
                    bbox_preds=None):
        """Generate targets of vote head.

        Args:
            points (list[torch.Tensor]): Points of each batch.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each batch.
            gt_labels_3d (list[torch.Tensor]): Labels of each batch.
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic
                label of each batch.
            pts_instance_mask (None | list[torch.Tensor]): Point-wise instance
                label of each batch.
            bbox_preds (torch.Tensor): Bounding box predictions of vote head.

        Returns:
            tuple[torch.Tensor]: Targets of vote head.
        """
        # find empty example
        valid_gt_masks = list()
        gt_num = list()
        for index in range(len(gt_labels_3d)):
            if len(gt_labels_3d[index]) == 0:
                fake_box = gt_bboxes_3d[index].tensor.new_zeros(
                    1, gt_bboxes_3d[index].tensor.shape[-1])
                gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
                gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
                valid_gt_masks.append(gt_labels_3d[index].new_zeros(1))
                gt_num.append(1)
            else:
                valid_gt_masks.append(gt_labels_3d[index].new_ones(
                    gt_labels_3d[index].shape))
                gt_num.append(gt_labels_3d[index].shape[0])
        max_gt_num = max(gt_num)

        if pts_semantic_mask is None:
            pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
            pts_instance_mask = [None for i in range(len(gt_labels_3d))]

        aggregated_points = [
            bbox_preds['aggregated_points'][i]
            for i in range(len(gt_labels_3d))
        ]

        (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
         dir_class_targets, dir_res_targets, center_targets,
         assigned_center_targets, mask_targets, objectness_targets,
         objectness_masks) = multi_apply(self.get_targets_single, points,
                                         gt_bboxes_3d, gt_labels_3d,
                                         pts_semantic_mask, pts_instance_mask,
                                         aggregated_points)

        # pad targets as original code of votenet.
        for index in range(len(gt_labels_3d)):
            pad_num = max_gt_num - gt_labels_3d[index].shape[0]
            center_targets[index] = F.pad(center_targets[index],
                                          (0, 0, 0, pad_num))
            valid_gt_masks[index] = F.pad(valid_gt_masks[index], (0, pad_num))

        vote_targets = torch.stack(vote_targets)
        vote_target_masks = torch.stack(vote_target_masks)
        center_targets = torch.stack(center_targets)
        valid_gt_masks = torch.stack(valid_gt_masks)

        assigned_center_targets = torch.stack(assigned_center_targets)
        objectness_targets = torch.stack(objectness_targets)
        objectness_weights = torch.stack(objectness_masks)
        objectness_weights /= (torch.sum(objectness_weights) + 1e-6)
        box_loss_weights = objectness_targets.float() / (
            torch.sum(objectness_targets).float() + 1e-6)
        valid_gt_weights = valid_gt_masks.float() / (
            torch.sum(valid_gt_masks.float()) + 1e-6)
        dir_class_targets = torch.stack(dir_class_targets)
        dir_res_targets = torch.stack(dir_res_targets)
        size_class_targets = torch.stack(size_class_targets)
        size_res_targets = torch.stack(size_res_targets)
        mask_targets = torch.stack(mask_targets)

        return (vote_targets, vote_target_masks, size_class_targets,
                size_res_targets, dir_class_targets, dir_res_targets,
                center_targets, assigned_center_targets, mask_targets,
                valid_gt_masks, objectness_targets, objectness_weights,
                box_loss_weights, valid_gt_weights)

    def get_targets_single(self,
                           points,
                           gt_bboxes_3d,
                           gt_labels_3d,
                           pts_semantic_mask=None,
                           pts_instance_mask=None,
                           aggregated_points=None):
        """Generate targets of vote head for single batch.

        Args:
            points (torch.Tensor): Points of each batch.
            gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \
                boxes of each batch.
            gt_labels_3d (torch.Tensor): Labels of each batch.
            pts_semantic_mask (None | torch.Tensor): Point-wise semantic
                label of each batch.
            pts_instance_mask (None | torch.Tensor): Point-wise instance
                label of each batch.
            aggregated_points (torch.Tensor): Aggregated points from
                vote aggregation layer.

        Returns:
            tuple[torch.Tensor]: Targets of vote head.
        """
        assert self.bbox_coder.with_rot or pts_semantic_mask is not None

        gt_bboxes_3d = gt_bboxes_3d.to(points.device)

        # generate votes target
        num_points = points.shape[0]
        if self.bbox_coder.with_rot:
            vote_targets = points.new_zeros([num_points, 3 * self.gt_per_seed])
            vote_target_masks = points.new_zeros([num_points],
                                                 dtype=torch.long)
            vote_target_idx = points.new_zeros([num_points], dtype=torch.long)
            box_indices_all = gt_bboxes_3d.points_in_boxes(points)
            for i in range(gt_labels_3d.shape[0]):
                box_indices = box_indices_all[:, i]
                indices = torch.nonzero(
                    box_indices, as_tuple=False).squeeze(-1)
                selected_points = points[indices]
                vote_target_masks[indices] = 1
                vote_targets_tmp = vote_targets[indices]
                votes = gt_bboxes_3d.gravity_center[i].unsqueeze(
                    0) - selected_points[:, :3]

                for j in range(self.gt_per_seed):
                    column_indices = torch.nonzero(
                        vote_target_idx[indices] == j,
                        as_tuple=False).squeeze(-1)
                    vote_targets_tmp[column_indices,
                                     int(j * 3):int(j * 3 +
                                                    3)] = votes[column_indices]
                    if j == 0:
                        vote_targets_tmp[column_indices] = votes[
                            column_indices].repeat(1, self.gt_per_seed)

                vote_targets[indices] = vote_targets_tmp
                vote_target_idx[indices] = torch.clamp(
                    vote_target_idx[indices] + 1, max=2)
        elif pts_semantic_mask is not None:
            vote_targets = points.new_zeros([num_points, 3])
            vote_target_masks = points.new_zeros([num_points],
                                                 dtype=torch.long)

            for i in torch.unique(pts_instance_mask):
                indices = torch.nonzero(
                    pts_instance_mask == i, as_tuple=False).squeeze(-1)
                if pts_semantic_mask[indices[0]] < self.num_classes:
                    selected_points = points[indices, :3]
                    center = 0.5 * (
                        selected_points.min(0)[0] + selected_points.max(0)[0])
                    vote_targets[indices, :] = center - selected_points
                    vote_target_masks[indices] = 1
            vote_targets = vote_targets.repeat((1, self.gt_per_seed))
        else:
            raise NotImplementedError

        (center_targets, size_class_targets, size_res_targets,
         dir_class_targets,
         dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d)

        proposal_num = aggregated_points.shape[0]
        distance1, _, assignment, _ = chamfer_distance(
            aggregated_points.unsqueeze(0),
            center_targets.unsqueeze(0),
            reduction='none')
        assignment = assignment.squeeze(0)
        euclidean_distance1 = torch.sqrt(distance1.squeeze(0) + 1e-6)

        objectness_targets = points.new_zeros((proposal_num), dtype=torch.long)
        objectness_targets[
            euclidean_distance1 < self.train_cfg['pos_distance_thr']] = 1

        objectness_masks = points.new_zeros((proposal_num))
        objectness_masks[
            euclidean_distance1 < self.train_cfg['pos_distance_thr']] = 1.0
        objectness_masks[
            euclidean_distance1 > self.train_cfg['neg_distance_thr']] = 1.0

        dir_class_targets = dir_class_targets[assignment]
        dir_res_targets = dir_res_targets[assignment]
        dir_res_targets /= (np.pi / self.num_dir_bins)
        size_class_targets = size_class_targets[assignment]
        size_res_targets = size_res_targets[assignment]

        one_hot_size_targets = gt_bboxes_3d.tensor.new_zeros(
            (proposal_num, self.num_sizes))
        one_hot_size_targets.scatter_(1, size_class_targets.unsqueeze(-1), 1)
        one_hot_size_targets = one_hot_size_targets.unsqueeze(-1).repeat(
            1, 1, 3)
        mean_sizes = size_res_targets.new_tensor(
            self.bbox_coder.mean_sizes).unsqueeze(0)
        pos_mean_sizes = torch.sum(one_hot_size_targets * mean_sizes, 1)
        size_res_targets /= pos_mean_sizes

        mask_targets = gt_labels_3d[assignment]
        assigned_center_targets = center_targets[assignment]

        return (vote_targets, vote_target_masks, size_class_targets,
                size_res_targets, dir_class_targets,
                dir_res_targets, center_targets, assigned_center_targets,
                mask_targets.long(), objectness_targets, objectness_masks)

    def get_bboxes(self,
                   points,
                   bbox_preds,
                   input_metas,
                   rescale=False,
                   use_nms=True):
        """Generate bboxes from vote head predictions.

        Args:
            points (torch.Tensor): Input points.
            bbox_preds (dict): Predictions from vote head.
            input_metas (list[dict]): Point cloud and image's meta info.
            rescale (bool): Whether to rescale bboxes.
            use_nms (bool): Whether to apply NMS, skip nms postprocessing
                while using vote head in rpn stage.

        Returns:
            list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
        """
        # decode boxes
        obj_scores = F.softmax(bbox_preds['obj_scores'], dim=-1)[..., -1]
        sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1)
        bbox3d = self.bbox_coder.decode(bbox_preds)

        if use_nms:
            batch_size = bbox3d.shape[0]
            results = list()
            for b in range(batch_size):
                bbox_selected, score_selected, labels = \
                    self.multiclass_nms_single(obj_scores[b], sem_scores[b],
                                               bbox3d[b], points[b, ..., :3],
                                               input_metas[b])
                bbox = input_metas[b]['box_type_3d'](
                    bbox_selected,
                    box_dim=bbox_selected.shape[-1],
                    with_yaw=self.bbox_coder.with_rot)
                results.append((bbox, score_selected, labels))

            return results
        else:
            return bbox3d

    def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
                              input_meta):
        """Multi-class nms in single batch.

        Args:
            obj_scores (torch.Tensor): Objectness score of bounding boxes.
            sem_scores (torch.Tensor): semantic class score of bounding boxes.
            bbox (torch.Tensor): Predicted bounding boxes.
            points (torch.Tensor): Input points.
            input_meta (dict): Point cloud and image's meta info.

        Returns:
            tuple[torch.Tensor]: Bounding boxes, scores and labels.
        """
        bbox = input_meta['box_type_3d'](
            bbox,
            box_dim=bbox.shape[-1],
            with_yaw=self.bbox_coder.with_rot,
            origin=(0.5, 0.5, 0.5))
        box_indices = bbox.points_in_boxes(points)

        corner3d = bbox.corners
        minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
        minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
        minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]

        nonempty_box_mask = box_indices.T.sum(1) > 5

        bbox_classes = torch.argmax(sem_scores, -1)
        nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask],
                                      obj_scores[nonempty_box_mask],
                                      bbox_classes[nonempty_box_mask],
                                      self.test_cfg.nms_thr)

        # filter empty boxes and boxes with low score
        scores_mask = (obj_scores > self.test_cfg.score_thr)
        nonempty_box_inds = torch.nonzero(
            nonempty_box_mask, as_tuple=False).flatten()
        nonempty_mask = torch.zeros_like(bbox_classes).scatter(
            0, nonempty_box_inds[nms_selected], 1)
        selected = (nonempty_mask.bool() & scores_mask.bool())

        if self.test_cfg.per_class_proposal:
            bbox_selected, score_selected, labels = [], [], []
            for k in range(sem_scores.shape[-1]):
                bbox_selected.append(bbox[selected].tensor)
                score_selected.append(obj_scores[selected] *
                                      sem_scores[selected][:, k])
                labels.append(
                    torch.zeros_like(bbox_classes[selected]).fill_(k))
            bbox_selected = torch.cat(bbox_selected, 0)
            score_selected = torch.cat(score_selected, 0)
            labels = torch.cat(labels, 0)
        else:
            bbox_selected = bbox[selected].tensor
            score_selected = obj_scores[selected]
            labels = bbox_classes[selected]

        return bbox_selected, score_selected, labels
Beispiel #7
0
    def __init__(self,
                 num_classes,
                 bbox_coder,
                 train_cfg=None,
                 test_cfg=None,
                 vote_moudule_cfg=None,
                 vote_aggregation_cfg=None,
                 feat_channels=(128, 128),
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 dir_class_loss=None,
                 dir_res_loss=None,
                 size_class_loss=None,
                 size_res_loss=None,
                 semantic_loss=None):
        super(VoteHead, self).__init__()
        self.num_classes = num_classes
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_moudule_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.dir_class_loss = build_loss(dir_class_loss)
        self.dir_res_loss = build_loss(dir_res_loss)
        self.size_class_loss = build_loss(size_class_loss)
        self.size_res_loss = build_loss(size_res_loss)
        self.semantic_loss = build_loss(semantic_loss)

        assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[
            'in_channels']

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.num_sizes = self.bbox_coder.num_sizes
        self.num_dir_bins = self.bbox_coder.num_dir_bins

        self.vote_module = VoteModule(**vote_moudule_cfg)
        self.vote_aggregation = PointSAModule(**vote_aggregation_cfg)

        prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
        conv_pred_list = list()
        for k in range(len(feat_channels)):
            conv_pred_list.append(
                ConvModule(prev_channel,
                           feat_channels[k],
                           1,
                           padding=0,
                           conv_cfg=conv_cfg,
                           norm_cfg=norm_cfg,
                           bias=True,
                           inplace=True))
            prev_channel = feat_channels[k]
        self.conv_pred = nn.Sequential(*conv_pred_list)

        # Objectness scores (2), center residual (3),
        # heading class+residual (num_dir_bins*2),
        # size class+residual(num_sizes*4)
        conv_out_channel = (2 + 3 + self.num_dir_bins * 2 +
                            self.num_sizes * 4 + num_classes)
        self.conv_pred.add_module('conv_out',
                                  nn.Conv1d(prev_channel, conv_out_channel, 1))