示例#1
0
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.PSP.ENABLED:
            self.psp = PSPNet(n_classes=1)
            # self.psp = PSPNet()

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        # merge down the xyz features and image features from PSPNet
        feature_channel = cfg.RPN.FP_MLPS[0][-1]
        self.merge_down = pt_utils.SharedMLP(
            [feature_channel * 2, feature_channel], bn=cfg.RPN.USE_BN)
        # self.merge_down = pt_utils.SharedMLP([feature_channel + 512, feature_channel], bn=cfg.RPN.USE_BN)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of 128 merged rpn and PSPnet feature
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass
            else:
                raise NotImplementedError
示例#2
0
class PointRCNN(nn.Module):
    def __init__(self,
                 num_classes,
                 num_point=512,
                 use_xyz=True,
                 mode='TRAIN',
                 old_model=False):
        super().__init__()
        self.mode = mode
        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED or cfg.IOUN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode, old_model=old_model)

        if cfg.RCNN.ENABLED or cfg.IOUN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features x,y,z,r,mask
            #self.rcnn_net = RCNNNet(num_point=num_point, num_classes=num_classes, input_channels=rcnn_input_channels, use_xyz=use_xyz)
            self.rcnn_net = RCNNNet(num_point=num_point,
                                    num_classes=num_classes,
                                    input_channels=rcnn_input_channels,
                                    use_xyz=use_xyz)

    def forward(self, input_data):
        if cfg.RPN.ENABLED:
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(input_data)
                output.update(rpn_output)

        elif cfg.RCNN.ENABLED or cfg.IOUN.ENABLED:
            output = {}
            # rpn inference
            rcnn_output = self.rcnn_net(input_data)
            output.update(rcnn_output)
        else:
            raise NotImplementedError

        return output

    def rpn_forward(self, input_data):
        if cfg.RPN.ENABLED:
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(input_data)
                output.update(rpn_output)
        return output

    def rcnn_forward(self, rcnn_input_info):
        output = {}
        rcnn_output = self.rcnn_net(rcnn_input_info)
        output.update(rcnn_output)
        return output
示例#3
0
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes, input_channels=rcnn_input_channels, use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass 
            else:
                raise NotImplementedError
示例#4
0
    def __init__(self,
                 num_classes,
                 num_point=512,
                 use_xyz=True,
                 mode='TRAIN',
                 old_model=False):
        super().__init__()
        self.mode = mode
        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED or cfg.IOUN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode, old_model=old_model)

        if cfg.RCNN.ENABLED or cfg.IOUN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features x,y,z,r,mask
            #self.rcnn_net = RCNNNet(num_point=num_point, num_classes=num_classes, input_channels=rcnn_input_channels, use_xyz=use_xyz)
            self.rcnn_net = RCNNNet(num_point=num_point,
                                    num_classes=num_classes,
                                    input_channels=rcnn_input_channels,
                                    use_xyz=use_xyz)
示例#5
0
class PointRCNN(nn.Module):
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes, input_channels=rcnn_input_channels, use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass 
            else:
                raise NotImplementedError

    def forward(self, input_data):
        pointrcnn_start = datetime.now()
        if cfg.RPN.ENABLED:
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(input_data)
                output.update(rpn_output)

            # rcnn inference
            if cfg.RCNN.ENABLED:
                with torch.no_grad():
                    rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output['rpn_reg']
                    backbone_xyz, backbone_features = rpn_output['backbone_xyz'], rpn_output['backbone_features']

                    rpn_scores_raw = rpn_cls[:, :, 0]
                    rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
                    seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
                    pts_depth = torch.norm(backbone_xyz, p=2, dim=2)

                    # proposal layer
                    rois, roi_scores_raw = self.rpn.proposal_layer(rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)
                    # debug: rois.shape = [B, 512, 7]
                    # debug: roi_scores_raw.shape = [B, 512]

                    output['rois'] = rois
                    output['roi_scores_raw'] = roi_scores_raw
                    output['seg_result'] = seg_mask

                rcnn_input_info = {'rpn_xyz': backbone_xyz,
                                   'rpn_features': backbone_features.permute((0, 2, 1)),
                                   'seg_mask': seg_mask,
                                   'roi_boxes3d': rois,
                                   'pts_depth': pts_depth}
                if self.training:
                    rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']

                rcnn_output = self.rcnn_net(rcnn_input_info)
                output.update(rcnn_output)

        elif cfg.RCNN.ENABLED:
            output = self.rcnn_net(input_data)
        else:
            raise NotImplementedError

        print(f'PointRCNN forward time: {datetime.now() - pointrcnn_start}')
        return output
