def data_augmentation(self, pts, rois, gt_of_rois):
        """
        :param pts: (B, M, 512, 3)
        :param rois: (B, M. 7)
        :param gt_of_rois: (B, M, 7)
        :return:
        """
        batch_size, boxes_num = pts.shape[0], pts.shape[1]

        # rotation augmentation
        angles = (torch.rand((batch_size, boxes_num), device=pts.device) - 0.5 / 0.5) * (np.pi / cfg.AUG_ROT_RANGE)

        # calculate gt alpha from gt_of_rois
        temp_x, temp_z, temp_ry = gt_of_rois[:, :, 0], gt_of_rois[:, :, 2], gt_of_rois[:, :, 6]
        temp_beta = torch.atan2(temp_z, temp_x)
        gt_alpha = -torch.sign(temp_beta) * np.pi / 2 + temp_beta + temp_ry  # (B, M)

        temp_x, temp_z, temp_ry = rois[:, :, 0], rois[:, :, 2], rois[:, :, 6]
        temp_beta = torch.atan2(temp_z, temp_x)
        roi_alpha = -torch.sign(temp_beta) * np.pi / 2 + temp_beta + temp_ry  # (B, M)

        for k in range(batch_size):
            pts[k] = kitti_utils.rotate_pc_along_y_torch(pts[k], angles[k])
            gt_of_rois[k] = kitti_utils.rotate_pc_along_y_torch(gt_of_rois[k].unsqueeze(dim=1), angles[k]).squeeze(dim=1)
            rois[k] = kitti_utils.rotate_pc_along_y_torch(rois[k].unsqueeze(dim=1), angles[k]).squeeze(dim=1)

            # calculate the ry after rotation
            temp_x, temp_z = gt_of_rois[:, :, 0], gt_of_rois[:, :, 2]
            temp_beta = torch.atan2(temp_z, temp_x)
            gt_of_rois[:, :, 6] = torch.sign(temp_beta) * np.pi / 2 + gt_alpha - temp_beta

            temp_x, temp_z = rois[:, :, 0], rois[:, :, 2]
            temp_beta = torch.atan2(temp_z, temp_x)
            rois[:, :, 6] = torch.sign(temp_beta) * np.pi / 2 + roi_alpha - temp_beta

        # scaling augmentation
        scales = 1 + ((torch.rand((batch_size, boxes_num), device=pts.device) - 0.5) / 0.5) * 0.05
        pts = pts * scales.unsqueeze(dim=2).unsqueeze(dim=3)
        gt_of_rois[:, :, 0:6] = gt_of_rois[:, :, 0:6] * scales.unsqueeze(dim=2)
        rois[:, :, 0:6] = rois[:, :, 0:6] * scales.unsqueeze(dim=2)

        # flip augmentation
        flip_flag = torch.sign(torch.rand((batch_size, boxes_num), device=pts.device) - 0.5)
        pts[:, :, :, 0] = pts[:, :, :, 0] * flip_flag.unsqueeze(dim=2)
        gt_of_rois[:, :, 0] = gt_of_rois[:, :, 0] * flip_flag
        # flip orientation: ry > 0: pi - ry, ry < 0: -pi - ry
        src_ry = gt_of_rois[:, :, 6]
        ry = (flip_flag == 1).float() * src_ry + (flip_flag == -1).float() * (torch.sign(src_ry) * np.pi - src_ry)
        gt_of_rois[:, :, 6] = ry

        rois[:, :, 0] = rois[:, :, 0] * flip_flag
        # flip orientation: ry > 0: pi - ry, ry < 0: -pi - ry
        src_ry = rois[:, :, 6]
        ry = (flip_flag == 1).float() * src_ry + (flip_flag == -1).float() * (torch.sign(src_ry) * np.pi - src_ry)
        rois[:, :, 6] = ry

        return pts, rois, gt_of_rois
    def canonical_transform_batch(pts_input, roi_boxes3d, gt_boxes3d):
        """
        :param pts_input: (N, npoints, 3 + C)
        :param roi_boxes3d: (N, 7)
        :param gt_boxes3d: (N, 7)
        :return:
        """
        roi_ry = roi_boxes3d[:, 6] % (2 * np.pi)  # 0 ~ 2pi
        roi_center = roi_boxes3d[:, 0:3]
        # shift to center
        pts_input[:, :, [0, 1, 2]] = pts_input[:, :, [0, 1, 2]] - roi_center.reshape(-1, 1, 3)
        gt_boxes3d_ct = np.copy(gt_boxes3d)
        gt_boxes3d_ct[:, 0:3] = gt_boxes3d_ct[:, 0:3] - roi_center
        # rotate to the direction of head
        gt_boxes3d_ct = kitti_utils.rotate_pc_along_y_torch(torch.from_numpy(gt_boxes3d_ct.reshape(-1, 1, 7)),
                                                            torch.from_numpy(roi_ry)).numpy().reshape(-1, 7)
        gt_boxes3d_ct[:, 6] = gt_boxes3d_ct[:, 6] - roi_ry
        pts_input = kitti_utils.rotate_pc_along_y_torch(torch.from_numpy(pts_input), torch.from_numpy(roi_ry)).numpy()

        return pts_input, gt_boxes3d_ct
