def forward(self, data_dicts):
        #dict_keys(['point_cloud', 'rot_angle', 'box3d_center', 'size_class', 'size_residual', 'angle_class', 'angle_residual', 'one_hot', 'label', 'center_ref1', 'center_ref2', 'center_ref3', 'center_ref4'])

        point_cloud = data_dicts.get('point_cloud')  #torch.Size([32, 4, 1024])
        one_hot = data_dicts.get('one_hot')  #torch.Size([32, 3])
        ref_label = data_dicts.get('ref_label')  #torch.Size([32, 140])
        bs = point_cloud.shape[0]

        # If not None, use to Compute Loss
        #seg_label = data_dicts.get('seg')#torch.Size([32, 1024])
        box3d_center_label = data_dicts.get(
            'box3d_center')  #torch.Size([32, 3])
        size_class_label = data_dicts.get('size_class')  #torch.Size([32])
        #size_residual_label = data_dicts.get('size_residual')  # torch.Size([32, 3])###
        #heading_class_label = data_dicts.get('angle_class')  # torch.Size([32])###
        #heading_residual_label = data_dicts.get('angle_residual')  # torch.Size([32])###

        box3d_size_label = data_dicts.get('box3d_size')  ###not residual
        box3d_heading_label = data_dicts.get('box3d_heading')  ###not residual

        center_ref1 = data_dicts.get('center_ref1')  #torch.Size([32, 3, 280])
        center_ref2 = data_dicts.get('center_ref2')  #torch.Size([32, 3, 140])
        center_ref3 = data_dicts.get('center_ref3')  #torch.Size([32, 3, 70])
        center_ref4 = data_dicts.get('center_ref4')  #torch.Size([32, 3, 35])

        object_point_cloud_xyz = point_cloud[:, :3, :].contiguous()
        if point_cloud.shape[1] == 4:
            object_point_cloud_i = point_cloud[:, [3], :].contiguous(
            )  #torch.Size([32, 1, 1024])
        elif point_cloud.shape[1] == 6:
            object_point_cloud_i = point_cloud[:, 3:6, :].contiguous(
            )  # torch.Size([32, 3, 1024])
        else:
            object_point_cloud_i = None

        mean_size_array = torch.from_numpy(g_mean_size_arr).type_as(
            point_cloud)

        feat1, feat2, feat3, feat4 = self.feat_net(
            object_point_cloud_xyz,
            [center_ref1, center_ref2, center_ref3, center_ref4],
            object_point_cloud_i, one_hot)
        #feat1:torch.Size([32, 131, 280])
        #feat2:torch.Size([32, 131, 140])
        #feat3:torch.Size([32, 131, 70])
        #feat4:torch.Size([32, 131, 35])
        x = self.conv_net(feat1, feat2, feat3,
                          feat4)  ##torch.Size([32, 768, 140])

        cls_scores = self.cls_out(x)  #torch.Size([32, 2, 140])
        outputs = self.reg_out(x)  #torch.Size([32, 39, 140])

        num_out = outputs.shape[2]
        output_size = outputs.shape[1]
        # b, c, n -> b, n, c
        cls_scores = cls_scores.permute(0, 2, 1).contiguous().view(
            -1, 2)  #torch.Size([4480, 2])
        outputs = outputs.permute(0, 2, 1).contiguous().view(
            -1, output_size)  #torch.Size([4480, 39])

        center_ref2 = center_ref2.permute(0, 2, 1).contiguous().view(
            -1, 3)  #torch.Size([4480, 3])

        cls_probs = F.softmax(cls_scores, -1)  #torch.Size([4480, 2])

        if box3d_center_label is None:  #no label == test mode or from rgb detection -> return output
            det_outputs = self._slice_output(outputs)  # torch.Size([4480, 39])
            center_boxnet, heading_scores, heading_res_norm, size_scores, size_res_norm = det_outputs

            heading_probs = F.softmax(heading_scores,
                                      -1)  # torch.Size([4480, 12])
            size_probs = F.softmax(size_scores, -1)  # torch.Size([4480, 3])

            heading_pred_label = torch.argmax(heading_probs, -1)
            size_pred_label = torch.argmax(size_probs, -1)

            center_preds = center_boxnet + center_ref2

            heading_preds = angle_decode(heading_res_norm, heading_pred_label)
            size_preds = size_decode(size_res_norm, mean_size_array,
                                     size_pred_label)

            # corner_preds = get_box3d_corners_helper(center_preds, heading_preds, size_preds)

            cls_probs = cls_probs.view(bs, -1, 2)
            center_preds = center_preds.view(bs, -1, 3)

            size_preds = size_preds.view(bs, -1, 3)
            heading_preds = heading_preds.view(bs, -1)

            outputs = (cls_probs, center_preds, heading_preds, size_preds)

            return outputs

        fg_idx = (ref_label.view(-1) == 1).nonzero().view(
            -1)  #torch.Size([99])

        assert fg_idx.numel() != 0

        outputs = outputs[fg_idx, :]  #torch.Size([99, 39])
        center_ref2 = center_ref2[fg_idx]  #torch.Size([99, 3])

        det_outputs = self._slice_output(outputs)
        center_boxnet, heading_scores, heading_res_norm, size_scores, size_res_norm = det_outputs
        #(99,3+12+12+3+3x3)

        heading_probs = F.softmax(heading_scores, -1)  #torch.Size([99, 12])
        size_probs = F.softmax(size_scores, -1)  #torch.Size([99, 3])
        # cls_loss = F.cross_entropy(cls_scores, mask_label, ignore_index=-1)
        cls_loss = softmax_focal_loss_ignore(cls_probs,
                                             ref_label.view(-1),
                                             ignore_idx=-1)

        heading_probs = F.softmax(heading_scores, -1)
        size_probs = F.softmax(size_scores, -1)
        # cls_loss = F.cross_entropy(cls_scores, mask_label, ignore_index=-1)
        cls_loss = softmax_focal_loss_ignore(cls_probs,
                                             ref_label.view(-1),
                                             ignore_idx=-1)

        # prepare label
        center_label = box3d_center_label.unsqueeze(1).expand(-1, num_out, -1)\
            .contiguous().view(-1, 3)[fg_idx]#torch.Size([99, 3])
        size_label = box3d_size_label.unsqueeze(1).expand(-1, num_out, -1)\
            .contiguous().view(-1, 3)[fg_idx]#torch.Size([99, 3])
        heading_label = box3d_heading_label.view(-1,1).expand(-1, num_out)\
            .contiguous().view(-1)[fg_idx]#torch.Size([99])
        size_class_label = size_class_label.view(-1,1).expand(-1, num_out)\
            .contiguous().view(-1)[fg_idx]#torch.Size([99])

        # encode regression targets
        center_gt_offsets = center_encode(center_label,
                                          center_ref2)  #torch.Size([99, 3])
        heading_class_label, heading_res_norm_label = angle_encode(
            heading_label)  #torch.Size([99]),torch.Size([99])
        size_res_label_norm = size_encode(
            size_label, mean_size_array,
            size_class_label)  #torch.Size([99, 3])

        # loss calculation

        # center_loss
        center_loss = self.get_center_loss(center_boxnet, center_gt_offsets)

        # heading loss
        heading_class_loss, heading_res_norm_loss = self.get_heading_loss(
            heading_scores, heading_res_norm, heading_class_label,
            heading_res_norm_label)

        # size loss
        size_class_loss, size_res_norm_loss = self.get_size_loss(
            size_scores, size_res_norm, size_class_label, size_res_label_norm)

        # corner loss regulation
        center_preds = center_decode(center_ref2, center_boxnet)
        heading = angle_decode(heading_res_norm, heading_class_label)
        size = size_decode(size_res_norm, mean_size_array, size_class_label)

        corners_loss, corner_gts = self.get_corner_loss(
            (center_preds, heading, size),
            (center_label, heading_label, size_label))

        BOX_LOSS_WEIGHT = cfg.LOSS.BOX_LOSS_WEIGHT
        CORNER_LOSS_WEIGHT = cfg.LOSS.CORNER_LOSS_WEIGHT
        HEAD_REG_WEIGHT = cfg.LOSS.HEAD_REG_WEIGHT
        SIZE_REG_WEIGHT = cfg.LOSS.SIZE_REG_WEIGHT

        # Weighted sum of all losses
        loss = cls_loss + \
            BOX_LOSS_WEIGHT * (center_loss +
                               heading_class_loss + size_class_loss +
                               HEAD_REG_WEIGHT * heading_res_norm_loss +
                               SIZE_REG_WEIGHT * size_res_norm_loss +
                               CORNER_LOSS_WEIGHT * corners_loss)

        # some metrics to monitor training status

        with torch.no_grad():

            # accuracy
            cls_prec = get_accuracy(cls_probs, ref_label.view(-1))
            heading_prec = get_accuracy(heading_probs,
                                        heading_class_label.view(-1))
            size_prec = get_accuracy(size_probs, size_class_label.view(-1))

            # iou metrics
            heading_pred_label = torch.argmax(heading_probs, -1)
            size_pred_label = torch.argmax(size_probs, -1)

            heading_preds = angle_decode(heading_res_norm, heading_pred_label)
            size_preds = size_decode(size_res_norm, mean_size_array,
                                     size_pred_label)

            corner_preds = get_box3d_corners_helper(center_preds,
                                                    heading_preds, size_preds)
            overlap = rbbox_iou_3d_pair(corner_preds.detach().cpu().numpy(),
                                        corner_gts.detach().cpu().numpy())

            iou2ds, iou3ds = overlap[:, 0], overlap[:, 1]
            iou2d_mean = iou2ds.mean()
            iou3d_mean = iou3ds.mean()
            iou3d_gt_mean = (iou3ds >= cfg.IOU_THRESH).mean()
            iou2d_mean = torch.tensor(iou2d_mean).type_as(cls_prec)
            iou3d_mean = torch.tensor(iou3d_mean).type_as(cls_prec)
            iou3d_gt_mean = torch.tensor(iou3d_gt_mean).type_as(cls_prec)

        losses = {
            'total_loss': loss,
            'cls_loss': cls_loss,
            'center_loss': center_loss,
            'heading_class_loss': heading_class_loss,
            'heading_residual_normalized_loss': heading_res_norm_loss,
            'size_class_loss': size_class_loss,
            'size_residual_normalized_loss': size_res_norm_loss,
            'corners_loss': corners_loss
        }
        metrics = {
            'cls_acc': cls_prec,
            'head_acc': heading_prec,
            'size_acc': size_prec,
            'iou2d': iou2d_mean,
            'iou3d': iou3d_mean,
            'iou3d_' + str(cfg.IOU_THRESH): iou3d_gt_mean
        }

        return losses, metrics