示例#6
0
class PointRCNN(nn.Module):
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes, input_channels=rcnn_input_channels, use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass 
            else:
                raise NotImplementedError

    
    def is_input_data_in_box3d(self, obj_list, d):
        x,y,z = d[0], d[1], d[2]
        for o in obj_list:
            '''
                xmin = o[4], ymin = o[5], xmax = o[6], ymax = [7]
            '''
            if o.cls_type == 'Car':
               o_corners = o.generate_corners3d()
               ymin = min(o_corners[0, 1], o_corners[4, 1])
               ymax = max(o_corners[0, 1], o_corners[4, 1])
               x1 = o_corners[0, 0]
               x2 = o_corners[1, 0]
               x3 = o_corners[2, 0]
               x4 = o_corners[3, 0]
               z1 = o_corners[0, 2]
               z2 = o_corners[1, 2]
               z3 = o_corners[2, 2]
               z4 = o_corners[3, 2]
               #print("y : %d" % y)
               #print("ymax : %d" % ymax)
               #print("ymin : %d" % ymin)
               if y <= ymax and y >= ymin:
                   if x <= max([x1, x2, x3, x4]) and x >= min([x1, x2, x3, x4]):
                       if z <= max([z1, z2, z3, z4]) and z >= min([z1, z2, z3, z4]):
                           return True               
        return False

    def fpnet_test(self, input_data, sample_id, rpn_scores_raw):
        input_data = input_data[0].cpu().numpy()
        sample_id = sample_id[0]
        import numpy as np
        import os
        fpnet_file = os.path.join('../data/KITTI/frustum_net/data/', '%06d.txt' % sample_id)
        import lib.utils.kitti_utils as kitti_utils
        obj_list = kitti_utils.get_objects_from_label(fpnet_file)
        for i in range(len(input_data)):
            d = input_data[i]
            if self.is_input_data_in_box3d(obj_list, d):
                rpn_scores_raw[0,i] = rpn_scores_raw[0,i] + 0.5
                #print("----------HAHAHAHAHAHAHAHA--------------> update rpn_scores_raw")
        return rpn_scores_raw
    
    def forward(self, input_data, sample_id=None):
        if cfg.RPN.ENABLED:
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(input_data)
                output.update(rpn_output)

            # rcnn inference
            if cfg.RCNN.ENABLED:
                with torch.no_grad():
                    rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output['rpn_reg']
                    backbone_xyz, backbone_features = rpn_output['backbone_xyz'], rpn_output['backbone_features']

                    rpn_scores_raw = rpn_cls[:, :, 0]
                    rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
                    seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
                    pts_depth = torch.norm(backbone_xyz, p=2, dim=2)


                    # frustum update score
                    if cfg.FUSSION.FRUSTUM_AVAILABLE and (not sample_id == None):
                        rpn_scores_raw = self.fpnet_test(input_data['pts_input'], sample_id, rpn_scores_raw)
                    
                    
                    # proposal layer
                    rois, roi_scores_raw = self.rpn.proposal_layer(rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)

                    output['rois'] = rois
                    output['roi_scores_raw'] = roi_scores_raw
                    output['seg_result'] = seg_mask

                rcnn_input_info = {'rpn_xyz': backbone_xyz,
                                   'rpn_features': backbone_features.permute((0, 2, 1)),
                                   'seg_mask': seg_mask,
                                   'roi_boxes3d': rois,
                                   'pts_depth': pts_depth}
                if self.training:
                    rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']

                rcnn_output = self.rcnn_net(rcnn_input_info)
                output.update(rcnn_output)

        elif cfg.RCNN.ENABLED:
            output = self.rcnn_net(input_data)
        else:
            raise NotImplementedError

        return output