Exemple #3
0
    def roipooling(self, input_data):
        rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data[
            'rpn_features']
        batch_rois = input_data['roi_boxes3d']
        if self.training == False and cfg.RCNN.ENABLED and not cfg.RPN.ENABLED:
            dd = -1
        else:
            dd = 2
        if cfg.RCNN.USE_INTENSITY:
            pts_extra_input_list = [
                input_data['rpn_intensity'].unsqueeze(dim=2),
                input_data['seg_mask'].unsqueeze(dim=2)
            ]
        else:
            #print(input_data['seg_mask'].shape)
            pts_extra_input_list = [input_data['seg_mask'].unsqueeze(dim=dd)]

        if cfg.RCNN.USE_DEPTH:
            pts_depth = input_data['pts_depth'] / 70.0 - 0.5
            pts_extra_input_list.append(pts_depth.unsqueeze(dim=dd))
        pts_extra_input = torch.cat(pts_extra_input_list, dim=dd)

        pts_feature = torch.cat((pts_extra_input, rpn_features), dim=dd)
        if self.training == False and cfg.RCNN.ENABLED and not cfg.RPN.ENABLED:
            batch_rois = torch.squeeze(batch_rois, 1)

        if self.training == False and cfg.RCNN.ENABLED and not cfg.RPN.ENABLED:
            rpn_xyz = rpn_xyz.unsqueeze(dim=0)
            pts_feature = pts_feature.unsqueeze(dim=0)
            batch_rois = batch_rois.unsqueeze(dim=0)
        #print(rpn_xyz.shape,pts_feature.shape,batch_rois.shape)
        pooled_features, pooled_empty_flag = \
            roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                          sampled_pt_num=cfg.RCNN.NUM_POINTS)

        # canonical transformation

        batch_size = batch_rois.shape[0]
        roi_center = batch_rois[:, :, 0:3]
        #print(batch_rois.shape,roi_center.shape)
        #print(pooled_features.shape,roi_center.shape)
        pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
        for k in range(batch_size):
            pooled_features[k, :, :,
                            0:3] = kitti_utils.rotate_pc_along_y_torch(
                                pooled_features[k, :, :, 0:3], batch_rois[k, :,
                                                                          6])

        pts_input = pooled_features.view(-1, pooled_features.shape[2],
                                         pooled_features.shape[3])
        return pts_input
Exemple #4
0
    def forward(self, input_dict):
        roi_boxes3d, gt_boxes3d = input_dict['roi_boxes3d'], input_dict[
            'gt_boxes3d']

        batch_rois, batch_gt_of_rois, batch_roi_iou = self.sample_rois_for_rcnn(
            roi_boxes3d, gt_boxes3d)

        rpn_xyz, rpn_features = input_dict['rpn_xyz'], input_dict[
            'rpn_features']
        if cfg.RCNN.USE_INTENSITY:
            pts_extra_input_list = [
                input_dict['rpn_intensity'].unsqueeze(dim=2),
                input_dict['seg_mask'].unsqueeze(dim=2)
            ]
        else:
            pts_extra_input_list = [input_dict['seg_mask'].unsqueeze(dim=2)]

        if cfg.RCNN.USE_DEPTH:
            pts_depth = input_dict['pts_depth'] / 70.0 - 0.5
            pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))

        if cfg.RCNN.USE_RGB:
            pts_rgb = input_dict['pts_rgb']
            pts_extra_input_list.append(pts_rgb)

        pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

        # point cloud pooling
        pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
        pooled_features, pooled_empty_flag = \
            roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                          sampled_pt_num = cfg.RCNN.NUM_POINTS)

        sampled_pts, sampled_features = pooled_features[:, :, :, 0:
                                                        3], pooled_features[:, :, :,
                                                                            3:]
        mask_score = pooled_features[:, :, :, 3].sum(-1) / cfg.RCNN.NUM_POINTS

        # data augmentation
        if cfg.AUG_DATA:
            # data augmentation
            sampled_pts, batch_rois, batch_gt_of_rois = \
                self.data_augmentation(sampled_pts, batch_rois, batch_gt_of_rois)

        # canonical transformation
        batch_size = batch_rois.shape[0]
        roi_ry = batch_rois[:, :, 6] % (2 * np.pi)
        roi_center = batch_rois[:, :, 0:3]
        sampled_pts = sampled_pts - roi_center.unsqueeze(
            dim=2)  # (B, M, 512, 3)
        batch_gt_of_rois[:, :, 0:3] = batch_gt_of_rois[:, :, 0:3] - roi_center
        batch_gt_of_rois[:, :, 6] = batch_gt_of_rois[:, :, 6] - roi_ry

        for k in range(batch_size):
            sampled_pts[k] = kitti_utils.rotate_pc_along_y_torch(
                sampled_pts[k], batch_rois[k, :, 6])
            batch_gt_of_rois[k] = kitti_utils.rotate_pc_along_y_torch(
                batch_gt_of_rois[k].unsqueeze(dim=1), roi_ry[k]).squeeze(dim=1)

        # regression valid mask
        valid_mask = (pooled_empty_flag == 0)
        reg_valid_mask = ((batch_roi_iou > cfg.RCNN.REG_FG_THRESH)
                          & valid_mask).long()

        # classification label
        batch_cls_label = (batch_roi_iou > cfg.RCNN.CLS_FG_THRESH).long()
        invalid_mask = (batch_roi_iou > cfg.RCNN.CLS_BG_THRESH) & (
            batch_roi_iou < cfg.RCNN.CLS_FG_THRESH)
        batch_cls_label[valid_mask == 0] = -1
        batch_cls_label[invalid_mask > 0] = -1

        output_dict = {
            'sampled_pts':
            sampled_pts.view(-1, cfg.RCNN.NUM_POINTS, 3),
            'pts_feature':
            sampled_features.view(-1, cfg.RCNN.NUM_POINTS,
                                  sampled_features.shape[3]),
            'cls_label':
            batch_cls_label.view(-1),
            'mask_score':
            mask_score.view(-1),
            'reg_valid_mask':
            reg_valid_mask.view(-1),
            'gt_of_rois':
            batch_gt_of_rois.view(-1, 7),
            'gt_iou':
            batch_roi_iou.view(-1),
            'roi_boxes3d':
            batch_rois.view(-1, 7)
        }

        return output_dict
