def get_rcnn_loss(model, ret_dict, tb_dict):
        rcnn_cls, rcnn_reg = ret_dict['rcnn_cls'], ret_dict['rcnn_reg']

        cls_label = ret_dict['cls_label'].float()
        reg_valid_mask = ret_dict['reg_valid_mask']
        roi_boxes3d = ret_dict['roi_boxes3d']
        roi_size = roi_boxes3d[:, 3:6]
        gt_boxes3d_ct = ret_dict['gt_of_rois']
        pts_input = ret_dict['pts_input']

        # rcnn classification loss
        if isinstance(model, nn.DataParallel):
            cls_loss_func = model.module.rcnn_net.cls_loss_func
        else:
            cls_loss_func = model.rcnn_net.cls_loss_func

        cls_label_flat = cls_label.view(-1)

        if cfg.RCNN.LOSS_CLS == 'SigmoidFocalLoss':
            rcnn_cls_flat = rcnn_cls.view(-1)

            cls_target = (cls_label_flat > 0).float()
            pos = (cls_label_flat > 0).float()
            neg = (cls_label_flat == 0).float()
            cls_weights = pos + neg
            pos_normalizer = pos.sum()
            cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0)

            rcnn_loss_cls = cls_loss_func(rcnn_cls_flat, cls_target,
                                          cls_weights)
            rcnn_loss_cls_pos = (rcnn_loss_cls * pos).sum()
            rcnn_loss_cls_neg = (rcnn_loss_cls * neg).sum()
            rcnn_loss_cls = rcnn_loss_cls.sum()
            tb_dict['rpn_loss_cls_pos'] = rcnn_loss_cls_pos.item()
            tb_dict['rpn_loss_cls_neg'] = rcnn_loss_cls_neg.item()

        elif cfg.RCNN.LOSS_CLS == 'BinaryCrossEntropy':
            rcnn_cls_flat = rcnn_cls.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)

        elif cfg.TRAIN.LOSS_CLS == 'CrossEntropy':
            rcnn_cls_reshape = rcnn_cls.view(rcnn_cls.shape[0], -1)
            cls_target = cls_label_flat.long()
            cls_valid_mask = (cls_label_flat >= 0).float()

            batch_loss_cls = cls_loss_func(rcnn_cls_reshape, cls_target)
            normalizer = torch.clamp(cls_valid_mask.sum(), min=1.0)
            rcnn_loss_cls = (batch_loss_cls.mean(dim=1) *
                             cls_valid_mask).sum() / normalizer

        else:
            raise NotImplementedError

        # rcnn regression loss
        batch_size = pts_input.shape[0]
        fg_mask = (reg_valid_mask > 0)
        fg_sum = fg_mask.long().sum().item()
        if fg_sum != 0:
            all_anchor_size = roi_size
            anchor_size = all_anchor_size[
                fg_mask] if cfg.RCNN.SIZE_RES_ON_ROI else MEAN_SIZE

            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg.view(batch_size, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size, 7)[fg_mask],
                    loc_scope=cfg.RCNN.LOC_SCOPE,
                    loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                    num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                    anchor_size=anchor_size,
                    get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                    loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                    get_ry_fine=True)

            loss_size = 3 * loss_size  # consistent with old codes
            rcnn_loss_reg = loss_loc + loss_angle + loss_size
            tb_dict.update(reg_loss_dict)
        else:
            loss_loc = loss_angle = loss_size = rcnn_loss_reg = rcnn_loss_cls * 0

        rcnn_loss = rcnn_loss_cls + rcnn_loss_reg
        tb_dict['rcnn_loss_cls'] = rcnn_loss_cls.item()
        tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item()
        tb_dict['rcnn_loss'] = rcnn_loss.item()

        tb_dict['rcnn_loss_loc'] = loss_loc.item()
        tb_dict['rcnn_loss_angle'] = loss_angle.item()
        tb_dict['rcnn_loss_size'] = loss_size.item()

        tb_dict['rcnn_cls_fg'] = (cls_label > 0).sum().item()
        tb_dict['rcnn_cls_bg'] = (cls_label == 0).sum().item()
        tb_dict['rcnn_reg_fg'] = reg_valid_mask.sum().item()

        return rcnn_loss
    def get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label,
                     tb_dict):
        if isinstance(model, nn.DataParallel):
            rpn_cls_loss_func = model.module.rpn.rpn_cls_loss_func
        else:
            rpn_cls_loss_func = model.rpn.rpn_cls_loss_func

        rpn_cls_label_flat = rpn_cls_label.view(-1)
        rpn_cls_flat = rpn_cls.view(-1)
        fg_mask = (rpn_cls_label_flat > 0)

        # RPN classification loss
        if cfg.RPN.LOSS_CLS == 'DiceLoss':
            rpn_loss_cls = rpn_cls_loss_func(rpn_cls, rpn_cls_label_flat)

        elif cfg.RPN.LOSS_CLS == 'SigmoidFocalLoss':
            rpn_cls_target = (rpn_cls_label_flat > 0).float()
            pos = (rpn_cls_label_flat > 0).float()
            neg = (rpn_cls_label_flat == 0).float()
            cls_weights = pos + neg
            pos_normalizer = pos.sum()
            cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0)
            rpn_loss_cls = rpn_cls_loss_func(rpn_cls_flat, rpn_cls_target,
                                             cls_weights)
            rpn_loss_cls_pos = (rpn_loss_cls * pos).sum()
            rpn_loss_cls_neg = (rpn_loss_cls * neg).sum()
            rpn_loss_cls = rpn_loss_cls.sum()
            tb_dict['rpn_loss_cls_pos'] = rpn_loss_cls_pos.item()
            tb_dict['rpn_loss_cls_neg'] = rpn_loss_cls_neg.item()

        elif cfg.RPN.LOSS_CLS == 'BinaryCrossEntropy':
            weight = rpn_cls_flat.new(rpn_cls_flat.shape[0]).fill_(1.0)
            weight[fg_mask] = cfg.RPN.FG_WEIGHT
            rpn_cls_label_target = (rpn_cls_label_flat > 0).float()
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rpn_cls_flat),
                rpn_cls_label_target,
                weight=weight,
                reduction='none')
            cls_valid_mask = (rpn_cls_label_flat >= 0).float()
            rpn_loss_cls = (batch_loss_cls *
                            cls_valid_mask).sum() / torch.clamp(
                                cls_valid_mask.sum(), min=1.0)
        else:
            raise NotImplementedError

        # RPN regression loss
        point_num = rpn_reg.size(0) * rpn_reg.size(1)
        fg_sum = fg_mask.long().sum().item()
        if fg_sum != 0:
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rpn_reg.view(point_num, -1)[fg_mask],
                                        rpn_reg_label.view(point_num, 7)[fg_mask],
                    loc_scope=cfg.RPN.LOC_SCOPE,
                    loc_bin_size=cfg.RPN.LOC_BIN_SIZE,
                    num_head_bin=cfg.RPN.NUM_HEAD_BIN,
                    anchor_size=MEAN_SIZE,
                    get_xz_fine=cfg.RPN.LOC_XZ_FINE,
                    get_y_by_bin=False,
                    get_ry_fine=False)

            loss_size = 3 * loss_size  # consistent with old codes
            rpn_loss_reg = loss_loc + loss_angle + loss_size
        else:
            loss_loc = loss_angle = loss_size = rpn_loss_reg = rpn_loss_cls * 0

        rpn_loss = rpn_loss_cls * cfg.RPN.LOSS_WEIGHT[
            0] + rpn_loss_reg * cfg.RPN.LOSS_WEIGHT[1]

        tb_dict.update({
            'rpn_loss_cls': rpn_loss_cls.item(),
            'rpn_loss_reg': rpn_loss_reg.item(),
            'rpn_loss': rpn_loss.item(),
            'rpn_fg_sum': fg_sum,
            'rpn_loss_loc': loss_loc.item(),
            'rpn_loss_angle': loss_angle.item(),
            'rpn_loss_size': loss_size.item()
        })

        return rpn_loss