示例#7
0
class PointRCNN(nn.Module):
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.PSP.ENABLED:
            self.psp = PSPNet(n_classes=1)
            # self.psp = PSPNet()

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        # merge down the xyz features and image features from PSPNet
        feature_channel = cfg.RPN.FP_MLPS[0][-1]
        self.merge_down = pt_utils.SharedMLP(
            [feature_channel * 2, feature_channel], bn=cfg.RPN.USE_BN)
        # self.merge_down = pt_utils.SharedMLP([feature_channel + 512, feature_channel], bn=cfg.RPN.USE_BN)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of 128 merged rpn and PSPnet feature
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass
            else:
                raise NotImplementedError

    def forward(self, input_data):
        if cfg.RPN.ENABLED:
            rpn_input_info = {'pts_input': input_data['pts_input']}  #B,C,N
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(rpn_input_info)
                output.update(rpn_output)
            # print("rpn_output['rpn_cls']: ", rpn_output['rpn_cls'].shape, rpn_output['rpn_cls'].device, rpn_output['rpn_cls'].requires_grad)
            # print("rpn_output['rpn_reg']: ", rpn_output['rpn_reg'].shape, rpn_output['rpn_reg'].device, rpn_output['rpn_reg'].requires_grad)
            # print("rpn_output['backbone_xyz']: ", rpn_output['backbone_xyz'].shape, rpn_output['backbone_xyz'].device, rpn_output['backbone_xyz'].requires_grad)
            # print("rpn_output['backbone_features']: ", rpn_output['backbone_features'].shape, rpn_output['backbone_features'].device, rpn_output['backbone_features'].requires_grad)

            # rcnn inference
            if cfg.RCNN.ENABLED:
                with torch.no_grad():
                    rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output[
                        'rpn_reg']
                    backbone_xyz, backbone_features = rpn_output[
                        'backbone_xyz'], rpn_output['backbone_features']

                    # print("rpn_output['rpn_cls']: ", rpn_output['rpn_cls'].shape, rpn_output['rpn_cls'].device, rpn_output['rpn_cls'].requires_grad)
                    # print("rpn_output['rpn_reg']: ", rpn_output['rpn_reg'].shape, rpn_output['rpn_reg'].device, rpn_output['rpn_reg'].requires_grad)
                    # print("rpn_output['backbone_xyz']: ", rpn_output['backbone_xyz'].shape, rpn_output['backbone_xyz'].device, rpn_output['backbone_xyz'].requires_grad)
                    # print("rpn_output['backbone_features']: ", rpn_output['backbone_features'].shape, rpn_output['backbone_features'].device, rpn_output['backbone_features'].requires_grad)

                    rpn_scores_raw = rpn_cls[:, :, 0]
                    rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
                    seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
                    pts_depth = torch.norm(backbone_xyz, p=2, dim=2)

                    # proposal layer
                    rois, roi_scores_raw = self.rpn.proposal_layer(
                        rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)

                    output['rois'] = rois
                    output['roi_scores_raw'] = roi_scores_raw
                    output['seg_result'] = seg_mask

                if cfg.PSP.ENABLED:
                    if cfg.PSP.FIXED:
                        self.psp.eval()
                        with torch.no_grad():
                            psp_output = self.psp(input_data['img'])
                    else:
                        self.psp.train()
                        psp_output = self.psp(input_data['img'])

                # print("this is psp_output shape: ", psp_output.shape, psp_output.device, psp_output.dtype)
                pts_img = input_data['pts_img']
                x_ind = pts_img[:, :, 0]
                y_ind = pts_img[:, :, 1]
                # print('y_ind:', y_ind.shape, y_ind.device, y_ind.dtype, y_ind.min(), y_ind.max())
                # print('x_ind:', x_ind.shape, x_ind.device, x_ind.dtype, x_ind.min(), x_ind.max())
                img_features = torch.stack([
                    psp_output[i, :, y_ind[i], x_ind[i]]
                    for i in range(psp_output.shape[0])
                ],
                                           dim=0)

                # concatenate backbone_feature and img_features and then merge them
                backbone_features = torch.cat(
                    (backbone_features.unsqueeze(dim=3),
                     img_features.unsqueeze(dim=3)),
                    dim=1)
                backbone_features = self.merge_down(backbone_features)
                backbone_features = backbone_features.squeeze(dim=3)

                rcnn_input_info = {
                    'rpn_xyz': backbone_xyz,
                    'rpn_features': backbone_features.permute(
                        (0, 2, 1)),  #(B N C)
                    'seg_mask': seg_mask,
                    'roi_boxes3d': rois,
                    'pts_depth': pts_depth
                }
                if self.training:
                    rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']

                rcnn_output = self.rcnn_net(rcnn_input_info)
                output.update(rcnn_output)

        elif cfg.RCNN.ENABLED:
            output = self.rcnn_net(input_data)
        else:
            raise NotImplementedError

        return output