Exemple #5
0
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """
        if cfg.RCNN.ROI_SAMPLE_JIT:
            if self.training:
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(input_data)

                pts_input = torch.cat((target_dict['sampled_pts'], target_dict['pts_feature']), dim=2)
                target_dict['pts_input'] = pts_input
            else:
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data['rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [input_data['rpn_intensity'].unsqueeze(dim=2),
                                            input_data['seg_mask'].unsqueeze(dim=2)]
                else:
                    pts_extra_input_list = [input_data['seg_mask'].unsqueeze(dim=2)]

                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data['pts_depth'] / 70.0 - 0.5
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :, 0:3] = kitti_utils.rotate_pc_along_y_torch(pooled_features[k, :, :, 0:3],
                                                                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2], pooled_features.shape[3])
        else:
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data['reg_valid_mask']
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']

        xyz, features = self._break_up_pc(pts_input)

        if cfg.RCNN.USE_RPN_FEATURES: ## True
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(1, 2).unsqueeze(dim=3)
            xyz_feature = self.xyz_up_layer(xyz_input)

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(1, 2).unsqueeze(dim=3)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1)
            merged_feature = self.merge_down_layer(merged_feature)
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):
            li_xyz, li_features, _ = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)

        rcnn_cls = self.cls_layer(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, C)
        if cfg.USE_IOU_BRANCH:
            rcnn_iou_branch = self.iou_branch(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B,1)
            ret_dict = {'rcnn_cls': rcnn_cls, 'rcnn_reg': rcnn_reg, 'rcnn_iou_branch': rcnn_iou_branch}
        else:
            ret_dict = {'rcnn_cls': rcnn_cls, 'rcnn_reg': rcnn_reg}

        if self.training:
            ret_dict.update(target_dict)
        return ret_dict
Exemple #6
0
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """
        if cfg.RCNN.ROI_SAMPLE_JIT:
            if self.training:
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(input_data)
                """
                target_dict:
                    sampled_pts [B*64, 512, 3]
                    pts_feature [B*64, 512, 130]
                    cls_label [B*64]
                    reg_valid_mask [B*64]
                    gt_of_rois [B*64, 7]
                    gt_iou [B*64]
                    roi_boxes3d [B*64, 7]
                """
                pts_input = torch.cat((target_dict['sampled_pts'], target_dict['pts_feature']), dim=2)
                target_dict['pts_input'] = pts_input
                ipdb.set_trace()
            else:
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data['rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [input_data['rpn_intensity'].unsqueeze(dim=2),
                                            input_data['seg_mask'].unsqueeze(dim=2)]
                else:
                    pts_extra_input_list = [input_data['seg_mask'].unsqueeze(dim=2)]
                ipdb.set_trace()
                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data['pts_depth'] / 70.0 - 0.5
                    if cfg.RCNN.USE_MAX_DENSITY and cfg.DA.INS.ENABLED:
                        max_depth = 30.0*input_data['is_source'] + 40
                        pts_depth = input_data['pts_depth'] / max_depth - 0.5
                        ipdb.set_trace()
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :, 0:3] = kitti_utils.rotate_pc_along_y_torch(pooled_features[k, :, :, 0:3],
                                                                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2], pooled_features.shape[3])
        else:
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data['reg_valid_mask']
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']

        xyz, features = self._break_up_pc(pts_input) # (64, 512, 3), (64,130,512)

        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(1, 2).unsqueeze(dim=3) # (64, 5, 512, 1)
            xyz_feature = self.xyz_up_layer(xyz_input) # [128, 128, 512, 1]

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(1, 2).unsqueeze(dim=3) # (64, 128, 512, 1)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1) # [128, 256, 512, 1]
            merged_feature = self.merge_down_layer(merged_feature) # (64, 128, 512, 1)
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)
        #l_features: [B*64, 128, 512] [B*64, 128, 128] [B*64, 256, 32] [B*64, 512, 1]
        rcnn_cls = self.cls_layer(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, C)
        ret_dict = {'rcnn_cls': rcnn_cls, 'rcnn_reg': rcnn_reg}
        if self.training:
            target_dict['l_features'] = l_features[-1] # For DA
            ret_dict.update(target_dict)
        return ret_dict
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """
        if cfg.RCNN.ROI_SAMPLE_JIT:
            if self.training:
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(input_data,
                                                             stage=1)

                pts_input = torch.cat(
                    (target_dict['sampled_pts'], target_dict['pts_feature']),
                    dim=2)
                target_dict['pts_input'] = pts_input
            else:
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data[
                    'rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [
                        input_data['rpn_intensity'].unsqueeze(dim=2),
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]
                else:
                    pts_extra_input_list = [
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]

                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data['pts_depth'] / 70.0 - 0.5
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :,
                                    0:3] = kitti_utils.rotate_pc_along_y_torch(
                                        pooled_features[k, :, :, 0:3],
                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2],
                                                 pooled_features.shape[3])
        else:
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data['reg_valid_mask']
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']

        xyz, features = self._break_up_pc(pts_input)

        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(
                1, 2).unsqueeze(dim=3)
            xyz_feature = self.xyz_up_layer(xyz_input)

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(
                1, 2).unsqueeze(dim=3)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1)
            merged_feature = self.merge_down_layer(merged_feature)
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)

        batch_size = input_data['roi_boxes3d'].size(0)
        rcnn_cls = self.cls_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        #tt=torch.rand(rcnn_reg.shape[0],rcnn_reg.shape[1]).cuda()
        #tt=self.test_layer(pts_input.permute((0,2,1)))
        #tt=tt[:,:,0]
        #rcnn_reg=tt
        #rcnn_cls=tt[:,0]
        #print(tt.size(),rcnn_cls.size())
        roi_boxes3d = target_dict['roi_boxes3d'].view(-1, 7)
        anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
        #print(rcnn_reg.size(),roi_boxes3d.size())
        pred_boxes3d_1st = decode_bbox_target(
            roi_boxes3d.view(-1, 7),
            rcnn_reg.view(-1, rcnn_reg.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        #print(pred_boxes3d.size()) (B,64,7)

        ## 2nd stage
        #print(input_data['roi_boxes3d'].size())
        input_data2 = input_data
        input_data2['roi_boxes3d'] = pred_boxes3d_1st
        #print(input_data['roi_boxes3d'].size())
        with torch.no_grad():
            target_dict_2nd = self.proposal_target_layer(input_data2, stage=2)

        pts_input_2 = torch.cat(
            (target_dict_2nd['sampled_pts'], target_dict_2nd['pts_feature']),
            dim=2)
        target_dict_2nd['pts_input'] = pts_input_2

        xyz_2, features_2 = self._break_up_pc(pts_input_2)
        #print(xyz_2.size(),xyz.size(),features_2.size(),features.size())
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_2 = pts_input_2[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_2 = self.xyz_up_layer(xyz_input_2)

            rpn_feature_2 = pts_input_2[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_2 = torch.cat((xyz_feature_2, rpn_feature_2), dim=1)
            merged_feature_2 = self.merge_down_layer(merged_feature_2)
            l_xyz_2, l_features_2 = [xyz_2], [merged_feature_2.squeeze(dim=3)]
        else:
            l_xyz__2, l_features_2 = [xyz_2], [features_2]
        #print(l_xyz_2[0].size(), l_xyz[0].size(), l_features_2[0].size(), l_features[0].size())
        for i in range(len(self.SA_modules)):
            li_xyz_2, li_features_2 = self.SA_modules[i](l_xyz_2[i],
                                                         l_features_2[i])
            l_xyz_2.append(li_xyz_2)
            l_features_2.append(li_features_2)
        del xyz, features, l_features

        rcnn_cls_2nd = self.cls_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_2nd = self.reg_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        #loss
        '''
        cls_label = target_dict_2nd['cls_label'].float()
        cls_label_flat = cls_label.view(-1)
        rcnn_cls_flat_2nd = rcnn_cls_2nd.view(-1)
        cls_valid_mask = (cls_label_flat >= 0).float()
        batch_loss_cls_2nd = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat_2nd), cls_label, reduction='none')
        rcnn_loss_cls_2nd = (batch_loss_cls_2nd * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
        #rcnn_loss_cls_2nd.backward()
        '''
        sec = {'rcnn_cls_2nd': rcnn_cls_2nd, 'rcnn_reg_2nd': rcnn_reg_2nd}
        pred_boxes3d_2nd = decode_bbox_target(
            pred_boxes3d_1st.view(-1, 7),
            rcnn_reg_2nd.view(-1, rcnn_reg_2nd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        ## 3rd
        #print(pred_boxes3d_2nd.size())
        input_data['roi_boxes3d'] = pred_boxes3d_2nd
        with torch.no_grad():
            target_dict_3nd = self.proposal_target_layer(input_data, stage=3)
        pts_input_3 = torch.cat(
            (target_dict_3nd['sampled_pts'], target_dict_3nd['pts_feature']),
            dim=2)
        target_dict_3nd['pts_input'] = pts_input_3
        xyz_3, features_3 = self._break_up_pc(pts_input_3)

        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_3 = pts_input[..., 0:self.rcnn_input_channel].transpose(
                1, 2).unsqueeze(dim=3)
            xyz_feature_3 = self.xyz_up_layer(xyz_input_3)

            rpn_feature_3 = pts_input_3[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_3 = torch.cat((xyz_feature_3, rpn_feature_3), dim=1)
            merged_feature_3 = self.merge_down_layer(merged_feature_3)
            l_xyz_3, l_features_3 = [xyz_3], [merged_feature_3.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz_3], [features_3]

        for i in range(len(self.SA_modules)):
            li_xyz_3, li_features_3 = self.SA_modules[i](l_xyz_3[i],
                                                         l_features_3[i])
            l_xyz_3.append(li_xyz_3)
            l_features_3.append(li_features_3)
        del xyz_2, features_2, l_features_2
        rcnn_cls_3rd = self.cls_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_3rd = self.reg_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        pred_boxes3d_3rd = decode_bbox_target(
            pred_boxes3d_2nd.view(-1, 7),
            rcnn_reg_3rd.view(-1, rcnn_reg_3rd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        ret_dict = {
            'rcnn_cls': rcnn_cls,
            'rcnn_reg': rcnn_reg,
            'rcnn_cls_3rd': rcnn_cls_3rd,
            'rcnn_reg_3rd': rcnn_reg_3rd,
            'pred_boxes3d_1st': pred_boxes3d_1st,
            'pred_boxes3d_2nd': pred_boxes3d_2nd,
            'pred_boxes3d_3rd': pred_boxes3d_3rd
        }
        ret_dict.update(sec)
        if self.training:
            ret_dict.update(target_dict)
        return ret_dict
Exemple #8
0
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """

        if cfg.RCNN.ROI_SAMPLE_JIT:
            if self.training:
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(input_data,
                                                             stage=1)

                pts_input = torch.cat(
                    (target_dict['sampled_pts'], target_dict['pts_feature']),
                    dim=2)
                target_dict['pts_input'] = pts_input
            else:
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data[
                    'rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [
                        input_data['rpn_intensity'].unsqueeze(dim=2),
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]
                else:
                    pts_extra_input_list = [
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]

                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data['pts_depth'] / 70.0 - 0.5
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :,
                                    0:3] = kitti_utils.rotate_pc_along_y_torch(
                                        pooled_features[k, :, :, 0:3],
                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2],
                                                 pooled_features.shape[3])
        else:
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                #input_data['ori_roi'] = torch.cat((input_data['ori_roi'], input_data['roi_boxes3d']), 1)
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data[
                    'reg_valid_mask'].view(-1)
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']
        #print(pts_input.shape)
        pts_input = pts_input.view(-1, 512, 128 + self.rcnn_input_channel)
        xyz, features = self._break_up_pc(pts_input)
        anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(
                1, 2).unsqueeze(dim=3)
            #xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(1, 2)

            xyz_feature = self.xyz_up_layer(xyz_input)

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(
                1, 2).unsqueeze(dim=3)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1)
            merged_feature = self.merge_down_layer(merged_feature)
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):

            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)

        batch_size = input_data['roi_boxes3d'].size(0)
        batch_size_2 = pts_input.shape[0]  # for loss fun
        #print(input_data['roi_boxes3d'].shape,pts_input.shape)
        rcnn_cls = self.cls_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        if self.training:
            roi_boxes3d = target_dict['roi_boxes3d'].view(-1, 7)
            cls_label = target_dict['cls_label'].float()
            rcnn_cls_flat = rcnn_cls.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat),
                cls_label.view(-1),
                reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict['gt_of_rois']
            reg_valid_mask = target_dict['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #print(rcnn_reg.view(batch_size_2, -1)[fg_mask].shape)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            one = {
                'rcnn_loss_cls': rcnn_loss_cls,
                'rcnn_loss_reg': rcnn_loss_reg
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask
        else:
            roi_boxes3d = input_data['roi_boxes3d'].view(-1, 7)
            one = {}
        #print(rcnn_reg.size(),roi_boxes3d.size())
        #print(roi_boxes3d.shape, rcnn_reg.shape)
        pred_boxes3d_1st = decode_bbox_target(
            roi_boxes3d.view(-1, 7),
            rcnn_reg.view(-1, rcnn_reg.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        if self.training == False and cfg.RCNN.ENABLED and not cfg.RPN.ENABLED:
            pred_boxes3d_1st = pred_boxes3d_1st.view(-1, 7)

        input_data2 = input_data.copy()

        #print(input_data['roi_boxes3d'].size())
        if self.training:
            #input_data2['roi_boxes3d'] = torch.cat((pred_boxes3d_1st, input_data['ori_roi']), 1)
            input_data2['roi_boxes3d'] = torch.cat(
                (pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            #input_data2['roi_boxes3d'] = input_data['gt_boxes3d']
            #input_data2['roi_boxes3d'] = pred_boxes3d_1st
            #print(input_data2['roi_boxes3d'].shape)
            with torch.no_grad():
                target_dict_2nd = self.proposal_target_layer(input_data2,
                                                             stage=2)
            '''
            reg_valid_mask = target_dict_2nd['reg_valid_mask']
            fg_mask_num2 = (reg_valid_mask > 0).sum()
            if fg_mask_num2< 10*batch_size:
                input_data2['roi_boxes3d'] = torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
                with torch.no_grad():
                    target_dict_2nd = self.proposal_target_layer(input_data2, stage=2)
            '''
            pts_input_2 = torch.cat((target_dict_2nd['sampled_pts'],
                                     target_dict_2nd['pts_feature']),
                                    dim=2)
            target_dict_2nd['pts_input'] = pts_input_2
            roi = target_dict_2nd['roi_boxes3d']

        else:
            input_data2['roi_boxes3d'] = pred_boxes3d_1st
            #input_data2['roi_boxes3d']=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            roi = pred_boxes3d_1st
            #roi=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            pts_input_2 = self.roipooling(input_data2)
        #print(pts_input_2.shape)
        xyz_2, features_2 = self._break_up_pc(pts_input_2)
        #print(xyz_2.size(),xyz.size(),features_2.size(),features.size())
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_2 = pts_input_2[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_2 = self.xyz_up_layer(xyz_input_2)

            rpn_feature_2 = pts_input_2[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_2 = torch.cat((xyz_feature_2, rpn_feature_2), dim=1)
            merged_feature_2 = self.merge_down_layer(merged_feature_2)
            l_xyz_2, l_features_2 = [xyz_2], [merged_feature_2.squeeze(dim=3)]
        else:
            l_xyz__2, l_features_2 = [xyz_2], [features_2]
        #print(l_xyz_2[0].size(), l_xyz[0].size(), l_features_2[0].size(), l_features[0].size())
        for i in range(len(self.SA_modules)):
            li_xyz_2, li_features_2 = self.SA_modules[i](l_xyz_2[i],
                                                         l_features_2[i])
            l_xyz_2.append(li_xyz_2)
            l_features_2.append(li_features_2)
        del xyz, features, l_features

        rcnn_cls_2nd = self.cls_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_2nd = self.reg_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        #loss
        if self.training:
            cls_label = target_dict_2nd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_2nd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat),
                cls_label.view(-1),
                reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_2nd['gt_of_rois']
            reg_valid_mask = target_dict_2nd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #print(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0))
            if rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            two = {
                'rcnn_loss_cls_2nd': rcnn_loss_cls,
                'rcnn_loss_reg_2nd': rcnn_loss_reg
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask

        else:
            two = {}

        sec = {'rcnn_cls_2nd': rcnn_cls_2nd, 'rcnn_reg_2nd': rcnn_reg_2nd}
        #print(input_data['roi_boxes3d'].shape,input_data2['roi_boxes3d'].shape)

        pred_boxes3d_2nd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_2nd.view(-1, rcnn_reg_2nd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        ## 3rd
        #print(target_dict['roi_boxes3d'].shape,target_dict_2nd['roi_boxes3d'].shape)
        #print(pred_boxes3d_1st.shape,input_data['roi_boxes3d'].shape)
        #print(target_dict['gt_of_rois']+target_dict['roi_boxes3d'],target_dict_2nd['gt_of_rois']+target_dict_2nd['roi_boxes3d'])
        input_data3 = input_data2.copy()
        #del input_data2

        if self.training:
            input_data3['roi_boxes3d'] = torch.cat(
                (pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            #input_data3['roi_boxes3d'] = input_data2['gt_boxes3d']
            #input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            #print(input_data3['roi_boxes3d'].shape)
            with torch.no_grad():
                target_dict_3rd = self.proposal_target_layer(input_data3,
                                                             stage=3)
            '''
            reg_valid_mask = target_dict_3rd['reg_valid_mask']
            fg_mask_num3 = (reg_valid_mask > 0).sum()
            
            if fg_mask_num3.item() < 10 * batch_size:
                input_data3['roi_boxes3d'] = torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
                with torch.no_grad():
                    target_dict_3rd = self.proposal_target_layer(input_data2, stage=3)
            '''
            #print(fg_mask_num2.item(),fg_mask_num3.item())
            pts_input_3 = torch.cat((target_dict_3rd['sampled_pts'],
                                     target_dict_3rd['pts_feature']),
                                    dim=2)
            target_dict_3rd['pts_input'] = pts_input_3
            roi = target_dict_3rd['roi_boxes3d']
        else:
            input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            #input_data3['roi_boxes3d']=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            roi = pred_boxes3d_2nd
            #roi=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            pts_input_3 = self.roipooling(input_data3)
        xyz_3, features_3 = self._break_up_pc(pts_input_3)

        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_3 = pts_input_3[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_3 = self.xyz_up_layer(xyz_input_3)

            rpn_feature_3 = pts_input_3[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_3 = torch.cat((xyz_feature_3, rpn_feature_3), dim=1)
            merged_feature_3 = self.merge_down_layer(merged_feature_3)
            l_xyz_3, l_features_3 = [xyz_3], [merged_feature_3.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz_3], [features_3]

        for i in range(len(self.SA_modules)):
            li_xyz_3, li_features_3 = self.SA_modules[i](l_xyz_3[i],
                                                         l_features_3[i])
            l_xyz_3.append(li_xyz_3)
            l_features_3.append(li_features_3)
        del xyz_2, features_2, l_features_2
        rcnn_cls_3rd = self.cls_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_3rd = self.reg_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)

        #loss
        if self.training:
            cls_label = target_dict_3rd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_3rd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_3rd['gt_of_rois']
            reg_valid_mask = target_dict_3rd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #cls_mask=(target_dict_3rd['cls_label']>0)
            #print(rcnn_reg_3rd.view(batch_size_2, -1)[cls_mask].size(0))
            #print(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0))
            if rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            three = {
                'rcnn_loss_cls_3rd': rcnn_loss_cls,
                'rcnn_loss_reg_3rd': rcnn_loss_reg
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask

        else:
            three = {}
        pred_boxes3d_3rd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_3rd.view(-1, rcnn_reg_3rd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        ret_dict = {
            'rcnn_cls': rcnn_cls,
            'rcnn_reg': rcnn_reg,
            'rcnn_cls_3rd': rcnn_cls_3rd,
            'rcnn_reg_3rd': rcnn_reg_3rd,
            'pred_boxes3d_1st': pred_boxes3d_1st,
            'pred_boxes3d_2nd': pred_boxes3d_2nd,
            'pred_boxes3d_3rd': pred_boxes3d_3rd
        }
        ret_dict.update(sec)
        ret_dict.update(one)
        ret_dict.update(two)
        ret_dict.update(three)
        if self.training:
            ret_dict.update(target_dict)
        return ret_dict
Exemple #9
0
    def forward(self, input_dict):
        roi_boxes3d, gt_boxes3d = input_dict['roi_boxes3d'], input_dict[
            'gt_boxes3d']
        # print("roi boxes3d", roi_boxes3d.size()) # size([4, 512, 7])
        # print("gt boxes3d", gt_boxes3d.size()) # size([4, 24..., 7])
        batch_rois, batch_gt_of_rois, batch_roi_iou, cls_list = self.sample_rois_for_rcnn(
            roi_boxes3d, gt_boxes3d)
        # print("batch rois", batch_rois.size()) # size([4, 64, 7])
        # print("batch gt of rois", batch_gt_of_rois.size()) # size([4, 64, 7])
        # print("batch roi iou", batch_roi_iou) # size([4, 64]) == 256
        # tensor([[0.6967, 0.5912, 0.6044, 0.6927, 0.5814, 0.6073, 0.5659, 0.5636, 0.5824,
        #  0.6014, 0.5500, 0.7392, 0.6286, 0.5571, 0.5547, 0.6427, 0.6389, 0.6798,
        #  0.5665, 0.7235, 0.5719, 0.6230, 0.6209, 0.5710, 0.5914, 0.6589, 0.6718,
        #  0.6226, 0.6686, 0.5762, 0.5874, 0.6522, 0.1930, 0.2827, 0.1654, 0.2159,
        #  0.1009, 0.2866, 0.3317, 0.0569, 0.0657, 0.1869, 0.0328, 0.0452, 0.1108,
        #  0.3662, 0.2471, 0.0362, 0.0560, 0.1665, 0.3016, 0.0098, 0.3256, 0.2459,
        #  0.1474, 0.1535, 0.3206, 0.2376, 0.0000, 0.0000, 0.0000, 0.0540, 0.0000, 0.0000]

        rpn_xyz, rpn_features = input_dict['rpn_xyz'], input_dict[
            'rpn_features']

        if cfg.RCNN.USE_INTENSITY:  # False
            pts_extra_input_list = [
                input_dict['rpn_intensity'].unsqueeze(dim=2),
                input_dict['seg_mask'].unsqueeze(dim=2)
            ]
        else:
            pts_extra_input_list = [input_dict['seg_mask'].unsqueeze(dim=2)]
            # print("pts extra input list", pts_extra_input_list) #### 0 or 1 value

        if cfg.RCNN.USE_DEPTH:  # True
            pts_depth = input_dict['pts_depth'] / 70.0 - 0.5  # wonder why it is
            pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
        pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

        # point cloud pooling
        pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
        pooled_features, pooled_empty_flag = \
            roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                          sampled_pt_num=cfg.RCNN.NUM_POINTS)

        sampled_pts, sampled_features = pooled_features[:, :, :, 0:
                                                        3], pooled_features[:, :, :,
                                                                            3:]

        # data augmentation
        if cfg.AUG_DATA:  # True
            # data augmentation
            sampled_pts, batch_rois, batch_gt_of_rois = \
                self.data_augmentation(sampled_pts, batch_rois, batch_gt_of_rois)

        # canonical transformation
        batch_size = batch_rois.shape[0]
        roi_ry = batch_rois[:, :, 6] % (2 * np.pi)
        roi_center = batch_rois[:, :, 0:3]
        sampled_pts = sampled_pts - roi_center.unsqueeze(
            dim=2)  # (B, M, 512, 3)
        batch_gt_of_rois[:, :, 0:3] = batch_gt_of_rois[:, :, 0:3] - roi_center
        batch_gt_of_rois[:, :, 6] = batch_gt_of_rois[:, :, 6] - roi_ry

        for k in range(batch_size):
            sampled_pts[k] = kitti_utils.rotate_pc_along_y_torch(
                sampled_pts[k], batch_rois[k, :, 6])
            batch_gt_of_rois[k] = kitti_utils.rotate_pc_along_y_torch(
                batch_gt_of_rois[k].unsqueeze(dim=1), roi_ry[k]).squeeze(dim=1)

        # regression valid mask
        valid_mask = (pooled_empty_flag == 0)
        reg_valid_mask = ((batch_roi_iou > cfg.RCNN.REG_FG_THRESH)
                          & valid_mask).long()

        # classification label
        batch_cls_label = (batch_roi_iou >
                           cfg.RCNN.CLS_FG_THRESH).long()  # (0.6 < X)
        # print("batch_roi_iou", batch_roi_iou)
        # print("batch_cls_label before adapting invalid mask", batch_cls_label)
        invalid_mask = (batch_roi_iou > cfg.RCNN.CLS_BG_THRESH) & (
            batch_roi_iou < cfg.RCNN.CLS_FG_THRESH)
        # cfg.RCNN.CLS_BG_THRESH = 0.45 & cfg.RCNN.CLS_FG_THRESH = 0.6 (0.45 < X < 0.6)
        batch_cls_label[valid_mask == 0] = -1
        batch_cls_label[invalid_mask > 0] = -1

        batch_cls_label = batch_cls_label.view(-1)
        cls_list = cls_list.view(-1)

        ##### correlation
        for i in range(len(batch_cls_label)):
            if batch_cls_label[i] == 1:
                batch_cls_label[i] = cls_list[i]
            else:
                pass

        # print("batch_cls_label after adapting invalid mask", batch_cls_label)
        # print("gt cls_label list", cls_list)

        output_dict = {
            'sampled_pts':
            sampled_pts.view(-1, cfg.RCNN.NUM_POINTS, 3),
            'pts_feature':
            sampled_features.view(-1, cfg.RCNN.NUM_POINTS,
                                  sampled_features.shape[3]),
            'cls_label':
            batch_cls_label,
            'reg_valid_mask':
            reg_valid_mask.view(-1),
            'gt_of_rois':
            batch_gt_of_rois.view(-1, 8),  #####
            'gt_iou':
            batch_roi_iou.view(-1),
            'roi_boxes3d':
            batch_rois.view(-1, 7)
        }

        return output_dict
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """
        if cfg.RCNN.ROI_SAMPLE_JIT:
            if self.training:
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(input_data)

                pts_input = torch.cat((target_dict['sampled_pts'], target_dict['pts_feature']), dim=2)
                target_dict['pts_input'] = pts_input
            else:
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data['rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [input_data['rpn_intensity'].unsqueeze(dim=2),
                                            input_data['seg_mask'].unsqueeze(dim=2)]
                else:
                    pts_extra_input_list = [input_data['seg_mask'].unsqueeze(dim=2)]

                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data['pts_depth'] / 70.0 - 0.5
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :, 0:3] = kitti_utils.rotate_pc_along_y_torch(pooled_features[k, :, :, 0:3],
                                                                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2], pooled_features.shape[3])
        else:
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data['reg_valid_mask']
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']

        xyz, features = self._break_up_pc(pts_input)
        batch_size = input_data['roi_boxes3d'].size(0)
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(1, 2).unsqueeze(dim=3)
            xyz_feature = self.xyz_up_layer(xyz_input)

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(1, 2).unsqueeze(dim=3)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1)
            merged_feature = self.merge_down_layer(merged_feature)
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)

        rcnn_cls = self.cls_layer(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, C)
        anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
        if self.training:
            roi_boxes3d=target_dict['roi_boxes3d'].view(-1,7)
            #roi_boxes3d = input_data['roi_boxes3d']
        else:
            roi_boxes3d=input_data['roi_boxes3d']
        pred_boxes3d_1st = decode_bbox_target(roi_boxes3d.view(-1, 7), rcnn_reg.view(-1, rcnn_reg.shape[-1]),
                                              anchor_size=anchor_size,
                                              loc_scope=cfg.RCNN.LOC_SCOPE,
                                              loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                              num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                              get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                              loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                              get_ry_fine=True).view(batch_size, -1, 7)
        ret_dict = {'rcnn_cls': rcnn_cls, 'rcnn_reg': rcnn_reg,'pred_boxes3d_1st':pred_boxes3d_1st}
        ret_dict['pooled_feature'] = l_features[-1]
        if cfg.TRAIN.IOU_LAYER == 'split' and self.training:

            gt = target_dict['real_gt']
            iou_label = []
            batch_size_2 = pts_input.shape[0]
            for i in range(batch_size_2):
                iou_label.append(
                    iou3d_utils.boxes_iou3d_gpu(pred_boxes3d_1st.view(-1, 7)[i].view(1, 7), gt[i].view(1, 7)))
            iou_label = torch.cat(iou_label)
            iou_label = (iou_label - 0.5) * 2
            ret_dict['iou_label']=iou_label
        if self.training:
            ret_dict.update(target_dict)
        return ret_dict
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """
        # print("input data printing!!!", input_data.keys())
        #### dict_keys(['rpn_xyz', 'rpn_features', 'seg_mask', 'roi_boxes3d', 'pts_depth', 'gt_boxes3d'])
        if cfg.RCNN.ROI_SAMPLE_JIT:  # True
            if self.training:  #### Use this!
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(
                        input_data)  #### core point!

                pts_input = torch.cat(
                    (target_dict['sampled_pts'], target_dict['pts_feature']),
                    dim=2)
                target_dict['pts_input'] = pts_input
            else:  #### Don't use!!
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data[
                    'rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [
                        input_data['rpn_intensity'].unsqueeze(dim=2),
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]
                else:
                    pts_extra_input_list = [
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]

                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data[
                        'pts_depth'] / 70.0 - 0.5  # What does 70.0 mean?
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :,
                                    0:3] = kitti_utils.rotate_pc_along_y_torch(
                                        pooled_features[k, :, :, 0:3],
                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2],
                                                 pooled_features.shape[3])
        else:  #### Don't use!
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data['reg_valid_mask']
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']

        xyz, features = self._break_up_pc(pts_input)

        if cfg.RCNN.USE_RPN_FEATURES:  # True
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(
                1, 2).unsqueeze(dim=3)
            xyz_feature = self.xyz_up_layer(xyz_input)
            # print("xyz_input", xyz_input.size()) # size([256, 5, 512, 1])
            # print("xyz_feature", xyz_feature.size()) # size([256, 128, 512, 1])

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(
                1, 2).unsqueeze(dim=3)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1)
            # print("merged_feature", merged_feature.size()) # size([256, 256, 512, 1])
            merged_feature = self.merge_down_layer(merged_feature)
            # print("merged_feature", merged_feature.size()) # size([256, 128, 512, 1])
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
            # print("l_xyz", len(l_xyz[0])) # size([256])
            # print("l_features", len(l_features[0])) # size([256]) # 0 or specific number

        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)
        # print(l_features)

        rcnn_cls = self.cls_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B, C)

        ret_dict = {'rcnn_cls': rcnn_cls, 'rcnn_reg': rcnn_reg}
        # print("output check : rcnn_cls", ret_dict['rcnn_cls'].size()) # size([256,4])
        # print("output check : rcnn_reg", ret_dict['rcnn_reg'].size()) # size([256,46])

        if self.training:
            ret_dict.update(target_dict)
        # print(list(ret_dict.keys()))
        # ['rcnn_cls', 'rcnn_reg', 'sampled_pts', 'pts_feature', 'cls_label', 'reg_valid_mask', 'gt_of_rois', 'gt_iou', 'roi_boxes3d', 'pts_input']

        # print("sampled_pts", ret_dict['sampled_pts'].size()) # size([256,512,3])
        # print("pts_feature", ret_dict['pts_feature'].size()) # size([256,512,130])
        # print("cls_label", ret_dict['cls_label']) # size([256])
        # print("reg_valid_mask", ret_dict['reg_valid_mask'].size()) # size([256])
        # print("gt_of_rois", ret_dict['gt_of_rois'].size()) # size([256,7])
        # print("gt_iou", ret_dict['gt_iou'].size()) # size([256])
        # print("roi_boxes3d", ret_dict['roi_boxes3d'].size()) # size([256,7])
        # print("pts_input", ret_dict['pts_input'].size()) # size([256,512,133])
        return ret_dict
Exemple #12
0
    def forward(self, input_dict):
        roi_boxes3d, gt_boxes3d = input_dict['roi_boxes3d'], input_dict[
            'gt_boxes3d']

        batch_rois, batch_gt_of_rois, batch_roi_iou = self.sample_rois_for_rcnn(
            roi_boxes3d, gt_boxes3d)

        rpn_xyz, rpn_features = input_dict['rpn_xyz'], input_dict[
            'rpn_features']
        if cfg.RCNN.USE_INTENSITY:
            pts_extra_input_list = [
                input_dict['rpn_intensity'].unsqueeze(dim=2),
                input_dict['seg_mask'].unsqueeze(dim=2)
            ]
        else:
            pts_extra_input_list = [input_dict['seg_mask'].unsqueeze(dim=2)]

        if cfg.RCNN.USE_DEPTH:
            pts_depth = input_dict['pts_depth'] / 70.0 - 0.5
            pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
        pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

        # point cloud pooling
        pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)

        ori_pts_idx, ori_pooled_empty_flag = \
            simple_roipool3d_utils.simple_roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                        sampled_pt_num=cfg.RCNN.NUM_POINTS)

        pts_idx = ori_pts_idx.clone().detach()
        pooled_empty_flag = ori_pooled_empty_flag.clone().detach()
        pooled_empty_flag = pooled_empty_flag.unsqueeze(2).unsqueeze(3).float()

        batch_size, num_box, pts_num = pts_idx.size()[0], pts_idx.size(
        )[1], pts_idx.size()[2]
        fea_ch = pts_feature.size()[2]

        pts_idx = pts_idx.view(batch_size, -1,
                               1)  # [batch_size, num_box * 512, 1]
        pts_idx = pts_idx.repeat(1, 1, fea_ch +
                                 3)  # [batch_size, num_box * 512, ch + 3]

        # [batch_size, num_box * 512, ch + 3]
        pooled_features = torch.gather(
            torch.cat([rpn_xyz, pts_feature], dim=2), 1, pts_idx.long())
        pooled_features = pooled_features.view(batch_size, num_box, pts_num,
                                               -1)
        pooled_features = pooled_features * (1 - pooled_empty_flag)

        sampled_pts, sampled_features = pooled_features[:, :, :, 0:
                                                        3], pooled_features[:, :, :,
                                                                            3:]

        # data augmentation
        if cfg.AUG_DATA:
            # data augmentation
            sampled_pts, batch_rois, batch_gt_of_rois = \
                self.data_augmentation(sampled_pts, batch_rois, batch_gt_of_rois)

        # canonical transformation
        batch_size = batch_rois.shape[0]
        roi_ry = batch_rois[:, :, 6] % (2 * np.pi)
        roi_center = batch_rois[:, :, 0:3]
        sampled_pts = sampled_pts - roi_center.unsqueeze(
            dim=2)  # (B, M, 512, 3)
        batch_gt_of_rois[:, :, 0:3] = batch_gt_of_rois[:, :, 0:3] - roi_center
        batch_gt_of_rois[:, :, 6] = batch_gt_of_rois[:, :, 6] - roi_ry

        for k in range(batch_size):
            sampled_pts[k] = kitti_utils.rotate_pc_along_y_torch(
                sampled_pts[k], batch_rois[k, :, 6])
            batch_gt_of_rois[k] = kitti_utils.rotate_pc_along_y_torch(
                batch_gt_of_rois[k].unsqueeze(dim=1), roi_ry[k]).squeeze(dim=1)

        # regression valid mask
        valid_mask = (ori_pooled_empty_flag == 0)
        reg_valid_mask = ((batch_roi_iou > cfg.RCNN.REG_FG_THRESH)
                          & valid_mask).long()

        # classification label
        batch_cls_label = (batch_roi_iou > cfg.RCNN.CLS_FG_THRESH).long()
        invalid_mask = (batch_roi_iou > cfg.RCNN.CLS_BG_THRESH) & (
            batch_roi_iou < cfg.RCNN.CLS_FG_THRESH)
        batch_cls_label[valid_mask == 0] = -1
        batch_cls_label[invalid_mask > 0] = -1

        output_dict = {
            'sampled_pts':
            sampled_pts.view(-1, cfg.RCNN.NUM_POINTS, 3),
            'pts_feature':
            sampled_features.view(-1, cfg.RCNN.NUM_POINTS,
                                  sampled_features.shape[3]),
            'cls_label':
            batch_cls_label.view(-1),
            'reg_valid_mask':
            reg_valid_mask.view(-1),
            'gt_of_rois':
            batch_gt_of_rois.view(-1, 7),
            'gt_iou':
            batch_roi_iou.view(-1),
            'roi_boxes3d':
            batch_rois.view(-1, 7)
        }

        return output_dict