コード例 #2
0
    def forward(self, data_dicts):

        image = data_dicts.get('image')
        out_image = self.cnn(image)
        P = data_dicts.get('P')
        query_v1 = data_dicts.get('query_v1')

        point_cloud = data_dicts.get('point_cloud')
        one_hot_vec = data_dicts.get('one_hot')
        cls_label = data_dicts.get('label')
        size_class_label = data_dicts.get('size_class')
        center_label = data_dicts.get('box3d_center')
        heading_label = data_dicts.get('box3d_heading')
        size_label = data_dicts.get('box3d_size')

        center_ref1 = data_dicts.get('center_ref1')
        center_ref2 = data_dicts.get('center_ref2')
        center_ref3 = data_dicts.get('center_ref3')
        center_ref4 = data_dicts.get('center_ref4')

        batch_size = point_cloud.shape[0]

        object_point_cloud_xyz = point_cloud[:, :3, :].contiguous()
        if point_cloud.shape[1] > 3:
            object_point_cloud_i = point_cloud[:, [3], :].contiguous()
        else:
            object_point_cloud_i = None

        mean_size_array = torch.from_numpy(MEAN_SIZE_ARRAY).type_as(
            point_cloud)

        feat1, feat2, feat3, feat4 = self.feat_net(
            object_point_cloud_xyz,
            [center_ref1, center_ref2, center_ref3, center_ref4],
            object_point_cloud_i, one_hot_vec, out_image, P, query_v1)

        x = self.conv_net(feat1, feat2, feat3, feat4)

        cls_scores = self.cls_out(x)
        outputs = self.reg_out(x)

        num_out = outputs.shape[2]
        output_size = outputs.shape[1]
        # b, c, n -> b, n, c
        cls_scores = cls_scores.permute(0, 2, 1).contiguous().view(-1, 2)
        outputs = outputs.permute(0, 2, 1).contiguous().view(-1, output_size)

        center_ref2 = center_ref2.permute(0, 2, 1).contiguous().view(-1, 3)

        cls_probs = F.softmax(cls_scores, -1)

        if center_label is None:
            assert not self.training, 'Please provide labels for training.'

            det_outputs = self._slice_output(outputs)

            center_boxnet, heading_scores, heading_res_norm, size_scores, size_res_norm = det_outputs

            # decode
            heading_probs = F.softmax(heading_scores, -1)
            size_probs = F.softmax(size_scores, -1)

            heading_pred_label = torch.argmax(heading_probs, -1)
            size_pred_label = torch.argmax(size_probs, -1)

            center_preds = center_boxnet + center_ref2

            heading_preds = angle_decode(heading_res_norm, heading_pred_label)
            size_preds = size_decode(size_res_norm, mean_size_array,
                                     size_pred_label)

            # corner_preds = get_box3d_corners_helper(center_preds, heading_preds, size_preds)

            cls_probs = cls_probs.view(batch_size, -1, 2)
            center_preds = center_preds.view(batch_size, -1, 3)

            size_preds = size_preds.view(batch_size, -1, 3)
            heading_preds = heading_preds.view(batch_size, -1)

            outputs = (cls_probs, center_preds, heading_preds, size_preds)
            return outputs

        fg_idx = (cls_label.view(-1) == 1).nonzero().view(-1)

        assert fg_idx.numel() != 0

        outputs = outputs[fg_idx, :]
        center_ref2 = center_ref2[fg_idx]

        det_outputs = self._slice_output(outputs)

        center_boxnet, heading_scores, heading_res_norm, size_scores, size_res_norm = det_outputs

        heading_probs = F.softmax(heading_scores, -1)
        size_probs = F.softmax(size_scores, -1)

        # cls_loss = F.cross_entropy(cls_scores, mask_label, ignore_index=-1)
        cls_loss = softmax_focal_loss_ignore(cls_probs,
                                             cls_label.view(-1),
                                             ignore_idx=-1)

        # prepare label
        center_label = center_label.unsqueeze(1).expand(-1, num_out,
                                                        -1).contiguous().view(
                                                            -1, 3)[fg_idx]
        heading_label = heading_label.expand(
            -1, num_out).contiguous().view(-1)[fg_idx]
        size_label = size_label.unsqueeze(1).expand(-1, num_out,
                                                    -1).contiguous().view(
                                                        -1, 3)[fg_idx]
        size_class_label = size_class_label.expand(
            -1, num_out).contiguous().view(-1)[fg_idx]

        # encode regression targets
        center_gt_offsets = center_encode(center_label, center_ref2)
        heading_class_label, heading_res_norm_label = angle_encode(
            heading_label)
        size_res_label_norm = size_encode(size_label, mean_size_array,
                                          size_class_label)

        # loss calculation

        # center_loss
        center_loss = self.get_center_loss(center_boxnet, center_gt_offsets)

        # heading loss
        heading_class_loss, heading_res_norm_loss = self.get_heading_loss(
            heading_scores, heading_res_norm, heading_class_label,
            heading_res_norm_label)

        # size loss
        size_class_loss, size_res_norm_loss = self.get_size_loss(
            size_scores, size_res_norm, size_class_label, size_res_label_norm)

        # corner loss regulation
        center_preds = center_decode(center_ref2, center_boxnet)
        heading = angle_decode(heading_res_norm, heading_class_label)
        size = size_decode(size_res_norm, mean_size_array, size_class_label)

        corners_loss, corner_gts = self.get_corner_loss(
            (center_preds, heading, size),
            (center_label, heading_label, size_label))

        BOX_LOSS_WEIGHT = cfg.LOSS.BOX_LOSS_WEIGHT
        CORNER_LOSS_WEIGHT = cfg.LOSS.CORNER_LOSS_WEIGHT
        HEAD_REG_WEIGHT = cfg.LOSS.HEAD_REG_WEIGHT
        SIZE_REG_WEIGHT = cfg.LOSS.SIZE_REG_WEIGHT

        # Weighted sum of all losses
        loss = cls_loss + \
            BOX_LOSS_WEIGHT * (center_loss +
                               heading_class_loss + size_class_loss +
                               HEAD_REG_WEIGHT * heading_res_norm_loss +
                               SIZE_REG_WEIGHT * size_res_norm_loss +
                               CORNER_LOSS_WEIGHT * corners_loss)

        # some metrics to monitor training status

        with torch.no_grad():

            # accuracy
            cls_prec = get_accuracy(cls_probs, cls_label.view(-1))
            heading_prec = get_accuracy(heading_probs,
                                        heading_class_label.view(-1))
            size_prec = get_accuracy(size_probs, size_class_label.view(-1))

            # iou metrics
            heading_pred_label = torch.argmax(heading_probs, -1)
            size_pred_label = torch.argmax(size_probs, -1)

            heading_preds = angle_decode(heading_res_norm, heading_pred_label)
            size_preds = size_decode(size_res_norm, mean_size_array,
                                     size_pred_label)

            corner_preds = get_box3d_corners_helper(center_preds,
                                                    heading_preds, size_preds)
            overlap = rbbox_iou_3d_pair(corner_preds.detach().cpu().numpy(),
                                        corner_gts.detach().cpu().numpy())

            iou2ds, iou3ds = overlap[:, 0], overlap[:, 1]
            iou2d_mean = iou2ds.mean()
            iou3d_mean = iou3ds.mean()
            iou3d_gt_mean = (iou3ds >= cfg.IOU_THRESH).mean()
            iou2d_mean = torch.tensor(iou2d_mean).type_as(cls_prec)
            iou3d_mean = torch.tensor(iou3d_mean).type_as(cls_prec)
            iou3d_gt_mean = torch.tensor(iou3d_gt_mean).type_as(cls_prec)

        losses = {
            'total_loss': loss,
            'cls_loss': cls_loss,
            'center_loss': center_loss,
            'head_cls_loss': heading_class_loss,
            'head_res_loss': heading_res_norm_loss,
            'size_cls_loss': size_class_loss,
            'size_res_loss': size_res_norm_loss,
            'corners_loss': corners_loss
        }

        metrics = {
            'cls_acc': cls_prec,
            'head_acc': heading_prec,
            'size_acc': size_prec,
            'IoU_2D': iou2d_mean,
            'IoU_3D': iou3d_mean,
            'IoU_' + str(cfg.IOU_THRESH): iou3d_gt_mean
        }

        return losses, metrics
コード例 #3
0
def get_iou_cc(bb1, bb2):
    ious = box_ops_cc.rbbox_iou_3d_pair(bb1[np.newaxis, ...], bb2[np.newaxis, ...])
    return ious[0, 1]