示例#8
0
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass
            elif cfg.RCNN.BACKBONE == 'edgeconv':
                rcnn_input_channels = cfg.RCNN.GCN_CONFIG.FILTERS[-1] * 2
                self.rcnn_net = GCNNet(num_classes=num_classes,
                                       input_channels=rcnn_input_channels,
                                       use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'rotnet':
                rcnn_input_channels = get_num_rot(
                    cfg.RCNN.ROT_CONFIG.DEGREE_RES)
                self.rcnn_net = RotRCNN(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'deepgcn':
                # rcnn_input_channels = 256
                self.rcnn_net = DenseRCNN(num_classes=num_classes,
                                          input_channels=rcnn_input_channels,
                                          use_xyz=use_xyz)

            elif cfg.RCNN.BACKBONE == 'refine':
                # rcnn_input_channels = 256
                self.rcnn_net = RefineRCNNNet(
                    num_classes=num_classes,
                    input_channels=rcnn_input_channels,
                    use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'deeprefine':
                # rcnn_input_channels = 256
                self.rcnn_net = RefineDeepRCNNNet(
                    num_classes=num_classes,
                    input_channels=rcnn_input_channels,
                    use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'deepfeats':
                # rcnn_input_channels = 256
                self.rcnn_net = DenseFeatRCNN(
                    num_classes=num_classes,
                    input_channels=rcnn_input_channels,
                    use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'deepfeatsrefine':
                # rcnn_input_channels = 256
                self.rcnn_net = DenseFeatRefineRCNN(
                    num_classes=num_classes,
                    input_channels=rcnn_input_channels,
                    use_xyz=use_xyz)
            else:
                raise NotImplementedError
        if cfg.CGCN.ENABLED:
            self.cgcn_net = None
        print(sum(p.numel() for p in self.rcnn_net.parameters()))
示例#9
0
class PointRCNN(nn.Module):
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass
            elif cfg.RCNN.BACKBONE == 'edgeconv':
                rcnn_input_channels = cfg.RCNN.GCN_CONFIG.FILTERS[-1] * 2
                self.rcnn_net = GCNNet(num_classes=num_classes,
                                       input_channels=rcnn_input_channels,
                                       use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'rotnet':
                rcnn_input_channels = get_num_rot(
                    cfg.RCNN.ROT_CONFIG.DEGREE_RES)
                self.rcnn_net = RotRCNN(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'deepgcn':
                # rcnn_input_channels = 256
                self.rcnn_net = DenseRCNN(num_classes=num_classes,
                                          input_channels=rcnn_input_channels,
                                          use_xyz=use_xyz)

            elif cfg.RCNN.BACKBONE == 'refine':
                # rcnn_input_channels = 256
                self.rcnn_net = RefineRCNNNet(
                    num_classes=num_classes,
                    input_channels=rcnn_input_channels,
                    use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'deeprefine':
                # rcnn_input_channels = 256
                self.rcnn_net = RefineDeepRCNNNet(
                    num_classes=num_classes,
                    input_channels=rcnn_input_channels,
                    use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'deepfeats':
                # rcnn_input_channels = 256
                self.rcnn_net = DenseFeatRCNN(
                    num_classes=num_classes,
                    input_channels=rcnn_input_channels,
                    use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'deepfeatsrefine':
                # rcnn_input_channels = 256
                self.rcnn_net = DenseFeatRefineRCNN(
                    num_classes=num_classes,
                    input_channels=rcnn_input_channels,
                    use_xyz=use_xyz)
            else:
                raise NotImplementedError
        if cfg.CGCN.ENABLED:
            self.cgcn_net = None
        print(sum(p.numel() for p in self.rcnn_net.parameters()))

    def forward(self, input_data):
        if cfg.RPN.ENABLED:
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(input_data)
                output.update(rpn_output)

            # rcnn inference
            if cfg.RCNN.ENABLED:
                with torch.no_grad():
                    rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output[
                        'rpn_reg']
                    backbone_xyz, backbone_features = rpn_output[
                        'backbone_xyz'], rpn_output['backbone_features']

                    rpn_scores_raw = rpn_cls[:, :, 0]
                    rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
                    seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
                    pts_depth = torch.norm(backbone_xyz, p=2, dim=2)

                    # proposal layer
                    rois, roi_scores_raw = self.rpn.proposal_layer(
                        rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)

                    output['rois'] = rois
                    output['roi_scores_raw'] = roi_scores_raw
                    output['seg_result'] = seg_mask

                rcnn_input_info = {
                    'rpn_xyz': backbone_xyz,
                    'rpn_features': backbone_features.permute((0, 2, 1)),
                    'seg_mask': seg_mask,
                    'roi_boxes3d': rois,
                    'pts_depth': pts_depth
                }
                if self.training:
                    rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']

                rcnn_output = self.rcnn_net(rcnn_input_info)
                output.update(rcnn_output)

        elif cfg.RCNN.ENABLED:
            output = self.rcnn_net(input_data)
        else:
            raise NotImplementedError
        # print(output['roi_boxes3d'])
        # print()
        return output
示例#10
0
class cat(nn.Module):
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass
            else:
                raise NotImplementedError
        self.cascade = cascade(num_classes=num_classes,
                               input_channels=rcnn_input_channels,
                               use_xyz=use_xyz)
        #self.iou = iou_layer()
    def forward(self, input_data):
        if cfg.RPN.ENABLED:
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(input_data)
                output.update(rpn_output)

            # rcnn inference
            if cfg.RCNN.ENABLED:
                with torch.no_grad():
                    rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output[
                        'rpn_reg']
                    backbone_xyz, backbone_features = rpn_output[
                        'backbone_xyz'], rpn_output['backbone_features']
                    rpn_scores_raw = rpn_cls[:, :, 0]
                    rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
                    seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
                    pts_depth = torch.norm(backbone_xyz, p=2, dim=2)

                    # proposal layer
                    rois, roi_scores_raw = self.rpn.proposal_layer(
                        rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)
                    #print(rois.size())
                    output['rois'] = rois
                    output['roi_scores_raw'] = roi_scores_raw
                    output['seg_result'] = seg_mask
                #print(rois.shape)
                rcnn_input_info = {
                    'rpn_xyz': backbone_xyz,
                    'rpn_features': backbone_features.permute((0, 2, 1)),
                    'seg_mask': seg_mask,
                    'roi_boxes3d': rois,
                    'pts_depth': pts_depth
                }
                if self.training:
                    rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']
                rcnn_output = self.rcnn_net(rcnn_input_info)
                output.update(rcnn_output)
                if cfg.TRAIN.IOU_LAYER != 'split':
                    cascade_input = rcnn_input_info.copy()
                    cascade_input['pred_boxes3d_1st'] = output[
                        'pred_boxes3d_1st']
                    cascade_ouput = self.cascade(cascade_input)
                    output.update(cascade_ouput)

                if cfg.TRAIN.IOU_LAYER == 'split':
                    iou_out = self.iou(rcnn_output)
                    output.update(iou_out)

            ## get new roi
        elif cfg.RCNN.ENABLED:
            output = self.rcnn_net(input_data)
        else:
            raise NotImplementedError

        return output
class GeneralizedPointRCNN(nn.Module):
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass
            else:
                raise NotImplementedError

        if cfg.DA.ENABLED:
            self.da_rpn = da_rpn(cfg)
            self.da_rcnn = da_rcnn(cfg)

    def forward(self, input_data):
        """
        :param input_data:dict
            pts_input:B,16384,3
            gt_boxes3d:B,14(?),7
            is_source:B*2,
        :return:
        """
        if cfg.RPN.ENABLED:
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(input_data)
                output.update(rpn_output)
                if cfg.DA.ENABLED and cfg.DA.DA_IMG.ENABLED:
                    if cfg.RPN.FIXED:
                        self.da_rpn.eval()
                    da_rpn_output = self.da_rpn(
                        rpn_output['backbone_features'])
                    output.update(da_rpn_output)
            """
            rpn_output:
                rpn_cls: B,N,1(Foreground Mask)
                rpn_reg: B,N,76(3D RoIs)
                backbone_xyz B,N,3(Point Coords)
                backbone_features: B,128,N(Semantic Features)
            """
            # rcnn inference
            if cfg.RCNN.ENABLED:
                with torch.no_grad():
                    rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output[
                        'rpn_reg']
                    backbone_xyz, backbone_features = rpn_output[
                        'backbone_xyz'], rpn_output['backbone_features']

                    rpn_scores_raw = rpn_cls[:, :, 0]
                    rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
                    seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
                    pts_depth = torch.norm(
                        backbone_xyz, p=2,
                        dim=2)  # distance, sqrt(x**2+y**2+z**2)
                    ipdb.set_trace()
                    # proposal layer
                    rois, roi_scores_raw = self.rpn.proposal_layer(
                        rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)
                    output['rois'] = rois
                    output['roi_scores_raw'] = roi_scores_raw
                    output['seg_result'] = seg_mask

                rcnn_input_info = {
                    'rpn_xyz':
                    backbone_xyz,
                    'rpn_features':
                    backbone_features.permute((0, 2, 1)).contiguous(),
                    'seg_mask':
                    seg_mask,
                    'roi_boxes3d':
                    rois,
                    'pts_depth':
                    pts_depth
                }

                if self.training:
                    rcnn_input_info['is_source'] = input_data[
                        'is_source']  # To avoid using gt when sampling
                    rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']

                rcnn_output = self.rcnn_net(rcnn_input_info)
                output.update(rcnn_output)

                if cfg.DA.ENABLED and cfg.DA.DA_INS.ENABLED and self.training:
                    da_rcnn_output = self.da_rcnn(
                        output['l_features'])  # l_features: [B*64, 512, 133]
                    output.update(da_rcnn_output)

        elif cfg.RCNN.ENABLED:
            output = self.rcnn_net(input_data)
            ipdb.set_trace()
        else:
            raise NotImplementedError
        return output
示例#12
0
class PointRCNN(nn.Module):
    def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            print("6 : point_rcnn.py & rpn.py")
            self.rpn = RPN(use_xyz=use_xyz, mode=mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes=num_classes,
                                        input_channels=rcnn_input_channels,
                                        use_xyz=use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass
            else:
                raise NotImplementedError

    def forward(self, input_data):
        if cfg.RPN.ENABLED:  # RPN.ENABLED=TRUE, RPN.FIXED=FALSE
            output = {}
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                # print("10 : rpn input keys when RPN.FIXED ", input_data.keys()) #### ['pts_input', 'gt_boxes3d']
                # print(" input_data : pts_input", input_data['pts_input'].size()) # size([16,16384,3])
                # print("input_data : gt_boxes3d", input_data['gt_boxes3d'].size()) # size([16,23(22,19),7])
                rpn_output = self.rpn(input_data)
                output.update(rpn_output)
                # print("11 : rpn output keys when RPN.FIXED ", rpn_output.keys()) #### ['rpn_cls', 'rpn_reg', 'backbone_xyz', 'backbone_features']

            # rcnn inference
            if cfg.RCNN.ENABLED:
                with torch.no_grad():
                    rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output[
                        'rpn_reg']
                    backbone_xyz, backbone_features = rpn_output[
                        'backbone_xyz'], rpn_output['backbone_features']

                    rpn_scores_raw = rpn_cls[:, :,
                                             0]  # rpn scoring when class = 1
                    rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
                    seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
                    pts_depth = torch.norm(backbone_xyz, p=2, dim=2)
                    # print("depth", pts_depth.size()) #### size([4,16384])

                    # proposal layer
                    # print("12 : rois, roi_scores_raw, seg_result by using proposal_layer")
                    rois, roi_scores_raw = self.rpn.proposal_layer(
                        rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)
                    #### rois size([4,512,7]), roi_scores_raw size([4,512])
                    output['rois'] = rois
                    output['roi_scores_raw'] = roi_scores_raw
                    output['seg_result'] = seg_mask

                rcnn_input_info = {
                    'rpn_xyz': backbone_xyz,
                    'rpn_features': backbone_features.permute((0, 2, 1)),
                    'seg_mask': seg_mask,
                    'roi_boxes3d': rois,
                    'pts_depth': pts_depth
                }
                if self.training:  # Use it
                    rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']

                rcnn_output = self.rcnn_net(rcnn_input_info)  #### core point!
                output.update(rcnn_output)

        elif cfg.RCNN.ENABLED:  # Don't use this
            output = self.rcnn_net(input_data)
        else:
            raise NotImplementedError
        # print("rcnn_input_info", rcnn_input_info.keys())
        # dict_keys(['rpn_xyz', 'rpn_features', 'seg_mask', 'roi_boxes3d', 'pts_depth', 'gt_boxes3d'])
        # print("1", rcnn_input_info['rpn_xyz'].size())
        # print("2", rcnn_input_info['rpn_features'].size())
        # print("3", rcnn_input_info['seg_mask'].size())
        # print("4", rcnn_input_info['roi_boxes3d'].size())
        # print("5", rcnn_input_info['gt_boxes3d'].size())

        # print(output.keys()) #### ['rpn_cls', 'rpn_reg', 'backbone_xyz', 'backbone_features', 'rois',
        #'roi_scores_raw', 'seg_result', 'rcnn_cls', 'rcnn_reg', 'sampled_pts', 'pts_feature', 'cls_label',
        # 'reg_valid_mask', 'gt_of_rois', 'gt_iou', 'roi_boxes3d', 'pts_input']

        # print("1", output['rcnn_cls'].size())
        # print("2", output['rcnn_reg'].size())
        # print("3", output['sampled_pts'].size())
        # print("4", output['cls_label'].size())
        # print("5", output['pts_feature'].size())
        # print("6", output['reg_valid_mask'].size())
        # print("7", output['gt_of_rois'].size())
        # print("8", output['gt_iou'].size())
        # print("9", output['roi_boxes3d'].size())
        # print("10", output['pts_input'].size())

        # print("rpn_cls", output['rpn_cls'].size()) # size([16,16384,1])
        # print("rpn_reg", output['rpn_reg'].size()) # size([16,16384,76])
        # print("backbone_xyz", output['backbone_xyz'].size()) # size([16,16384,3])
        # print("backbone_features", output['backbone_features'].size()) # size([16,128,16384])
        # print("rois", output['rois'].size()) # size([4,512,7])
        # print("roi_scores_raw", output['roi_scores_raw'].size()) # size([4,512])
        # print("seg_result", output['seg_result'].size()) # size([4,16384])

    #     output check {'rpn_cls': tensor([[[-5.1735],
    #      [-3.7231],
    #      [-2.9883],
    #      ...,
    #      [-3.7360],
    #      [-3.6476],
    #      [-6.0314]],

    #     [[-2.3807],
    #      [-3.1805],
    #      [-2.7425],
    #      ...,
    #      [-3.3026],
    #      [-0.8901],
    #      [-2.9785]],

    #     [[-3.9357],
    #      [-4.6314],
    #      [-2.8741],
    #      ...,
    #      [-4.5779],
    #      [-5.5495],
    #      [-3.5850]],

    #     ...,

    #     [[-5.1414],
    #      [-3.7547],
    #      [-2.9199],
    #      ...,
    #      [-2.7652],
    #      [-4.4252],
    #      [-3.0268]],

    #     [[-4.8873],
    #      [-4.7900],
    #      [-2.2165],
    #      ...,
    #      [-5.0135],
    #      [-3.7314],
    #      [-2.1040]],

    #     [[-4.6535],
    #      [-4.5354],
    #      [-2.1991],
    #      ...,
    #      [-3.9017],
    #      [-5.5090],
    #      [-3.3456]]], device='cuda:0', grad_fn=<TransposeBackward0>),

    #      'rpn_reg': tensor([[[-0.0030, -0.0025, -0.0270,  ...,  0.0106, -0.0045,  0.0100],
    #      [-0.0069, -0.0056,  0.0013,  ..., -0.0135, -0.0121, -0.0205],
    #      [-0.0073, -0.0028, -0.0037,  ..., -0.0119,  0.0050, -0.0087],
    #      ...,
    #      [-0.0160, -0.0036,  0.0042,  ...,  0.0078, -0.0131, -0.0103],
    #      [-0.0097, -0.0009,  0.0159,  ...,  0.0069,  0.0044, -0.0213],
    #      [-0.0197, -0.0046,  0.0093,  ..., -0.0091, -0.0036, -0.0064]],

    #     [[-0.0044, -0.0001,  0.0101,  ..., -0.0037, -0.0018, -0.0181],
    #      [-0.0141, -0.0067,  0.0206,  ...,  0.0149, -0.0037, -0.0043],
    #      [ 0.0045, -0.0007,  0.0050,  ...,  0.0048,  0.0027, -0.0060],
    #      ...,
    #      [-0.0080, -0.0132,  0.0084,  ..., -0.0030, -0.0059, -0.0044],
    #      [-0.0191, -0.0097,  0.0027,  ..., -0.0109, -0.0052,  0.0017],
    #      [-0.0431, -0.0046,  0.0214,  ..., -0.0081,  0.0418, -0.0160]],

    #     [[-0.0224,  0.0354,  0.0456,  ...,  0.0470,  0.0364,  0.0129],
    #      [-0.0003, -0.0068,  0.0088,  ..., -0.0124,  0.0023, -0.0108],
    #      [-0.0140,  0.0025, -0.0173,  ..., -0.0129, -0.0042, -0.0022],
    #      ...,
    #      [-0.0142,  0.0074, -0.0136,  ...,  0.0014, -0.0230, -0.0038],
    #      [ 0.0053, -0.0010,  0.0010,  ...,  0.0003, -0.0128, -0.0053],
    #      [-0.0144,  0.0066, -0.0089,  ..., -0.0130, -0.0001,  0.0043]],

    #     ...,

    #     [[-0.0073,  0.0144,  0.0047,  ..., -0.0046, -0.0127, -0.0053],
    #      [-0.0073,  0.0143,  0.0134,  ..., -0.0121, -0.0172, -0.0087],
    #      [-0.0012,  0.0008,  0.0075,  ..., -0.0115,  0.0037,  0.0113],
    #      ...,
    #      [-0.0160, -0.0008,  0.0014,  ..., -0.0096,  0.0009,  0.0168],
    #      [-0.0077, -0.0204,  0.0016,  ...,  0.0069, -0.0034, -0.0109],
    #      [-0.0118, -0.0002, -0.0119,  ...,  0.0076, -0.0003, -0.0028]],

    #     [[-0.0058, -0.0029, -0.0242,  ...,  0.0193, -0.0070, -0.0189],
    #      [-0.0080,  0.0025,  0.0093,  ...,  0.0026, -0.0154,  0.0006],
    #      [-0.0163, -0.0162,  0.0073,  ...,  0.0118, -0.0049,  0.0109],
    #      ...,
    #      [-0.0152, -0.0099,  0.0081,  ...,  0.0083, -0.0066, -0.0163],
    #      [-0.0061, -0.0071,  0.0113,  ...,  0.0025,  0.0034,  0.0019],
    #      [-0.0297, -0.0089,  0.0032,  ...,  0.0131, -0.0165, -0.0016]],

    #     [[-0.0328, -0.0086,  0.0042,  ...,  0.0095, -0.0195, -0.0202],
    #      [-0.0044, -0.0077, -0.0021,  ..., -0.0024,  0.0027, -0.0147],
    #      [-0.0057, -0.0080, -0.0051,  ..., -0.0027, -0.0138, -0.0278],
    #      ...,
    #      [ 0.0021, -0.0028,  0.0149,  ...,  0.0022, -0.0052, -0.0128],
    #      [-0.0211, -0.0081,  0.0074,  ..., -0.0160, -0.0039, -0.0115],
    #      [-0.0288, -0.0105,  0.0017,  ..., -0.0049, -0.0030,  0.0006]]],
    #    device='cuda:0', grad_fn=<CloneBackward>),

    #      'backbone_xyz': tensor([[[-3.0450e+00,  1.7477e+00,  7.3536e+00],
    #      [-2.8207e+00,  1.8023e+00,  1.5552e+01],
    #      [-1.4563e+00,  1.7466e+00,  9.0663e+00],
    #      ...,
    #      [-3.2104e+00,  7.5752e-01,  1.7776e+01],
    #      [-4.7241e+00,  1.6880e+00,  5.2166e+01],
    #      [-3.2135e+00,  1.7506e+00,  9.0108e+00]],

    #     [[ 1.2817e+00,  1.5900e+00,  7.0689e+00],
    #      [ 4.3956e+00,  1.5504e+00,  7.4190e+00],
    #      [-2.1143e+00,  1.0660e+00,  2.4041e+01],
    #      ...,
    #      [-1.4653e+00,  4.8930e-01,  2.4777e+01],
    #      [ 4.5721e+00,  6.7786e-01,  3.2771e+01],
    #      [-2.9633e+00,  7.6913e-01,  4.0196e+00]],

    #     [[ 1.8854e+00,  5.2625e-01,  9.7074e+00],
    #      [ 2.4766e+00,  1.6312e+00,  7.0634e+00],
    #      [-3.9135e+00,  9.0114e-01,  6.5021e+00],
    #      ...,
    #      [ 2.6864e+00,  1.5195e+00,  7.9244e+00],
    #      [-2.2287e+01,  5.8842e-01,  3.9904e+01],
    #      [ 1.3065e-02,  1.7124e+00,  8.1955e+00]],

    #     ...,

    #     [[ 7.3371e+00,  1.5727e+00,  8.6214e+00],
    #      [ 2.9540e+00,  1.7296e+00,  8.6262e+00],
    #      [ 1.0131e+00,  1.7175e+00,  6.2225e+00],
    #      ...,
    #      [ 6.7662e+00,  1.7713e+00,  1.2571e+01],
    #      [-1.4129e+01,  1.0701e+00,  3.4923e+01],
    #      [-4.6669e+00,  1.5269e+00,  1.7717e+01]],

    #     [[ 5.8294e+00,  1.4071e+00,  1.2038e+01],
    #      [ 5.7820e+00,  1.4821e+00,  7.5465e+00],
    #      [ 7.6756e+00,  1.5017e-01,  1.3766e+01],
    #      ...,
    #      [-3.9948e+00,  1.0493e+00,  8.4988e+00],
    #      [ 7.1555e+00, -2.5194e-01,  1.3790e+01],
    #      [ 3.2816e+00,  4.8616e-01,  1.2042e+01]],

    #     [[ 7.5882e+00,  4.0229e-01,  1.5010e+01],
    #      [ 1.1825e+01, -1.9168e-01,  3.3753e+01],
    #      [-2.4167e+00,  1.0656e+00,  1.1514e+01],
    #      ...,
    #      [ 7.3970e+00,  2.1952e+00,  2.0686e+01],
    #      [-1.3037e-01,  1.7683e+00,  8.4944e+00],
    #      [ 1.6258e-01,  9.2454e-01,  1.1520e+01]]], device='cuda:0'),

    #      'backbone_features': tensor([[[1.0240, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
    #      [0.8422, 0.0000, 0.2286,  ..., 0.0000, 0.0000, 0.0000],
    #      [0.6216, 0.3254, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
    #      ...,
    #      [0.0000, 0.7835, 0.3120,  ..., 0.0247, 1.8844, 0.0000],
    #      [1.7096, 0.7817, 0.7809,  ..., 0.2252, 0.0000, 0.4022],
    #      [0.0000, 0.6971, 0.0171,  ..., 0.7339, 1.2453, 0.4075]],

    #     [[0.0000, 0.0485, 0.0000,  ..., 0.0000, 0.0000, 1.0796],
    #      [0.3091, 0.0000, 0.0242,  ..., 0.0000, 0.0000, 2.4629],
    #      [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.2058],
    #      ...,
    #      [0.7341, 0.1059, 0.7387,  ..., 2.2113, 0.4392, 0.0000],
    #      [0.0000, 0.7646, 0.5568,  ..., 1.0912, 1.4438, 0.0000],
    #      [0.3786, 0.0000, 0.9027,  ..., 0.2513, 0.0612, 0.0000]],

    #     [[0.0000, 0.8502, 1.0199,  ..., 1.9348, 0.2216, 0.0000],
    #      [2.1404, 0.4681, 0.8735,  ..., 1.2392, 0.0000, 0.0000],
    #      [0.0000, 0.0992, 0.6485,  ..., 0.6603, 0.0000, 0.0000],
    #      ...,
    #      [0.0000, 0.0000, 0.0000,  ..., 0.6013, 0.0000, 0.0000],
    #      [0.0000, 0.3969, 0.0000,  ..., 0.7706, 1.1524, 1.2900],
    #      [0.0000, 0.4175, 0.0000,  ..., 0.2451, 0.0000, 0.8827]],

    #     ...,

    #     [[0.0000, 0.0000, 0.4496,  ..., 0.0000, 0.0000, 0.0000],
    #      [0.0000, 1.2469, 2.0409,  ..., 0.8542, 0.0000, 0.0000],
    #      [0.0000, 0.0000, 0.3108,  ..., 0.0000, 0.0000, 0.0000],
    #      ...,
    #      [2.1831, 0.3466, 1.2687,  ..., 0.2440, 0.9746, 0.8589],
    #      [0.0000, 0.0000, 0.0000,  ..., 0.5533, 0.0664, 0.0000],
    #      [0.7207, 0.0000, 0.5719,  ..., 0.0000, 0.3842, 0.7616]],

    #     [[0.0000, 1.5484, 0.0000,  ..., 0.0000, 0.0000, 0.3902],
    #      [0.0000, 0.0000, 0.0000,  ..., 1.0491, 0.0000, 1.9582],
    #      [0.3808, 0.4163, 0.0116,  ..., 0.8172, 1.2071, 0.4874],
    #      ...,
    #      [0.1024, 0.2911, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
    #      [0.0000, 1.0773, 0.3174,  ..., 0.9181, 0.6023, 0.0000],
    #      [0.0000, 0.0000, 0.6266,  ..., 0.0000, 0.7492, 0.0000]],

    #     [[0.0000, 0.2948, 0.0000,  ..., 0.0000, 0.0000, 1.1005],
    #      [0.0000, 0.0000, 2.1901,  ..., 0.0000, 0.7972, 0.5424],
    #      [0.7746, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
    #      ...,
    #      [0.0000, 0.1079, 0.0000,  ..., 1.5799, 0.7394, 0.0000],
    #      [0.0000, 0.7342, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
    #      [0.0000, 0.0000, 0.0000,  ..., 1.2627, 1.7921, 0.7948]]],
    #    device='cuda:0', grad_fn=<SqueezeBackward1>)}
        return output