def get_rcnn_loss(model, ret_dict, tb_dict): rcnn_cls, rcnn_reg = ret_dict['rcnn_cls'], ret_dict['rcnn_reg'] cls_label = ret_dict['cls_label'].float() reg_valid_mask = ret_dict['reg_valid_mask'] roi_boxes3d = ret_dict['roi_boxes3d'] roi_size = roi_boxes3d[:, 3:6] gt_boxes3d_ct = ret_dict['gt_of_rois'] pts_input = ret_dict['pts_input'] # rcnn classification loss if isinstance(model, nn.DataParallel): cls_loss_func = model.module.rcnn_net.cls_loss_func else: cls_loss_func = model.rcnn_net.cls_loss_func cls_label_flat = cls_label.view(-1) if cfg.RCNN.LOSS_CLS == 'SigmoidFocalLoss': rcnn_cls_flat = rcnn_cls.view(-1) cls_target = (cls_label_flat > 0).float() pos = (cls_label_flat > 0).float() neg = (cls_label_flat == 0).float() cls_weights = pos + neg pos_normalizer = pos.sum() cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0) rcnn_loss_cls = cls_loss_func(rcnn_cls_flat, cls_target, cls_weights) rcnn_loss_cls_pos = (rcnn_loss_cls * pos).sum() rcnn_loss_cls_neg = (rcnn_loss_cls * neg).sum() rcnn_loss_cls = rcnn_loss_cls.sum() tb_dict['rpn_loss_cls_pos'] = rcnn_loss_cls_pos.item() tb_dict['rpn_loss_cls_neg'] = rcnn_loss_cls_neg.item() elif cfg.RCNN.LOSS_CLS == 'BinaryCrossEntropy': rcnn_cls_flat = rcnn_cls.view(-1) batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none') cls_valid_mask = (cls_label_flat >= 0).float() rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) elif cfg.TRAIN.LOSS_CLS == 'CrossEntropy': rcnn_cls_reshape = rcnn_cls.view(rcnn_cls.shape[0], -1) cls_target = cls_label_flat.long() cls_valid_mask = (cls_label_flat >= 0).float() batch_loss_cls = cls_loss_func(rcnn_cls_reshape, cls_target) normalizer = torch.clamp(cls_valid_mask.sum(), min=1.0) rcnn_loss_cls = (batch_loss_cls.mean(dim=1) * cls_valid_mask).sum() / normalizer else: raise NotImplementedError # rcnn regression loss batch_size = pts_input.shape[0] fg_mask = (reg_valid_mask > 0) fg_sum = fg_mask.long().sum().item() if fg_sum != 0: all_anchor_size = roi_size anchor_size = all_anchor_size[ fg_mask] if cfg.RCNN.SIZE_RES_ON_ROI else MEAN_SIZE loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rcnn_reg.view(batch_size, -1)[fg_mask], gt_boxes3d_ct.view(batch_size, 7)[fg_mask], loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, anchor_size=anchor_size, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True) loss_size = 3 * loss_size # consistent with old codes rcnn_loss_reg = loss_loc + loss_angle + loss_size tb_dict.update(reg_loss_dict) else: loss_loc = loss_angle = loss_size = rcnn_loss_reg = rcnn_loss_cls * 0 rcnn_loss = rcnn_loss_cls + rcnn_loss_reg tb_dict['rcnn_loss_cls'] = rcnn_loss_cls.item() tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item() tb_dict['rcnn_loss'] = rcnn_loss.item() tb_dict['rcnn_loss_loc'] = loss_loc.item() tb_dict['rcnn_loss_angle'] = loss_angle.item() tb_dict['rcnn_loss_size'] = loss_size.item() tb_dict['rcnn_cls_fg'] = (cls_label > 0).sum().item() tb_dict['rcnn_cls_bg'] = (cls_label == 0).sum().item() tb_dict['rcnn_reg_fg'] = reg_valid_mask.sum().item() return rcnn_loss
def get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label, tb_dict): if isinstance(model, nn.DataParallel): rpn_cls_loss_func = model.module.rpn.rpn_cls_loss_func else: rpn_cls_loss_func = model.rpn.rpn_cls_loss_func rpn_cls_label_flat = rpn_cls_label.view(-1) rpn_cls_flat = rpn_cls.view(-1) fg_mask = (rpn_cls_label_flat > 0) # RPN classification loss if cfg.RPN.LOSS_CLS == 'DiceLoss': rpn_loss_cls = rpn_cls_loss_func(rpn_cls, rpn_cls_label_flat) elif cfg.RPN.LOSS_CLS == 'SigmoidFocalLoss': rpn_cls_target = (rpn_cls_label_flat > 0).float() pos = (rpn_cls_label_flat > 0).float() neg = (rpn_cls_label_flat == 0).float() cls_weights = pos + neg pos_normalizer = pos.sum() cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0) rpn_loss_cls = rpn_cls_loss_func(rpn_cls_flat, rpn_cls_target, cls_weights) rpn_loss_cls_pos = (rpn_loss_cls * pos).sum() rpn_loss_cls_neg = (rpn_loss_cls * neg).sum() rpn_loss_cls = rpn_loss_cls.sum() tb_dict['rpn_loss_cls_pos'] = rpn_loss_cls_pos.item() tb_dict['rpn_loss_cls_neg'] = rpn_loss_cls_neg.item() elif cfg.RPN.LOSS_CLS == 'BinaryCrossEntropy': weight = rpn_cls_flat.new(rpn_cls_flat.shape[0]).fill_(1.0) weight[fg_mask] = cfg.RPN.FG_WEIGHT rpn_cls_label_target = (rpn_cls_label_flat > 0).float() batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rpn_cls_flat), rpn_cls_label_target, weight=weight, reduction='none') cls_valid_mask = (rpn_cls_label_flat >= 0).float() rpn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) else: raise NotImplementedError # RPN regression loss point_num = rpn_reg.size(0) * rpn_reg.size(1) fg_sum = fg_mask.long().sum().item() if fg_sum != 0: loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rpn_reg.view(point_num, -1)[fg_mask], rpn_reg_label.view(point_num, 7)[fg_mask], loc_scope=cfg.RPN.LOC_SCOPE, loc_bin_size=cfg.RPN.LOC_BIN_SIZE, num_head_bin=cfg.RPN.NUM_HEAD_BIN, anchor_size=MEAN_SIZE, get_xz_fine=cfg.RPN.LOC_XZ_FINE, get_y_by_bin=False, get_ry_fine=False) loss_size = 3 * loss_size # consistent with old codes rpn_loss_reg = loss_loc + loss_angle + loss_size else: loss_loc = loss_angle = loss_size = rpn_loss_reg = rpn_loss_cls * 0 rpn_loss = rpn_loss_cls * cfg.RPN.LOSS_WEIGHT[ 0] + rpn_loss_reg * cfg.RPN.LOSS_WEIGHT[1] tb_dict.update({ 'rpn_loss_cls': rpn_loss_cls.item(), 'rpn_loss_reg': rpn_loss_reg.item(), 'rpn_loss': rpn_loss.item(), 'rpn_fg_sum': fg_sum, 'rpn_loss_loc': loss_loc.item(), 'rpn_loss_angle': loss_angle.item(), 'rpn_loss_size': loss_size.item() }) return rpn_loss
def forward(self, input_data): """ :param input_data: input dict :return: """ input_data2 = input_data.copy() pred_boxes3d_1st = input_data2['pred_boxes3d_1st'] ret_dict = {} batch_size = input_data['roi_boxes3d'].size(0) if self.training: input_data2['roi_boxes3d'] = pred_boxes3d_1st with torch.no_grad(): target_dict_2nd = self.proposal_target_layer(input_data2, stage=2) pts_input_2 = torch.cat((target_dict_2nd['sampled_pts'], target_dict_2nd['pts_feature']), dim=2) target_dict_2nd['pts_input'] = pts_input_2 roi = target_dict_2nd['roi_boxes3d'] #roi = pred_boxes3d_1st else: input_data2['roi_boxes3d'] = pred_boxes3d_1st #input_data2['roi_boxes3d']=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1) roi = pred_boxes3d_1st #roi=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1) pts_input_2 = self.roipooling(input_data2) xyz_2, features_2 = self._break_up_pc(pts_input_2) #print(xyz_2.size(),xyz.size(),features_2.size(),features.size()) if cfg.RCNN.USE_RPN_FEATURES: xyz_input_2 = pts_input_2[..., 0:self.rcnn_input_channel].transpose( 1, 2).unsqueeze(dim=3) xyz_feature_2 = self.xyz_up_layer(xyz_input_2) rpn_feature_2 = pts_input_2[..., self.rcnn_input_channel:].transpose( 1, 2).unsqueeze(dim=3) merged_feature_2 = torch.cat((xyz_feature_2, rpn_feature_2), dim=1) merged_feature_2 = self.merge_down_layer(merged_feature_2) l_xyz_2, l_features_2 = [xyz_2], [merged_feature_2.squeeze(dim=3)] else: l_xyz__2, l_features_2 = [xyz_2], [features_2] #print(l_xyz_2[0].size(), l_xyz[0].size(), l_features_2[0].size(), l_features[0].size()) for i in range(len(self.SA_modules)): li_xyz_2, li_features_2 = self.SA_modules[i](l_xyz_2[i], l_features_2[i]) l_xyz_2.append(li_xyz_2) l_features_2.append(li_features_2) batch_size_2 = pts_input_2.shape[0] anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda() rcnn_cls_2nd = self.cls_layer_2nd(l_features_2[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, 1 or 2) rcnn_reg_2nd = self.reg_layer_2nd(l_features_2[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, C) pre_iou2 = self.iou_layer(l_features_2[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) #loss if self.training: cls_label = target_dict_2nd['cls_label'].float() rcnn_cls_flat = rcnn_cls_2nd.view(-1) batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rcnn_cls_flat), cls_label.view(-1), reduction='none') cls_label_flat = cls_label.view(-1) cls_valid_mask = (cls_label_flat >= 0).float() rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) gt_boxes3d_ct = target_dict_2nd['gt_of_rois'] reg_valid_mask = target_dict_2nd['reg_valid_mask'] fg_mask = (reg_valid_mask > 0) #print(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0)) if rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0) == 0: fg_mask = (reg_valid_mask <= 0) loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask], gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask], loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, anchor_size=anchor_size, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True) rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size two = { 'rcnn_loss_cls_2nd': rcnn_loss_cls, 'rcnn_loss_reg_2nd': rcnn_loss_reg } else: two = {} sec = {'rcnn_cls_2nd': rcnn_cls_2nd, 'rcnn_reg_2nd': rcnn_reg_2nd} #print(input_data['roi_boxes3d'].shape,input_data2['roi_boxes3d'].shape) pred_boxes3d_2nd = decode_bbox_target( roi.view(-1, 7), rcnn_reg_2nd.view(-1, rcnn_reg_2nd.shape[-1]), anchor_size=anchor_size, loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True).view(batch_size, -1, 7) input_data3 = input_data.copy() if self.training: input_data3['roi_boxes3d'] = pred_boxes3d_2nd # print(input_data3['roi_boxes3d'].shape) with torch.no_grad(): target_dict_3rd = self.proposal_target_layer(input_data3, stage=3) pts_input_3 = torch.cat((target_dict_3rd['sampled_pts'], target_dict_3rd['pts_feature']), dim=2) target_dict_3rd['pts_input'] = pts_input_3 roi = target_dict_3rd['roi_boxes3d'] #roi = pred_boxes3d_2nd else: input_data3['roi_boxes3d'] = pred_boxes3d_2nd # input_data3['roi_boxes3d']=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1) roi = pred_boxes3d_2nd # roi=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1) pts_input_3 = self.roipooling(input_data3) xyz_3, features_3 = self._break_up_pc(pts_input_3) if cfg.RCNN.USE_RPN_FEATURES: xyz_input_3 = pts_input_3[..., 0:self.rcnn_input_channel].transpose( 1, 2).unsqueeze(dim=3) xyz_feature_3 = self.xyz_up_layer_3(xyz_input_3) rpn_feature_3 = pts_input_3[..., self.rcnn_input_channel:].transpose( 1, 2).unsqueeze(dim=3) merged_feature_3 = torch.cat((xyz_feature_3, rpn_feature_3), dim=1) merged_feature_3 = self.merge_down_layer_3(merged_feature_3) l_xyz_3, l_features_3 = [xyz_3], [merged_feature_3.squeeze(dim=3)] else: l_xyz, l_features = [xyz_3], [features_3] for i in range(len(self.SA_modules_3)): li_xyz_3, li_features_3 = self.SA_modules_3[i](l_xyz_3[i], l_features_3[i]) l_xyz_3.append(li_xyz_3) l_features_3.append(li_features_3) del xyz_2, features_2, l_features_2 rcnn_cls_3rd = self.cls_layer_3rd(l_features_3[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, 1 or 2) rcnn_reg_3rd = self.reg_layer_3rd(l_features_3[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, C) pre_iou3 = self.iou_layer(l_features_3[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # loss if self.training: cls_label = target_dict_3rd['cls_label'].float() rcnn_cls_flat = rcnn_cls_3rd.view(-1) batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none') cls_label_flat = cls_label.view(-1) cls_valid_mask = (cls_label_flat >= 0).float() rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) gt_boxes3d_ct = target_dict_3rd['gt_of_rois'] reg_valid_mask = target_dict_3rd['reg_valid_mask'] fg_mask = (reg_valid_mask > 0) if rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0) == 0: fg_mask = (reg_valid_mask <= 0) loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask], gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask], loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, anchor_size=anchor_size, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True) rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size # three = {'rcnn_loss_cls_3rd': rcnn_loss_cls, 'rcnn_loss_reg_3rd': rcnn_loss_reg} else: three = {} pred_boxes3d_3rd = decode_bbox_target( roi.view(-1, 7), rcnn_reg_3rd.view(-1, rcnn_reg_3rd.shape[-1]), anchor_size=anchor_size, loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True).view(batch_size, -1, 7) if self.training: gt = target_dict_3rd['real_gt'] iou_label = [] for i in range(batch_size_2): iou_label.append( iou3d_utils.boxes_iou3d_gpu( pred_boxes3d_3rd.view(-1, 7)[i].view(1, 7), gt[i].view(1, 7))) iou_label = torch.cat(iou_label) iou_label = (iou_label - 0.5) * 2 iou_loss = F.mse_loss((pre_iou3[fg_mask]), iou_label[fg_mask]) #print(iou_loss.item()) three = { 'rcnn_loss_cls_3rd': rcnn_loss_cls, 'rcnn_loss_reg_3rd': rcnn_loss_reg, 'rcnn_iou_loss': iou_loss } del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask pre_iou3 = pre_iou3 / 2 + 0.5 pre_iou2 = pre_iou2 / 2 + 0.5 ret_dict = { 'rcnn_cls_3rd': rcnn_cls_3rd, 'rcnn_reg_3rd': rcnn_reg_3rd, 'pred_boxes3d_1st': pred_boxes3d_1st, 'pred_boxes3d_2nd': pred_boxes3d_2nd, 'pred_boxes3d_3rd': pred_boxes3d_3rd, 'pre_iou3': pre_iou3, 'pre_iou2': pre_iou2 } ret_dict.update(sec) ret_dict.update(two) ret_dict.update(three) return ret_dict
def forward(self, input_data): """ :param input_data: input dict :return: """ if cfg.RCNN.ROI_SAMPLE_JIT: if self.training: with torch.no_grad(): target_dict = self.proposal_target_layer(input_data, stage=1) pts_input = torch.cat( (target_dict['sampled_pts'], target_dict['pts_feature']), dim=2) target_dict['pts_input'] = pts_input else: rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data[ 'rpn_features'] batch_rois = input_data['roi_boxes3d'] if cfg.RCNN.USE_INTENSITY: pts_extra_input_list = [ input_data['rpn_intensity'].unsqueeze(dim=2), input_data['seg_mask'].unsqueeze(dim=2) ] else: pts_extra_input_list = [ input_data['seg_mask'].unsqueeze(dim=2) ] if cfg.RCNN.USE_DEPTH: pts_depth = input_data['pts_depth'] / 70.0 - 0.5 pts_extra_input_list.append(pts_depth.unsqueeze(dim=2)) pts_extra_input = torch.cat(pts_extra_input_list, dim=2) pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2) pooled_features, pooled_empty_flag = \ roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH, sampled_pt_num=cfg.RCNN.NUM_POINTS) # canonical transformation batch_size = batch_rois.shape[0] roi_center = batch_rois[:, :, 0:3] pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2) for k in range(batch_size): pooled_features[k, :, :, 0:3] = kitti_utils.rotate_pc_along_y_torch( pooled_features[k, :, :, 0:3], batch_rois[k, :, 6]) pts_input = pooled_features.view(-1, pooled_features.shape[2], pooled_features.shape[3]) else: pts_input = input_data['pts_input'] target_dict = {} target_dict['pts_input'] = input_data['pts_input'] target_dict['roi_boxes3d'] = input_data['roi_boxes3d'] if self.training: #input_data['ori_roi'] = torch.cat((input_data['ori_roi'], input_data['roi_boxes3d']), 1) target_dict['cls_label'] = input_data['cls_label'] target_dict['reg_valid_mask'] = input_data[ 'reg_valid_mask'].view(-1) target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct'] #print(pts_input.shape) pts_input = pts_input.view(-1, 512, 128 + self.rcnn_input_channel) xyz, features = self._break_up_pc(pts_input) anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda() if cfg.RCNN.USE_RPN_FEATURES: xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose( 1, 2).unsqueeze(dim=3) #xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(1, 2) xyz_feature = self.xyz_up_layer(xyz_input) rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose( 1, 2).unsqueeze(dim=3) merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1) merged_feature = self.merge_down_layer(merged_feature) l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)] else: l_xyz, l_features = [xyz], [features] for i in range(len(self.SA_modules)): li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) l_xyz.append(li_xyz) l_features.append(li_features) batch_size = input_data['roi_boxes3d'].size(0) batch_size_2 = pts_input.shape[0] # for loss fun #print(input_data['roi_boxes3d'].shape,pts_input.shape) rcnn_cls = self.cls_layer(l_features[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, 1 or 2) rcnn_reg = self.reg_layer(l_features[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, C) if self.training: roi_boxes3d = target_dict['roi_boxes3d'].view(-1, 7) cls_label = target_dict['cls_label'].float() rcnn_cls_flat = rcnn_cls.view(-1) batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rcnn_cls_flat), cls_label.view(-1), reduction='none') cls_label_flat = cls_label.view(-1) cls_valid_mask = (cls_label_flat >= 0).float() rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) gt_boxes3d_ct = target_dict['gt_of_rois'] reg_valid_mask = target_dict['reg_valid_mask'] fg_mask = (reg_valid_mask > 0) #print(rcnn_reg.view(batch_size_2, -1)[fg_mask].shape) loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rcnn_reg.view(batch_size_2, -1)[fg_mask], gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask], loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, anchor_size=anchor_size, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True) rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size one = { 'rcnn_loss_cls': rcnn_loss_cls, 'rcnn_loss_reg': rcnn_loss_reg } del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask else: roi_boxes3d = input_data['roi_boxes3d'].view(-1, 7) one = {} #print(rcnn_reg.size(),roi_boxes3d.size()) #print(roi_boxes3d.shape, rcnn_reg.shape) pred_boxes3d_1st = decode_bbox_target( roi_boxes3d.view(-1, 7), rcnn_reg.view(-1, rcnn_reg.shape[-1]), anchor_size=anchor_size, loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True).view(batch_size, -1, 7) if self.training == False and cfg.RCNN.ENABLED and not cfg.RPN.ENABLED: pred_boxes3d_1st = pred_boxes3d_1st.view(-1, 7) input_data2 = input_data.copy() #print(input_data['roi_boxes3d'].size()) if self.training: #input_data2['roi_boxes3d'] = torch.cat((pred_boxes3d_1st, input_data['ori_roi']), 1) input_data2['roi_boxes3d'] = torch.cat( (pred_boxes3d_1st, input_data['roi_boxes3d']), 1) #input_data2['roi_boxes3d'] = input_data['gt_boxes3d'] #input_data2['roi_boxes3d'] = pred_boxes3d_1st #print(input_data2['roi_boxes3d'].shape) with torch.no_grad(): target_dict_2nd = self.proposal_target_layer(input_data2, stage=2) ''' reg_valid_mask = target_dict_2nd['reg_valid_mask'] fg_mask_num2 = (reg_valid_mask > 0).sum() if fg_mask_num2< 10*batch_size: input_data2['roi_boxes3d'] = torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1) with torch.no_grad(): target_dict_2nd = self.proposal_target_layer(input_data2, stage=2) ''' pts_input_2 = torch.cat((target_dict_2nd['sampled_pts'], target_dict_2nd['pts_feature']), dim=2) target_dict_2nd['pts_input'] = pts_input_2 roi = target_dict_2nd['roi_boxes3d'] else: input_data2['roi_boxes3d'] = pred_boxes3d_1st #input_data2['roi_boxes3d']=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1) roi = pred_boxes3d_1st #roi=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1) pts_input_2 = self.roipooling(input_data2) #print(pts_input_2.shape) xyz_2, features_2 = self._break_up_pc(pts_input_2) #print(xyz_2.size(),xyz.size(),features_2.size(),features.size()) if cfg.RCNN.USE_RPN_FEATURES: xyz_input_2 = pts_input_2[..., 0:self.rcnn_input_channel].transpose( 1, 2).unsqueeze(dim=3) xyz_feature_2 = self.xyz_up_layer(xyz_input_2) rpn_feature_2 = pts_input_2[..., self.rcnn_input_channel:].transpose( 1, 2).unsqueeze(dim=3) merged_feature_2 = torch.cat((xyz_feature_2, rpn_feature_2), dim=1) merged_feature_2 = self.merge_down_layer(merged_feature_2) l_xyz_2, l_features_2 = [xyz_2], [merged_feature_2.squeeze(dim=3)] else: l_xyz__2, l_features_2 = [xyz_2], [features_2] #print(l_xyz_2[0].size(), l_xyz[0].size(), l_features_2[0].size(), l_features[0].size()) for i in range(len(self.SA_modules)): li_xyz_2, li_features_2 = self.SA_modules[i](l_xyz_2[i], l_features_2[i]) l_xyz_2.append(li_xyz_2) l_features_2.append(li_features_2) del xyz, features, l_features rcnn_cls_2nd = self.cls_layer_2nd(l_features_2[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, 1 or 2) rcnn_reg_2nd = self.reg_layer_2nd(l_features_2[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, C) #loss if self.training: cls_label = target_dict_2nd['cls_label'].float() rcnn_cls_flat = rcnn_cls_2nd.view(-1) batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rcnn_cls_flat), cls_label.view(-1), reduction='none') cls_label_flat = cls_label.view(-1) cls_valid_mask = (cls_label_flat >= 0).float() rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) gt_boxes3d_ct = target_dict_2nd['gt_of_rois'] reg_valid_mask = target_dict_2nd['reg_valid_mask'] fg_mask = (reg_valid_mask > 0) #print(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0)) if rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0) == 0: fg_mask = (reg_valid_mask <= 0) loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask], gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask], loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, anchor_size=anchor_size, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True) rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size two = { 'rcnn_loss_cls_2nd': rcnn_loss_cls, 'rcnn_loss_reg_2nd': rcnn_loss_reg } del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask else: two = {} sec = {'rcnn_cls_2nd': rcnn_cls_2nd, 'rcnn_reg_2nd': rcnn_reg_2nd} #print(input_data['roi_boxes3d'].shape,input_data2['roi_boxes3d'].shape) pred_boxes3d_2nd = decode_bbox_target( roi.view(-1, 7), rcnn_reg_2nd.view(-1, rcnn_reg_2nd.shape[-1]), anchor_size=anchor_size, loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True).view(batch_size, -1, 7) ## 3rd #print(target_dict['roi_boxes3d'].shape,target_dict_2nd['roi_boxes3d'].shape) #print(pred_boxes3d_1st.shape,input_data['roi_boxes3d'].shape) #print(target_dict['gt_of_rois']+target_dict['roi_boxes3d'],target_dict_2nd['gt_of_rois']+target_dict_2nd['roi_boxes3d']) input_data3 = input_data2.copy() #del input_data2 if self.training: input_data3['roi_boxes3d'] = torch.cat( (pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1) #input_data3['roi_boxes3d'] = input_data2['gt_boxes3d'] #input_data3['roi_boxes3d'] = pred_boxes3d_2nd #print(input_data3['roi_boxes3d'].shape) with torch.no_grad(): target_dict_3rd = self.proposal_target_layer(input_data3, stage=3) ''' reg_valid_mask = target_dict_3rd['reg_valid_mask'] fg_mask_num3 = (reg_valid_mask > 0).sum() if fg_mask_num3.item() < 10 * batch_size: input_data3['roi_boxes3d'] = torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1) with torch.no_grad(): target_dict_3rd = self.proposal_target_layer(input_data2, stage=3) ''' #print(fg_mask_num2.item(),fg_mask_num3.item()) pts_input_3 = torch.cat((target_dict_3rd['sampled_pts'], target_dict_3rd['pts_feature']), dim=2) target_dict_3rd['pts_input'] = pts_input_3 roi = target_dict_3rd['roi_boxes3d'] else: input_data3['roi_boxes3d'] = pred_boxes3d_2nd #input_data3['roi_boxes3d']=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1) roi = pred_boxes3d_2nd #roi=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1) pts_input_3 = self.roipooling(input_data3) xyz_3, features_3 = self._break_up_pc(pts_input_3) if cfg.RCNN.USE_RPN_FEATURES: xyz_input_3 = pts_input_3[..., 0:self.rcnn_input_channel].transpose( 1, 2).unsqueeze(dim=3) xyz_feature_3 = self.xyz_up_layer(xyz_input_3) rpn_feature_3 = pts_input_3[..., self.rcnn_input_channel:].transpose( 1, 2).unsqueeze(dim=3) merged_feature_3 = torch.cat((xyz_feature_3, rpn_feature_3), dim=1) merged_feature_3 = self.merge_down_layer(merged_feature_3) l_xyz_3, l_features_3 = [xyz_3], [merged_feature_3.squeeze(dim=3)] else: l_xyz, l_features = [xyz_3], [features_3] for i in range(len(self.SA_modules)): li_xyz_3, li_features_3 = self.SA_modules[i](l_xyz_3[i], l_features_3[i]) l_xyz_3.append(li_xyz_3) l_features_3.append(li_features_3) del xyz_2, features_2, l_features_2 rcnn_cls_3rd = self.cls_layer_3rd(l_features_3[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, 1 or 2) rcnn_reg_3rd = self.reg_layer_3rd(l_features_3[-1]).transpose( 1, 2).contiguous().squeeze(dim=1) # (B*64, C) #loss if self.training: cls_label = target_dict_3rd['cls_label'].float() rcnn_cls_flat = rcnn_cls_3rd.view(-1) batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none') cls_label_flat = cls_label.view(-1) cls_valid_mask = (cls_label_flat >= 0).float() rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) gt_boxes3d_ct = target_dict_3rd['gt_of_rois'] reg_valid_mask = target_dict_3rd['reg_valid_mask'] fg_mask = (reg_valid_mask > 0) #cls_mask=(target_dict_3rd['cls_label']>0) #print(rcnn_reg_3rd.view(batch_size_2, -1)[cls_mask].size(0)) #print(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0)) if rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0) == 0: fg_mask = (reg_valid_mask <= 0) loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask], gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask], loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, anchor_size=anchor_size, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True) rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size three = { 'rcnn_loss_cls_3rd': rcnn_loss_cls, 'rcnn_loss_reg_3rd': rcnn_loss_reg } del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask else: three = {} pred_boxes3d_3rd = decode_bbox_target( roi.view(-1, 7), rcnn_reg_3rd.view(-1, rcnn_reg_3rd.shape[-1]), anchor_size=anchor_size, loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True).view(batch_size, -1, 7) ret_dict = { 'rcnn_cls': rcnn_cls, 'rcnn_reg': rcnn_reg, 'rcnn_cls_3rd': rcnn_cls_3rd, 'rcnn_reg_3rd': rcnn_reg_3rd, 'pred_boxes3d_1st': pred_boxes3d_1st, 'pred_boxes3d_2nd': pred_boxes3d_2nd, 'pred_boxes3d_3rd': pred_boxes3d_3rd } ret_dict.update(sec) ret_dict.update(one) ret_dict.update(two) ret_dict.update(three) if self.training: ret_dict.update(target_dict) return ret_dict
def get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label, tb_dict=None): ModelReturn = namedtuple("ModelReturn", ["loss", "tb_dict", "disp_dict"]) MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda() if isinstance(model, nn.DataParallel): rpn_cls_loss_func = model.module.rpn.rpn_cls_loss_func else: rpn_cls_loss_func = model.rpn.rpn_cls_loss_func rpn_cls_label_flat = rpn_cls_label.view(-1) rpn_cls_flat = rpn_cls.view(-1) fg_mask = rpn_cls_label_flat > 0 # RPN classification loss if cfg.RPN.LOSS_CLS == "DiceLoss": rpn_loss_cls = rpn_cls_loss_func(rpn_cls, rpn_cls_label_flat) elif cfg.RPN.LOSS_CLS == "SigmoidFocalLoss": rpn_cls_target = (rpn_cls_label_flat > 0).float() pos = (rpn_cls_label_flat > 0).float() neg = (rpn_cls_label_flat == 0).float() cls_weights = pos + neg pos_normalizer = pos.sum() cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0) rpn_loss_cls = rpn_cls_loss_func(rpn_cls_flat, rpn_cls_target, cls_weights) rpn_loss_cls_pos = (rpn_loss_cls * pos).sum() rpn_loss_cls_neg = (rpn_loss_cls * neg).sum() rpn_loss_cls = rpn_loss_cls.sum() if tb_dict is not None: tb_dict["rpn_loss_cls_pos"] = rpn_loss_cls_pos.item() tb_dict["rpn_loss_cls_neg"] = rpn_loss_cls_neg.item() elif cfg.RPN.LOSS_CLS == "BinaryCrossEntropy": weight = rpn_cls_flat.new(rpn_cls_flat.shape[0]).fill_(1.0) weight[fg_mask] = cfg.RPN.FG_WEIGHT rpn_cls_label_target = (rpn_cls_label_flat > 0).float() batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rpn_cls_flat), rpn_cls_label_target, weight=weight, reduction="none", ) cls_valid_mask = (rpn_cls_label_flat >= 0).float() rpn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) else: raise NotImplementedError # RPN regression loss point_num = rpn_reg.size(0) * rpn_reg.size(1) fg_sum = fg_mask.long().sum().item() if fg_sum != 0: loss_loc, loss_angle, loss_size, reg_loss_dict = loss_utils.get_reg_loss( rpn_reg.view(point_num, -1)[fg_mask], rpn_reg_label.view(point_num, 7)[fg_mask], loc_scope=cfg.RPN.LOC_SCOPE, loc_bin_size=cfg.RPN.LOC_BIN_SIZE, num_head_bin=cfg.RPN.NUM_HEAD_BIN, anchor_size=MEAN_SIZE, get_xz_fine=cfg.RPN.LOC_XZ_FINE, get_y_by_bin=False, get_ry_fine=False, ) loss_size = 3 * loss_size # consistent with old codes rpn_loss_reg = loss_loc + loss_angle + loss_size else: loss_loc = loss_angle = loss_size = rpn_loss_reg = rpn_loss_cls * 0 rpn_loss = (rpn_loss_cls * cfg.RPN.LOSS_WEIGHT[0] + rpn_loss_reg * cfg.RPN.LOSS_WEIGHT[1]) if tb_dict is not None: tb_dict.update({ "rpn_loss_cls": rpn_loss_cls.item(), "rpn_loss_reg": rpn_loss_reg.item(), "rpn_loss": rpn_loss.item(), "rpn_fg_sum": fg_sum, "rpn_loss_loc": loss_loc.item(), "rpn_loss_angle": loss_angle.item(), "rpn_loss_size": loss_size.item(), }) return rpn_loss
def get_rcnn_loss(model, ret_dict, tb_dict): rcnn_cls, rcnn_reg = ret_dict['rcnn_cls'], ret_dict['rcnn_reg'] cls_label = ret_dict['cls_label'].float() #### cls_label process reg_valid_mask = ret_dict['reg_valid_mask'] roi_boxes3d = ret_dict['roi_boxes3d'] roi_size = roi_boxes3d[:, 3:6] gt_boxes3d_ct = ret_dict['gt_of_rois'] pts_input = ret_dict['pts_input'] # rcnn classification loss if isinstance(model, nn.DataParallel): cls_loss_func = model.module.rcnn_net.cls_loss_func else: cls_loss_func = model.rcnn_net.cls_loss_func # print("cls_label",cls_label) #### -1, 0, 1로 이루어진 tensor 256개 # print("cls_label_size",cls_label.size()) #### torch.size([256]) cls_label_flat = cls_label.view(-1) if cfg.RCNN.LOSS_CLS == 'SigmoidFocalLoss': rcnn_cls_flat = rcnn_cls.view(-1) cls_target = (cls_label_flat > 0).float() pos = (cls_label_flat > 0).float() neg = (cls_label_flat == 0).float() cls_weights = pos + neg pos_normalizer = pos.sum() cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0) rcnn_loss_cls = cls_loss_func(rcnn_cls_flat, cls_target, cls_weights) rcnn_loss_cls_pos = (rcnn_loss_cls * pos).sum() rcnn_loss_cls_neg = (rcnn_loss_cls * neg).sum() rcnn_loss_cls = rcnn_loss_cls.sum() tb_dict['rpn_loss_cls_pos'] = rcnn_loss_cls_pos.item() tb_dict['rpn_loss_cls_neg'] = rcnn_loss_cls_neg.item() elif cfg.RCNN.LOSS_CLS == 'BinaryCrossEntropy': rcnn_cls_flat = rcnn_cls.view(-1) batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none') cls_valid_mask = (cls_label_flat >= 0).float() rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) elif cfg.RCNN.LOSS_CLS == 'CrossEntropy': #### TRAIN -> RCNN # elif cfg.TRAIN.LOSS_CLS == 'CrossEntropy': # print(rcnn_cls.size()) #### torch.size([256,4]) # tensor([[ 0.0186, -0.0566, -0.0374, -0.0273], # [-0.0119, -0.0458, -0.0206, -0.0464], # [-0.0035, -0.0503, -0.0334, -0.0135], # ..., # [ 0.0098, -0.0219, -0.0139, -0.0330], # [ 0.0182, -0.0153, 0.0086, -0.0376], # [ 0.0071, -0.0278, 0.0146, -0.0302]], device='cuda:0', # print(rcnn_cls_reshape.size()) #### torch.size([256,4]) # rcnn_cls_reshape = rcnn_cls.view(rcnn_cls.shape[0], -1).sum(dim=1) #### choose sum / mean # rcnn_cls_reshape = rcnn_cls.view(-1) rcnn_cls_reshape = rcnn_cls.view(rcnn_cls.shape[0], -1) cls_target = cls_label_flat.long() # print("cls_target", cls_target) # print(cls_target.size()) #### torch.size([256]) # tensor([ 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, # 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 0, 0, 0, 0, # 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, 0, -1, -1, -1, 0, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, # 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, # 1, -1, 1, -1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, -1, # 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 0, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, # -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, # -1, 1, -1, 1, -1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, 0, 0], device='cuda:0') cls_valid_mask = (cls_target >= 0).float() #### -1 : invalid ??????? cls_target_final = torch.zeros(rcnn_cls_reshape.shape[0], rcnn_cls_reshape.shape[1]) # print(cls_target_final.size()) # size([256,4]) for i in range(cls_target_final.shape[0]): if cls_target[i] == -1: cls_target[i] = 0 # cls_target_final[i, cls_target[i].cpu().numpy()] = 1 #### class가 0,1,2,3,.. 한 줄 이어야 한다... # print("cls_target_final", cls_target) # size([256,4]) #### too many value 1... cls_target_final = cls_target_final.cuda() cls_target_final = cls_target_final.long() # print("cls_target_final", cls_target_final) # size([256,4]) #### cls_target = cls_target.unsqueeze(1) #### size([256,1]) #### cls_target = torch.cat((cls_target, cls_target, cls_target, cls_target), 1) #### size([256,4]) #### cls_target = cls_target.view(-1) # print("rcnn_cls_reshape", rcnn_cls_reshape) #### size([256,4]) # print(cls_target.size()) #### size([256]) batch_loss_cls = cls_loss_func(rcnn_cls_reshape, cls_target) #### loss calculation ## batch_loss_cls = cls_loss_func(rcnn_cls_reshape, cls_target) #### loss calculation # print(batch_loss_cls.size()) #### size([256]) normalizer = torch.clamp(cls_valid_mask.sum(), min=1.0) rcnn_loss_cls = (batch_loss_cls.mean(dim=0) * cls_valid_mask).sum() / normalizer # rcnn_loss_cls = (batch_loss_cls.mean(dim=1) * cls_valid_mask).sum() / normalizer #### why the writer misunderstand dimension 0 and 1? else: raise NotImplementedError # rcnn regression loss batch_size = pts_input.shape[0] fg_mask = (reg_valid_mask > 0) fg_sum = fg_mask.long().sum().item() if fg_sum != 0: all_anchor_size = roi_size anchor_size = all_anchor_size[ fg_mask] if cfg.RCNN.SIZE_RES_ON_ROI else MEAN_SIZE loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rcnn_reg.view(batch_size, -1)[fg_mask], gt_boxes3d_ct.view(batch_size, 8)[fg_mask], loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, anchor_size=anchor_size, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True) loss_size = 3 * loss_size # consistent with old codes rcnn_loss_reg = loss_loc + loss_angle + loss_size tb_dict.update(reg_loss_dict) else: loss_loc = loss_angle = loss_size = rcnn_loss_reg = rcnn_loss_cls * 0 rcnn_loss = rcnn_loss_cls + rcnn_loss_reg tb_dict['rcnn_loss_cls'] = rcnn_loss_cls.item() tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item() tb_dict['rcnn_loss'] = rcnn_loss.item() tb_dict['rcnn_loss_loc'] = loss_loc.item() tb_dict['rcnn_loss_angle'] = loss_angle.item() tb_dict['rcnn_loss_size'] = loss_size.item() # fg : foreground, bg : background tb_dict['rcnn_cls_fg'] = (cls_label > 0).sum().item() tb_dict['rcnn_cls_bg'] = (cls_label == 0).sum().item() tb_dict['rcnn_reg_fg'] = reg_valid_mask.sum().item() return rcnn_loss
def get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label, tb_dict): if isinstance(model, nn.DataParallel): rpn_cls_loss_func = model.module.rpn.rpn_cls_loss_func else: rpn_cls_loss_func = model.rpn.rpn_cls_loss_func #model.rpn.rpn_cls_loss_func is defined in lib/net/rpn.py rpn_cls_label_flat = rpn_cls_label.view(-1) rpn_cls_flat = rpn_cls.view(-1) fg_mask = (rpn_cls_label_flat > 0) # RPN classification loss if cfg.RPN.LOSS_CLS == 'DiceLoss': rpn_loss_cls = rpn_cls_loss_func(rpn_cls, rpn_cls_label_flat) elif cfg.RPN.LOSS_CLS == 'SigmoidFocalLoss': rpn_cls_target = (rpn_cls_label_flat > 0).float() pos = (rpn_cls_label_flat > 0).float() neg = (rpn_cls_label_flat == 0).float() cls_weights = pos + neg pos_normalizer = pos.sum() cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0) rpn_loss_cls = rpn_cls_loss_func(rpn_cls_flat, rpn_cls_target, cls_weights) rpn_loss_cls_pos = (rpn_loss_cls * pos).sum() rpn_loss_cls_neg = (rpn_loss_cls * neg).sum() rpn_loss_cls = rpn_loss_cls.sum() tb_dict['rpn_loss_cls_pos'] = rpn_loss_cls_pos.item() tb_dict['rpn_loss_cls_neg'] = rpn_loss_cls_neg.item() elif cfg.RPN.LOSS_CLS == 'BinaryCrossEntropy': weight = rpn_cls_flat.new(rpn_cls_flat.shape[0]).fill_(1.0) weight[fg_mask] = cfg.RPN.FG_WEIGHT rpn_cls_label_target = (rpn_cls_label_flat > 0).float() batch_loss_cls = F.binary_cross_entropy( torch.sigmoid(rpn_cls_flat), rpn_cls_label_target, weight=weight, reduction='none') cls_valid_mask = (rpn_cls_label_flat >= 0).float() rpn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp( cls_valid_mask.sum(), min=1.0) else: raise NotImplementedError # RPN regression loss point_num = rpn_reg.size(0) * rpn_reg.size( 1) # notice rpn_reg is a batch (multiple scenes) # total number of points (in all scenes) that are inside a box fg_sum = fg_mask.long().sum().item() if fg_sum != 0: loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rpn_reg.view(point_num, -1)[fg_mask], # we just make a list of points (scene number doesn't matter) we already have the labels for each point # we only regress the points that are labeled inside a box (points outside bboxes have all features 0 no point in regressing) rpn_reg_label.view(point_num, 9)[fg_mask], loc_scope=cfg.RPN.LOC_SCOPE, loc_bin_size=cfg.RPN.LOC_BIN_SIZE, num_head_bin=cfg.RPN.NUM_HEAD_BIN, anchor_size=MEAN_SIZE, get_xz_fine=cfg.RPN.LOC_XZ_FINE, get_y_by_bin=False, get_ry_fine=False) #loss_size = 3 * loss_size # consistent with old codes rpn_loss_reg = loss_loc + loss_angle + loss_size else: loss_loc = loss_angle = loss_size = rpn_loss_reg = rpn_loss_cls * 0 # we just sum them ! LOSS_WEIGHT: [1.0, 1.0] rpn_loss = rpn_loss_cls * cfg.RPN.LOSS_WEIGHT[ 0] + rpn_loss_reg * cfg.RPN.LOSS_WEIGHT[1] tb_dict.update({ 'rpn_loss_cls': rpn_loss_cls.item(), 'rpn_loss_reg': rpn_loss_reg.item(), 'rpn_loss': rpn_loss.item(), 'rpn_fg_sum': fg_sum, 'rpn_loss_loc': loss_loc.item(), 'rpn_loss_angle': loss_angle.item(), 'rpn_loss_size': loss_size.item() }) return rpn_loss