Exemple #3
0
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """
        input_data2 = input_data.copy()
        pred_boxes3d_1st = input_data2['pred_boxes3d_1st']
        ret_dict = {}
        batch_size = input_data['roi_boxes3d'].size(0)
        if self.training:

            input_data2['roi_boxes3d'] = pred_boxes3d_1st
            with torch.no_grad():
                target_dict_2nd = self.proposal_target_layer(input_data2,
                                                             stage=2)
            pts_input_2 = torch.cat((target_dict_2nd['sampled_pts'],
                                     target_dict_2nd['pts_feature']),
                                    dim=2)
            target_dict_2nd['pts_input'] = pts_input_2
            roi = target_dict_2nd['roi_boxes3d']
            #roi = pred_boxes3d_1st

        else:
            input_data2['roi_boxes3d'] = pred_boxes3d_1st
            #input_data2['roi_boxes3d']=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            roi = pred_boxes3d_1st
            #roi=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            pts_input_2 = self.roipooling(input_data2)

        xyz_2, features_2 = self._break_up_pc(pts_input_2)
        #print(xyz_2.size(),xyz.size(),features_2.size(),features.size())
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_2 = pts_input_2[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_2 = self.xyz_up_layer(xyz_input_2)

            rpn_feature_2 = pts_input_2[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_2 = torch.cat((xyz_feature_2, rpn_feature_2), dim=1)
            merged_feature_2 = self.merge_down_layer(merged_feature_2)
            l_xyz_2, l_features_2 = [xyz_2], [merged_feature_2.squeeze(dim=3)]
        else:
            l_xyz__2, l_features_2 = [xyz_2], [features_2]
        #print(l_xyz_2[0].size(), l_xyz[0].size(), l_features_2[0].size(), l_features[0].size())
        for i in range(len(self.SA_modules)):
            li_xyz_2, li_features_2 = self.SA_modules[i](l_xyz_2[i],
                                                         l_features_2[i])
            l_xyz_2.append(li_xyz_2)
            l_features_2.append(li_features_2)

        batch_size_2 = pts_input_2.shape[0]
        anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
        rcnn_cls_2nd = self.cls_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_2nd = self.reg_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        pre_iou2 = self.iou_layer(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)
        #loss

        if self.training:
            cls_label = target_dict_2nd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_2nd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat),
                cls_label.view(-1),
                reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_2nd['gt_of_rois']
            reg_valid_mask = target_dict_2nd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #print(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0))
            if rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            two = {
                'rcnn_loss_cls_2nd': rcnn_loss_cls,
                'rcnn_loss_reg_2nd': rcnn_loss_reg
            }

        else:
            two = {}

        sec = {'rcnn_cls_2nd': rcnn_cls_2nd, 'rcnn_reg_2nd': rcnn_reg_2nd}
        #print(input_data['roi_boxes3d'].shape,input_data2['roi_boxes3d'].shape)

        pred_boxes3d_2nd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_2nd.view(-1, rcnn_reg_2nd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        input_data3 = input_data.copy()
        if self.training:

            input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            # print(input_data3['roi_boxes3d'].shape)
            with torch.no_grad():
                target_dict_3rd = self.proposal_target_layer(input_data3,
                                                             stage=3)

            pts_input_3 = torch.cat((target_dict_3rd['sampled_pts'],
                                     target_dict_3rd['pts_feature']),
                                    dim=2)
            target_dict_3rd['pts_input'] = pts_input_3
            roi = target_dict_3rd['roi_boxes3d']
            #roi = pred_boxes3d_2nd
        else:
            input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            # input_data3['roi_boxes3d']=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            roi = pred_boxes3d_2nd
            # roi=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            pts_input_3 = self.roipooling(input_data3)
        xyz_3, features_3 = self._break_up_pc(pts_input_3)

        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_3 = pts_input_3[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_3 = self.xyz_up_layer_3(xyz_input_3)

            rpn_feature_3 = pts_input_3[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_3 = torch.cat((xyz_feature_3, rpn_feature_3), dim=1)
            merged_feature_3 = self.merge_down_layer_3(merged_feature_3)
            l_xyz_3, l_features_3 = [xyz_3], [merged_feature_3.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz_3], [features_3]

        for i in range(len(self.SA_modules_3)):
            li_xyz_3, li_features_3 = self.SA_modules_3[i](l_xyz_3[i],
                                                           l_features_3[i])
            l_xyz_3.append(li_xyz_3)
            l_features_3.append(li_features_3)
        del xyz_2, features_2, l_features_2
        rcnn_cls_3rd = self.cls_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_3rd = self.reg_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        pre_iou3 = self.iou_layer(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)
        # loss
        if self.training:
            cls_label = target_dict_3rd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_3rd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_3rd['gt_of_rois']
            reg_valid_mask = target_dict_3rd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)

            if rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            # three = {'rcnn_loss_cls_3rd': rcnn_loss_cls, 'rcnn_loss_reg_3rd': rcnn_loss_reg}

        else:
            three = {}
        pred_boxes3d_3rd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_3rd.view(-1, rcnn_reg_3rd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        if self.training:
            gt = target_dict_3rd['real_gt']
            iou_label = []
            for i in range(batch_size_2):
                iou_label.append(
                    iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d_3rd.view(-1, 7)[i].view(1, 7),
                        gt[i].view(1, 7)))
            iou_label = torch.cat(iou_label)
            iou_label = (iou_label - 0.5) * 2
            iou_loss = F.mse_loss((pre_iou3[fg_mask]), iou_label[fg_mask])
            #print(iou_loss.item())
            three = {
                'rcnn_loss_cls_3rd': rcnn_loss_cls,
                'rcnn_loss_reg_3rd': rcnn_loss_reg,
                'rcnn_iou_loss': iou_loss
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask
        pre_iou3 = pre_iou3 / 2 + 0.5
        pre_iou2 = pre_iou2 / 2 + 0.5
        ret_dict = {
            'rcnn_cls_3rd': rcnn_cls_3rd,
            'rcnn_reg_3rd': rcnn_reg_3rd,
            'pred_boxes3d_1st': pred_boxes3d_1st,
            'pred_boxes3d_2nd': pred_boxes3d_2nd,
            'pred_boxes3d_3rd': pred_boxes3d_3rd,
            'pre_iou3': pre_iou3,
            'pre_iou2': pre_iou2
        }
        ret_dict.update(sec)
        ret_dict.update(two)
        ret_dict.update(three)

        return ret_dict
Exemple #4
0
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """

        if cfg.RCNN.ROI_SAMPLE_JIT:
            if self.training:
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(input_data,
                                                             stage=1)

                pts_input = torch.cat(
                    (target_dict['sampled_pts'], target_dict['pts_feature']),
                    dim=2)
                target_dict['pts_input'] = pts_input
            else:
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data[
                    'rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [
                        input_data['rpn_intensity'].unsqueeze(dim=2),
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]
                else:
                    pts_extra_input_list = [
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]

                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data['pts_depth'] / 70.0 - 0.5
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :,
                                    0:3] = kitti_utils.rotate_pc_along_y_torch(
                                        pooled_features[k, :, :, 0:3],
                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2],
                                                 pooled_features.shape[3])
        else:
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                #input_data['ori_roi'] = torch.cat((input_data['ori_roi'], input_data['roi_boxes3d']), 1)
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data[
                    'reg_valid_mask'].view(-1)
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']
        #print(pts_input.shape)
        pts_input = pts_input.view(-1, 512, 128 + self.rcnn_input_channel)
        xyz, features = self._break_up_pc(pts_input)
        anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(
                1, 2).unsqueeze(dim=3)
            #xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(1, 2)

            xyz_feature = self.xyz_up_layer(xyz_input)

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(
                1, 2).unsqueeze(dim=3)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1)
            merged_feature = self.merge_down_layer(merged_feature)
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):

            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)

        batch_size = input_data['roi_boxes3d'].size(0)
        batch_size_2 = pts_input.shape[0]  # for loss fun
        #print(input_data['roi_boxes3d'].shape,pts_input.shape)
        rcnn_cls = self.cls_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        if self.training:
            roi_boxes3d = target_dict['roi_boxes3d'].view(-1, 7)
            cls_label = target_dict['cls_label'].float()
            rcnn_cls_flat = rcnn_cls.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat),
                cls_label.view(-1),
                reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict['gt_of_rois']
            reg_valid_mask = target_dict['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #print(rcnn_reg.view(batch_size_2, -1)[fg_mask].shape)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            one = {
                'rcnn_loss_cls': rcnn_loss_cls,
                'rcnn_loss_reg': rcnn_loss_reg
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask
        else:
            roi_boxes3d = input_data['roi_boxes3d'].view(-1, 7)
            one = {}
        #print(rcnn_reg.size(),roi_boxes3d.size())
        #print(roi_boxes3d.shape, rcnn_reg.shape)
        pred_boxes3d_1st = decode_bbox_target(
            roi_boxes3d.view(-1, 7),
            rcnn_reg.view(-1, rcnn_reg.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        if self.training == False and cfg.RCNN.ENABLED and not cfg.RPN.ENABLED:
            pred_boxes3d_1st = pred_boxes3d_1st.view(-1, 7)

        input_data2 = input_data.copy()

        #print(input_data['roi_boxes3d'].size())
        if self.training:
            #input_data2['roi_boxes3d'] = torch.cat((pred_boxes3d_1st, input_data['ori_roi']), 1)
            input_data2['roi_boxes3d'] = torch.cat(
                (pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            #input_data2['roi_boxes3d'] = input_data['gt_boxes3d']
            #input_data2['roi_boxes3d'] = pred_boxes3d_1st
            #print(input_data2['roi_boxes3d'].shape)
            with torch.no_grad():
                target_dict_2nd = self.proposal_target_layer(input_data2,
                                                             stage=2)
            '''
            reg_valid_mask = target_dict_2nd['reg_valid_mask']
            fg_mask_num2 = (reg_valid_mask > 0).sum()
            if fg_mask_num2< 10*batch_size:
                input_data2['roi_boxes3d'] = torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
                with torch.no_grad():
                    target_dict_2nd = self.proposal_target_layer(input_data2, stage=2)
            '''
            pts_input_2 = torch.cat((target_dict_2nd['sampled_pts'],
                                     target_dict_2nd['pts_feature']),
                                    dim=2)
            target_dict_2nd['pts_input'] = pts_input_2
            roi = target_dict_2nd['roi_boxes3d']

        else:
            input_data2['roi_boxes3d'] = pred_boxes3d_1st
            #input_data2['roi_boxes3d']=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            roi = pred_boxes3d_1st
            #roi=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            pts_input_2 = self.roipooling(input_data2)
        #print(pts_input_2.shape)
        xyz_2, features_2 = self._break_up_pc(pts_input_2)
        #print(xyz_2.size(),xyz.size(),features_2.size(),features.size())
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_2 = pts_input_2[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_2 = self.xyz_up_layer(xyz_input_2)

            rpn_feature_2 = pts_input_2[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_2 = torch.cat((xyz_feature_2, rpn_feature_2), dim=1)
            merged_feature_2 = self.merge_down_layer(merged_feature_2)
            l_xyz_2, l_features_2 = [xyz_2], [merged_feature_2.squeeze(dim=3)]
        else:
            l_xyz__2, l_features_2 = [xyz_2], [features_2]
        #print(l_xyz_2[0].size(), l_xyz[0].size(), l_features_2[0].size(), l_features[0].size())
        for i in range(len(self.SA_modules)):
            li_xyz_2, li_features_2 = self.SA_modules[i](l_xyz_2[i],
                                                         l_features_2[i])
            l_xyz_2.append(li_xyz_2)
            l_features_2.append(li_features_2)
        del xyz, features, l_features

        rcnn_cls_2nd = self.cls_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_2nd = self.reg_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        #loss
        if self.training:
            cls_label = target_dict_2nd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_2nd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat),
                cls_label.view(-1),
                reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_2nd['gt_of_rois']
            reg_valid_mask = target_dict_2nd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #print(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0))
            if rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            two = {
                'rcnn_loss_cls_2nd': rcnn_loss_cls,
                'rcnn_loss_reg_2nd': rcnn_loss_reg
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask

        else:
            two = {}

        sec = {'rcnn_cls_2nd': rcnn_cls_2nd, 'rcnn_reg_2nd': rcnn_reg_2nd}
        #print(input_data['roi_boxes3d'].shape,input_data2['roi_boxes3d'].shape)

        pred_boxes3d_2nd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_2nd.view(-1, rcnn_reg_2nd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        ## 3rd
        #print(target_dict['roi_boxes3d'].shape,target_dict_2nd['roi_boxes3d'].shape)
        #print(pred_boxes3d_1st.shape,input_data['roi_boxes3d'].shape)
        #print(target_dict['gt_of_rois']+target_dict['roi_boxes3d'],target_dict_2nd['gt_of_rois']+target_dict_2nd['roi_boxes3d'])
        input_data3 = input_data2.copy()
        #del input_data2

        if self.training:
            input_data3['roi_boxes3d'] = torch.cat(
                (pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            #input_data3['roi_boxes3d'] = input_data2['gt_boxes3d']
            #input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            #print(input_data3['roi_boxes3d'].shape)
            with torch.no_grad():
                target_dict_3rd = self.proposal_target_layer(input_data3,
                                                             stage=3)
            '''
            reg_valid_mask = target_dict_3rd['reg_valid_mask']
            fg_mask_num3 = (reg_valid_mask > 0).sum()
            
            if fg_mask_num3.item() < 10 * batch_size:
                input_data3['roi_boxes3d'] = torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
                with torch.no_grad():
                    target_dict_3rd = self.proposal_target_layer(input_data2, stage=3)
            '''
            #print(fg_mask_num2.item(),fg_mask_num3.item())
            pts_input_3 = torch.cat((target_dict_3rd['sampled_pts'],
                                     target_dict_3rd['pts_feature']),
                                    dim=2)
            target_dict_3rd['pts_input'] = pts_input_3
            roi = target_dict_3rd['roi_boxes3d']
        else:
            input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            #input_data3['roi_boxes3d']=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            roi = pred_boxes3d_2nd
            #roi=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            pts_input_3 = self.roipooling(input_data3)
        xyz_3, features_3 = self._break_up_pc(pts_input_3)

        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_3 = pts_input_3[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_3 = self.xyz_up_layer(xyz_input_3)

            rpn_feature_3 = pts_input_3[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_3 = torch.cat((xyz_feature_3, rpn_feature_3), dim=1)
            merged_feature_3 = self.merge_down_layer(merged_feature_3)
            l_xyz_3, l_features_3 = [xyz_3], [merged_feature_3.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz_3], [features_3]

        for i in range(len(self.SA_modules)):
            li_xyz_3, li_features_3 = self.SA_modules[i](l_xyz_3[i],
                                                         l_features_3[i])
            l_xyz_3.append(li_xyz_3)
            l_features_3.append(li_features_3)
        del xyz_2, features_2, l_features_2
        rcnn_cls_3rd = self.cls_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_3rd = self.reg_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)

        #loss
        if self.training:
            cls_label = target_dict_3rd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_3rd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_3rd['gt_of_rois']
            reg_valid_mask = target_dict_3rd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #cls_mask=(target_dict_3rd['cls_label']>0)
            #print(rcnn_reg_3rd.view(batch_size_2, -1)[cls_mask].size(0))
            #print(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0))
            if rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            three = {
                'rcnn_loss_cls_3rd': rcnn_loss_cls,
                'rcnn_loss_reg_3rd': rcnn_loss_reg
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask

        else:
            three = {}
        pred_boxes3d_3rd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_3rd.view(-1, rcnn_reg_3rd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        ret_dict = {
            'rcnn_cls': rcnn_cls,
            'rcnn_reg': rcnn_reg,
            'rcnn_cls_3rd': rcnn_cls_3rd,
            'rcnn_reg_3rd': rcnn_reg_3rd,
            'pred_boxes3d_1st': pred_boxes3d_1st,
            'pred_boxes3d_2nd': pred_boxes3d_2nd,
            'pred_boxes3d_3rd': pred_boxes3d_3rd
        }
        ret_dict.update(sec)
        ret_dict.update(one)
        ret_dict.update(two)
        ret_dict.update(three)
        if self.training:
            ret_dict.update(target_dict)
        return ret_dict
def get_rpn_loss(model,
                 rpn_cls,
                 rpn_reg,
                 rpn_cls_label,
                 rpn_reg_label,
                 tb_dict=None):
    ModelReturn = namedtuple("ModelReturn", ["loss", "tb_dict", "disp_dict"])
    MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()

    if isinstance(model, nn.DataParallel):
        rpn_cls_loss_func = model.module.rpn.rpn_cls_loss_func
    else:
        rpn_cls_loss_func = model.rpn.rpn_cls_loss_func

    rpn_cls_label_flat = rpn_cls_label.view(-1)
    rpn_cls_flat = rpn_cls.view(-1)
    fg_mask = rpn_cls_label_flat > 0

    # RPN classification loss
    if cfg.RPN.LOSS_CLS == "DiceLoss":
        rpn_loss_cls = rpn_cls_loss_func(rpn_cls, rpn_cls_label_flat)

    elif cfg.RPN.LOSS_CLS == "SigmoidFocalLoss":
        rpn_cls_target = (rpn_cls_label_flat > 0).float()
        pos = (rpn_cls_label_flat > 0).float()
        neg = (rpn_cls_label_flat == 0).float()
        cls_weights = pos + neg
        pos_normalizer = pos.sum()
        cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0)
        rpn_loss_cls = rpn_cls_loss_func(rpn_cls_flat, rpn_cls_target,
                                         cls_weights)
        rpn_loss_cls_pos = (rpn_loss_cls * pos).sum()
        rpn_loss_cls_neg = (rpn_loss_cls * neg).sum()
        rpn_loss_cls = rpn_loss_cls.sum()
        if tb_dict is not None:
            tb_dict["rpn_loss_cls_pos"] = rpn_loss_cls_pos.item()
            tb_dict["rpn_loss_cls_neg"] = rpn_loss_cls_neg.item()

    elif cfg.RPN.LOSS_CLS == "BinaryCrossEntropy":
        weight = rpn_cls_flat.new(rpn_cls_flat.shape[0]).fill_(1.0)
        weight[fg_mask] = cfg.RPN.FG_WEIGHT
        rpn_cls_label_target = (rpn_cls_label_flat > 0).float()
        batch_loss_cls = F.binary_cross_entropy(
            torch.sigmoid(rpn_cls_flat),
            rpn_cls_label_target,
            weight=weight,
            reduction="none",
        )
        cls_valid_mask = (rpn_cls_label_flat >= 0).float()
        rpn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(
            cls_valid_mask.sum(), min=1.0)
    else:
        raise NotImplementedError

    # RPN regression loss
    point_num = rpn_reg.size(0) * rpn_reg.size(1)
    fg_sum = fg_mask.long().sum().item()
    if fg_sum != 0:
        loss_loc, loss_angle, loss_size, reg_loss_dict = loss_utils.get_reg_loss(
            rpn_reg.view(point_num, -1)[fg_mask],
            rpn_reg_label.view(point_num, 7)[fg_mask],
            loc_scope=cfg.RPN.LOC_SCOPE,
            loc_bin_size=cfg.RPN.LOC_BIN_SIZE,
            num_head_bin=cfg.RPN.NUM_HEAD_BIN,
            anchor_size=MEAN_SIZE,
            get_xz_fine=cfg.RPN.LOC_XZ_FINE,
            get_y_by_bin=False,
            get_ry_fine=False,
        )

        loss_size = 3 * loss_size  # consistent with old codes
        rpn_loss_reg = loss_loc + loss_angle + loss_size
    else:
        loss_loc = loss_angle = loss_size = rpn_loss_reg = rpn_loss_cls * 0

    rpn_loss = (rpn_loss_cls * cfg.RPN.LOSS_WEIGHT[0] +
                rpn_loss_reg * cfg.RPN.LOSS_WEIGHT[1])

    if tb_dict is not None:
        tb_dict.update({
            "rpn_loss_cls": rpn_loss_cls.item(),
            "rpn_loss_reg": rpn_loss_reg.item(),
            "rpn_loss": rpn_loss.item(),
            "rpn_fg_sum": fg_sum,
            "rpn_loss_loc": loss_loc.item(),
            "rpn_loss_angle": loss_angle.item(),
            "rpn_loss_size": loss_size.item(),
        })

    return rpn_loss
    def get_rcnn_loss(model, ret_dict, tb_dict):
        rcnn_cls, rcnn_reg = ret_dict['rcnn_cls'], ret_dict['rcnn_reg']

        cls_label = ret_dict['cls_label'].float()  #### cls_label process
        reg_valid_mask = ret_dict['reg_valid_mask']
        roi_boxes3d = ret_dict['roi_boxes3d']
        roi_size = roi_boxes3d[:, 3:6]
        gt_boxes3d_ct = ret_dict['gt_of_rois']
        pts_input = ret_dict['pts_input']

        # rcnn classification loss
        if isinstance(model, nn.DataParallel):
            cls_loss_func = model.module.rcnn_net.cls_loss_func
        else:
            cls_loss_func = model.rcnn_net.cls_loss_func
        # print("cls_label",cls_label) #### -1, 0, 1로 이루어진 tensor 256개
        # print("cls_label_size",cls_label.size()) #### torch.size([256])
        cls_label_flat = cls_label.view(-1)

        if cfg.RCNN.LOSS_CLS == 'SigmoidFocalLoss':
            rcnn_cls_flat = rcnn_cls.view(-1)

            cls_target = (cls_label_flat > 0).float()
            pos = (cls_label_flat > 0).float()
            neg = (cls_label_flat == 0).float()
            cls_weights = pos + neg
            pos_normalizer = pos.sum()
            cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0)

            rcnn_loss_cls = cls_loss_func(rcnn_cls_flat, cls_target,
                                          cls_weights)
            rcnn_loss_cls_pos = (rcnn_loss_cls * pos).sum()
            rcnn_loss_cls_neg = (rcnn_loss_cls * neg).sum()
            rcnn_loss_cls = rcnn_loss_cls.sum()
            tb_dict['rpn_loss_cls_pos'] = rcnn_loss_cls_pos.item()
            tb_dict['rpn_loss_cls_neg'] = rcnn_loss_cls_neg.item()

        elif cfg.RCNN.LOSS_CLS == 'BinaryCrossEntropy':
            rcnn_cls_flat = rcnn_cls.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)

        elif cfg.RCNN.LOSS_CLS == 'CrossEntropy':  #### TRAIN -> RCNN
            # elif cfg.TRAIN.LOSS_CLS == 'CrossEntropy':
            # print(rcnn_cls.size()) #### torch.size([256,4])
            # tensor([[ 0.0186, -0.0566, -0.0374, -0.0273],
            #         [-0.0119, -0.0458, -0.0206, -0.0464],
            #         [-0.0035, -0.0503, -0.0334, -0.0135],
            #         ...,
            #         [ 0.0098, -0.0219, -0.0139, -0.0330],
            #         [ 0.0182, -0.0153,  0.0086, -0.0376],
            #         [ 0.0071, -0.0278,  0.0146, -0.0302]], device='cuda:0',
            # print(rcnn_cls_reshape.size()) #### torch.size([256,4])
            # rcnn_cls_reshape = rcnn_cls.view(rcnn_cls.shape[0], -1).sum(dim=1) #### choose sum / mean
            # rcnn_cls_reshape = rcnn_cls.view(-1)

            rcnn_cls_reshape = rcnn_cls.view(rcnn_cls.shape[0], -1)
            cls_target = cls_label_flat.long()

            # print("cls_target", cls_target) # print(cls_target.size()) #### torch.size([256])
            # tensor([ 1, -1,  1,  1, -1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,
            #         1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  0,  0,  0,  0,
            #         0, -1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
            #         0,  0,  0, -1, -1, -1,  0, -1, -1, -1, -1, -1,  1,  1, -1,  1,  1, -1,
            #         1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1,
            #         1, -1,  1, -1,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
            #         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
            #         0,  0, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1, -1,
            #         1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  0,  0,
            #         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
            #         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,
            #         -1,  1,  1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,
            #         -1,  1, -1,  1, -1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
            #         0,  0,  0,  0,  0,  0, -1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
            #         0,  0,  0,  0], device='cuda:0')
            cls_valid_mask = (cls_target >=
                              0).float()  #### -1 : invalid ???????

            cls_target_final = torch.zeros(rcnn_cls_reshape.shape[0],
                                           rcnn_cls_reshape.shape[1])
            # print(cls_target_final.size()) # size([256,4])
            for i in range(cls_target_final.shape[0]):
                if cls_target[i] == -1:
                    cls_target[i] = 0
            #    cls_target_final[i, cls_target[i].cpu().numpy()] = 1 #### class가 0,1,2,3,.. 한 줄 이어야 한다...
            # print("cls_target_final", cls_target) # size([256,4]) #### too many value 1...
            cls_target_final = cls_target_final.cuda()
            cls_target_final = cls_target_final.long()
            # print("cls_target_final", cls_target_final) # size([256,4])

            #### cls_target = cls_target.unsqueeze(1) #### size([256,1])
            #### cls_target = torch.cat((cls_target, cls_target, cls_target, cls_target), 1) #### size([256,4])
            #### cls_target = cls_target.view(-1)

            # print("rcnn_cls_reshape", rcnn_cls_reshape) #### size([256,4])
            # print(cls_target.size()) #### size([256])
            batch_loss_cls = cls_loss_func(rcnn_cls_reshape,
                                           cls_target)  #### loss calculation
            ## batch_loss_cls = cls_loss_func(rcnn_cls_reshape, cls_target) #### loss calculation
            # print(batch_loss_cls.size()) #### size([256])
            normalizer = torch.clamp(cls_valid_mask.sum(), min=1.0)
            rcnn_loss_cls = (batch_loss_cls.mean(dim=0) *
                             cls_valid_mask).sum() / normalizer
            # rcnn_loss_cls = (batch_loss_cls.mean(dim=1) * cls_valid_mask).sum() / normalizer
            #### why the writer misunderstand dimension 0 and 1?

        else:
            raise NotImplementedError

        # rcnn regression loss
        batch_size = pts_input.shape[0]
        fg_mask = (reg_valid_mask > 0)
        fg_sum = fg_mask.long().sum().item()
        if fg_sum != 0:
            all_anchor_size = roi_size
            anchor_size = all_anchor_size[
                fg_mask] if cfg.RCNN.SIZE_RES_ON_ROI else MEAN_SIZE

            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg.view(batch_size, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size, 8)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)

            loss_size = 3 * loss_size  # consistent with old codes
            rcnn_loss_reg = loss_loc + loss_angle + loss_size
            tb_dict.update(reg_loss_dict)
        else:
            loss_loc = loss_angle = loss_size = rcnn_loss_reg = rcnn_loss_cls * 0

        rcnn_loss = rcnn_loss_cls + rcnn_loss_reg
        tb_dict['rcnn_loss_cls'] = rcnn_loss_cls.item()
        tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item()
        tb_dict['rcnn_loss'] = rcnn_loss.item()

        tb_dict['rcnn_loss_loc'] = loss_loc.item()
        tb_dict['rcnn_loss_angle'] = loss_angle.item()
        tb_dict['rcnn_loss_size'] = loss_size.item()
        # fg : foreground, bg : background
        tb_dict['rcnn_cls_fg'] = (cls_label > 0).sum().item()
        tb_dict['rcnn_cls_bg'] = (cls_label == 0).sum().item()
        tb_dict['rcnn_reg_fg'] = reg_valid_mask.sum().item()

        return rcnn_loss
    def get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label,
                     tb_dict):
        if isinstance(model, nn.DataParallel):
            rpn_cls_loss_func = model.module.rpn.rpn_cls_loss_func
        else:
            rpn_cls_loss_func = model.rpn.rpn_cls_loss_func
        #model.rpn.rpn_cls_loss_func is defined in lib/net/rpn.py

        rpn_cls_label_flat = rpn_cls_label.view(-1)
        rpn_cls_flat = rpn_cls.view(-1)
        fg_mask = (rpn_cls_label_flat > 0)

        # RPN classification loss
        if cfg.RPN.LOSS_CLS == 'DiceLoss':
            rpn_loss_cls = rpn_cls_loss_func(rpn_cls, rpn_cls_label_flat)

        elif cfg.RPN.LOSS_CLS == 'SigmoidFocalLoss':
            rpn_cls_target = (rpn_cls_label_flat > 0).float()
            pos = (rpn_cls_label_flat > 0).float()
            neg = (rpn_cls_label_flat == 0).float()
            cls_weights = pos + neg
            pos_normalizer = pos.sum()
            cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0)
            rpn_loss_cls = rpn_cls_loss_func(rpn_cls_flat, rpn_cls_target,
                                             cls_weights)
            rpn_loss_cls_pos = (rpn_loss_cls * pos).sum()
            rpn_loss_cls_neg = (rpn_loss_cls * neg).sum()
            rpn_loss_cls = rpn_loss_cls.sum()
            tb_dict['rpn_loss_cls_pos'] = rpn_loss_cls_pos.item()
            tb_dict['rpn_loss_cls_neg'] = rpn_loss_cls_neg.item()

        elif cfg.RPN.LOSS_CLS == 'BinaryCrossEntropy':
            weight = rpn_cls_flat.new(rpn_cls_flat.shape[0]).fill_(1.0)
            weight[fg_mask] = cfg.RPN.FG_WEIGHT
            rpn_cls_label_target = (rpn_cls_label_flat > 0).float()
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rpn_cls_flat),
                rpn_cls_label_target,
                weight=weight,
                reduction='none')
            cls_valid_mask = (rpn_cls_label_flat >= 0).float()
            rpn_loss_cls = (batch_loss_cls *
                            cls_valid_mask).sum() / torch.clamp(
                                cls_valid_mask.sum(), min=1.0)
        else:
            raise NotImplementedError

        # RPN regression loss
        point_num = rpn_reg.size(0) * rpn_reg.size(
            1)  # notice rpn_reg is a batch (multiple scenes)

        # total number of points (in all scenes) that are inside a box
        fg_sum = fg_mask.long().sum().item()
        if fg_sum != 0:
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rpn_reg.view(point_num, -1)[fg_mask], # we just make a list of points (scene number doesn't matter) we already have the labels for each point
                                        # we only regress the points that are labeled inside a box (points outside bboxes have all features 0 no point in regressing)
                                        rpn_reg_label.view(point_num, 9)[fg_mask],
                                        loc_scope=cfg.RPN.LOC_SCOPE,
                                        loc_bin_size=cfg.RPN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RPN.NUM_HEAD_BIN,
                                        anchor_size=MEAN_SIZE,
                                        get_xz_fine=cfg.RPN.LOC_XZ_FINE,
                                        get_y_by_bin=False,
                                        get_ry_fine=False)

            #loss_size = 3 * loss_size  # consistent with old codes
            rpn_loss_reg = loss_loc + loss_angle + loss_size
        else:
            loss_loc = loss_angle = loss_size = rpn_loss_reg = rpn_loss_cls * 0

        # we just sum them ! LOSS_WEIGHT: [1.0, 1.0]
        rpn_loss = rpn_loss_cls * cfg.RPN.LOSS_WEIGHT[
            0] + rpn_loss_reg * cfg.RPN.LOSS_WEIGHT[1]

        tb_dict.update({
            'rpn_loss_cls': rpn_loss_cls.item(),
            'rpn_loss_reg': rpn_loss_reg.item(),
            'rpn_loss': rpn_loss.item(),
            'rpn_fg_sum': fg_sum,
            'rpn_loss_loc': loss_loc.item(),
            'rpn_loss_angle': loss_angle.item(),
            'rpn_loss_size': loss_size.item()
        })

        return rpn_loss