def forward(self, data_dicts): #dict_keys(['point_cloud', 'rot_angle', 'box3d_center', 'size_class', 'size_residual', 'angle_class', 'angle_residual', 'one_hot', 'label', 'center_ref1', 'center_ref2', 'center_ref3', 'center_ref4']) point_cloud = data_dicts.get('point_cloud') #torch.Size([32, 4, 1024]) one_hot = data_dicts.get('one_hot') #torch.Size([32, 3]) ref_label = data_dicts.get('ref_label') #torch.Size([32, 140]) bs = point_cloud.shape[0] # If not None, use to Compute Loss #seg_label = data_dicts.get('seg')#torch.Size([32, 1024]) box3d_center_label = data_dicts.get( 'box3d_center') #torch.Size([32, 3]) size_class_label = data_dicts.get('size_class') #torch.Size([32]) #size_residual_label = data_dicts.get('size_residual') # torch.Size([32, 3])### #heading_class_label = data_dicts.get('angle_class') # torch.Size([32])### #heading_residual_label = data_dicts.get('angle_residual') # torch.Size([32])### box3d_size_label = data_dicts.get('box3d_size') ###not residual box3d_heading_label = data_dicts.get('box3d_heading') ###not residual center_ref1 = data_dicts.get('center_ref1') #torch.Size([32, 3, 280]) center_ref2 = data_dicts.get('center_ref2') #torch.Size([32, 3, 140]) center_ref3 = data_dicts.get('center_ref3') #torch.Size([32, 3, 70]) center_ref4 = data_dicts.get('center_ref4') #torch.Size([32, 3, 35]) object_point_cloud_xyz = point_cloud[:, :3, :].contiguous() if point_cloud.shape[1] == 4: object_point_cloud_i = point_cloud[:, [3], :].contiguous( ) #torch.Size([32, 1, 1024]) elif point_cloud.shape[1] == 6: object_point_cloud_i = point_cloud[:, 3:6, :].contiguous( ) # torch.Size([32, 3, 1024]) else: object_point_cloud_i = None mean_size_array = torch.from_numpy(g_mean_size_arr).type_as( point_cloud) feat1, feat2, feat3, feat4 = self.feat_net( object_point_cloud_xyz, [center_ref1, center_ref2, center_ref3, center_ref4], object_point_cloud_i, one_hot) #feat1:torch.Size([32, 131, 280]) #feat2:torch.Size([32, 131, 140]) #feat3:torch.Size([32, 131, 70]) #feat4:torch.Size([32, 131, 35]) x = self.conv_net(feat1, feat2, feat3, feat4) ##torch.Size([32, 768, 140]) cls_scores = self.cls_out(x) #torch.Size([32, 2, 140]) outputs = self.reg_out(x) #torch.Size([32, 39, 140]) num_out = outputs.shape[2] output_size = outputs.shape[1] # b, c, n -> b, n, c cls_scores = cls_scores.permute(0, 2, 1).contiguous().view( -1, 2) #torch.Size([4480, 2]) outputs = outputs.permute(0, 2, 1).contiguous().view( -1, output_size) #torch.Size([4480, 39]) center_ref2 = center_ref2.permute(0, 2, 1).contiguous().view( -1, 3) #torch.Size([4480, 3]) cls_probs = F.softmax(cls_scores, -1) #torch.Size([4480, 2]) if box3d_center_label is None: #no label == test mode or from rgb detection -> return output det_outputs = self._slice_output(outputs) # torch.Size([4480, 39]) center_boxnet, heading_scores, heading_res_norm, size_scores, size_res_norm = det_outputs heading_probs = F.softmax(heading_scores, -1) # torch.Size([4480, 12]) size_probs = F.softmax(size_scores, -1) # torch.Size([4480, 3]) heading_pred_label = torch.argmax(heading_probs, -1) size_pred_label = torch.argmax(size_probs, -1) center_preds = center_boxnet + center_ref2 heading_preds = angle_decode(heading_res_norm, heading_pred_label) size_preds = size_decode(size_res_norm, mean_size_array, size_pred_label) # corner_preds = get_box3d_corners_helper(center_preds, heading_preds, size_preds) cls_probs = cls_probs.view(bs, -1, 2) center_preds = center_preds.view(bs, -1, 3) size_preds = size_preds.view(bs, -1, 3) heading_preds = heading_preds.view(bs, -1) outputs = (cls_probs, center_preds, heading_preds, size_preds) return outputs fg_idx = (ref_label.view(-1) == 1).nonzero().view( -1) #torch.Size([99]) assert fg_idx.numel() != 0 outputs = outputs[fg_idx, :] #torch.Size([99, 39]) center_ref2 = center_ref2[fg_idx] #torch.Size([99, 3]) det_outputs = self._slice_output(outputs) center_boxnet, heading_scores, heading_res_norm, size_scores, size_res_norm = det_outputs #(99,3+12+12+3+3x3) heading_probs = F.softmax(heading_scores, -1) #torch.Size([99, 12]) size_probs = F.softmax(size_scores, -1) #torch.Size([99, 3]) # cls_loss = F.cross_entropy(cls_scores, mask_label, ignore_index=-1) cls_loss = softmax_focal_loss_ignore(cls_probs, ref_label.view(-1), ignore_idx=-1) heading_probs = F.softmax(heading_scores, -1) size_probs = F.softmax(size_scores, -1) # cls_loss = F.cross_entropy(cls_scores, mask_label, ignore_index=-1) cls_loss = softmax_focal_loss_ignore(cls_probs, ref_label.view(-1), ignore_idx=-1) # prepare label center_label = box3d_center_label.unsqueeze(1).expand(-1, num_out, -1)\ .contiguous().view(-1, 3)[fg_idx]#torch.Size([99, 3]) size_label = box3d_size_label.unsqueeze(1).expand(-1, num_out, -1)\ .contiguous().view(-1, 3)[fg_idx]#torch.Size([99, 3]) heading_label = box3d_heading_label.view(-1,1).expand(-1, num_out)\ .contiguous().view(-1)[fg_idx]#torch.Size([99]) size_class_label = size_class_label.view(-1,1).expand(-1, num_out)\ .contiguous().view(-1)[fg_idx]#torch.Size([99]) # encode regression targets center_gt_offsets = center_encode(center_label, center_ref2) #torch.Size([99, 3]) heading_class_label, heading_res_norm_label = angle_encode( heading_label) #torch.Size([99]),torch.Size([99]) size_res_label_norm = size_encode( size_label, mean_size_array, size_class_label) #torch.Size([99, 3]) # loss calculation # center_loss center_loss = self.get_center_loss(center_boxnet, center_gt_offsets) # heading loss heading_class_loss, heading_res_norm_loss = self.get_heading_loss( heading_scores, heading_res_norm, heading_class_label, heading_res_norm_label) # size loss size_class_loss, size_res_norm_loss = self.get_size_loss( size_scores, size_res_norm, size_class_label, size_res_label_norm) # corner loss regulation center_preds = center_decode(center_ref2, center_boxnet) heading = angle_decode(heading_res_norm, heading_class_label) size = size_decode(size_res_norm, mean_size_array, size_class_label) corners_loss, corner_gts = self.get_corner_loss( (center_preds, heading, size), (center_label, heading_label, size_label)) BOX_LOSS_WEIGHT = cfg.LOSS.BOX_LOSS_WEIGHT CORNER_LOSS_WEIGHT = cfg.LOSS.CORNER_LOSS_WEIGHT HEAD_REG_WEIGHT = cfg.LOSS.HEAD_REG_WEIGHT SIZE_REG_WEIGHT = cfg.LOSS.SIZE_REG_WEIGHT # Weighted sum of all losses loss = cls_loss + \ BOX_LOSS_WEIGHT * (center_loss + heading_class_loss + size_class_loss + HEAD_REG_WEIGHT * heading_res_norm_loss + SIZE_REG_WEIGHT * size_res_norm_loss + CORNER_LOSS_WEIGHT * corners_loss) # some metrics to monitor training status with torch.no_grad(): # accuracy cls_prec = get_accuracy(cls_probs, ref_label.view(-1)) heading_prec = get_accuracy(heading_probs, heading_class_label.view(-1)) size_prec = get_accuracy(size_probs, size_class_label.view(-1)) # iou metrics heading_pred_label = torch.argmax(heading_probs, -1) size_pred_label = torch.argmax(size_probs, -1) heading_preds = angle_decode(heading_res_norm, heading_pred_label) size_preds = size_decode(size_res_norm, mean_size_array, size_pred_label) corner_preds = get_box3d_corners_helper(center_preds, heading_preds, size_preds) overlap = rbbox_iou_3d_pair(corner_preds.detach().cpu().numpy(), corner_gts.detach().cpu().numpy()) iou2ds, iou3ds = overlap[:, 0], overlap[:, 1] iou2d_mean = iou2ds.mean() iou3d_mean = iou3ds.mean() iou3d_gt_mean = (iou3ds >= cfg.IOU_THRESH).mean() iou2d_mean = torch.tensor(iou2d_mean).type_as(cls_prec) iou3d_mean = torch.tensor(iou3d_mean).type_as(cls_prec) iou3d_gt_mean = torch.tensor(iou3d_gt_mean).type_as(cls_prec) losses = { 'total_loss': loss, 'cls_loss': cls_loss, 'center_loss': center_loss, 'heading_class_loss': heading_class_loss, 'heading_residual_normalized_loss': heading_res_norm_loss, 'size_class_loss': size_class_loss, 'size_residual_normalized_loss': size_res_norm_loss, 'corners_loss': corners_loss } metrics = { 'cls_acc': cls_prec, 'head_acc': heading_prec, 'size_acc': size_prec, 'iou2d': iou2d_mean, 'iou3d': iou3d_mean, 'iou3d_' + str(cfg.IOU_THRESH): iou3d_gt_mean } return losses, metrics
def forward(self, data_dicts): image = data_dicts.get('image') out_image = self.cnn(image) P = data_dicts.get('P') query_v1 = data_dicts.get('query_v1') point_cloud = data_dicts.get('point_cloud') one_hot_vec = data_dicts.get('one_hot') cls_label = data_dicts.get('label') size_class_label = data_dicts.get('size_class') center_label = data_dicts.get('box3d_center') heading_label = data_dicts.get('box3d_heading') size_label = data_dicts.get('box3d_size') center_ref1 = data_dicts.get('center_ref1') center_ref2 = data_dicts.get('center_ref2') center_ref3 = data_dicts.get('center_ref3') center_ref4 = data_dicts.get('center_ref4') batch_size = point_cloud.shape[0] object_point_cloud_xyz = point_cloud[:, :3, :].contiguous() if point_cloud.shape[1] > 3: object_point_cloud_i = point_cloud[:, [3], :].contiguous() else: object_point_cloud_i = None mean_size_array = torch.from_numpy(MEAN_SIZE_ARRAY).type_as( point_cloud) feat1, feat2, feat3, feat4 = self.feat_net( object_point_cloud_xyz, [center_ref1, center_ref2, center_ref3, center_ref4], object_point_cloud_i, one_hot_vec, out_image, P, query_v1) x = self.conv_net(feat1, feat2, feat3, feat4) cls_scores = self.cls_out(x) outputs = self.reg_out(x) num_out = outputs.shape[2] output_size = outputs.shape[1] # b, c, n -> b, n, c cls_scores = cls_scores.permute(0, 2, 1).contiguous().view(-1, 2) outputs = outputs.permute(0, 2, 1).contiguous().view(-1, output_size) center_ref2 = center_ref2.permute(0, 2, 1).contiguous().view(-1, 3) cls_probs = F.softmax(cls_scores, -1) if center_label is None: assert not self.training, 'Please provide labels for training.' det_outputs = self._slice_output(outputs) center_boxnet, heading_scores, heading_res_norm, size_scores, size_res_norm = det_outputs # decode heading_probs = F.softmax(heading_scores, -1) size_probs = F.softmax(size_scores, -1) heading_pred_label = torch.argmax(heading_probs, -1) size_pred_label = torch.argmax(size_probs, -1) center_preds = center_boxnet + center_ref2 heading_preds = angle_decode(heading_res_norm, heading_pred_label) size_preds = size_decode(size_res_norm, mean_size_array, size_pred_label) # corner_preds = get_box3d_corners_helper(center_preds, heading_preds, size_preds) cls_probs = cls_probs.view(batch_size, -1, 2) center_preds = center_preds.view(batch_size, -1, 3) size_preds = size_preds.view(batch_size, -1, 3) heading_preds = heading_preds.view(batch_size, -1) outputs = (cls_probs, center_preds, heading_preds, size_preds) return outputs fg_idx = (cls_label.view(-1) == 1).nonzero().view(-1) assert fg_idx.numel() != 0 outputs = outputs[fg_idx, :] center_ref2 = center_ref2[fg_idx] det_outputs = self._slice_output(outputs) center_boxnet, heading_scores, heading_res_norm, size_scores, size_res_norm = det_outputs heading_probs = F.softmax(heading_scores, -1) size_probs = F.softmax(size_scores, -1) # cls_loss = F.cross_entropy(cls_scores, mask_label, ignore_index=-1) cls_loss = softmax_focal_loss_ignore(cls_probs, cls_label.view(-1), ignore_idx=-1) # prepare label center_label = center_label.unsqueeze(1).expand(-1, num_out, -1).contiguous().view( -1, 3)[fg_idx] heading_label = heading_label.expand( -1, num_out).contiguous().view(-1)[fg_idx] size_label = size_label.unsqueeze(1).expand(-1, num_out, -1).contiguous().view( -1, 3)[fg_idx] size_class_label = size_class_label.expand( -1, num_out).contiguous().view(-1)[fg_idx] # encode regression targets center_gt_offsets = center_encode(center_label, center_ref2) heading_class_label, heading_res_norm_label = angle_encode( heading_label) size_res_label_norm = size_encode(size_label, mean_size_array, size_class_label) # loss calculation # center_loss center_loss = self.get_center_loss(center_boxnet, center_gt_offsets) # heading loss heading_class_loss, heading_res_norm_loss = self.get_heading_loss( heading_scores, heading_res_norm, heading_class_label, heading_res_norm_label) # size loss size_class_loss, size_res_norm_loss = self.get_size_loss( size_scores, size_res_norm, size_class_label, size_res_label_norm) # corner loss regulation center_preds = center_decode(center_ref2, center_boxnet) heading = angle_decode(heading_res_norm, heading_class_label) size = size_decode(size_res_norm, mean_size_array, size_class_label) corners_loss, corner_gts = self.get_corner_loss( (center_preds, heading, size), (center_label, heading_label, size_label)) BOX_LOSS_WEIGHT = cfg.LOSS.BOX_LOSS_WEIGHT CORNER_LOSS_WEIGHT = cfg.LOSS.CORNER_LOSS_WEIGHT HEAD_REG_WEIGHT = cfg.LOSS.HEAD_REG_WEIGHT SIZE_REG_WEIGHT = cfg.LOSS.SIZE_REG_WEIGHT # Weighted sum of all losses loss = cls_loss + \ BOX_LOSS_WEIGHT * (center_loss + heading_class_loss + size_class_loss + HEAD_REG_WEIGHT * heading_res_norm_loss + SIZE_REG_WEIGHT * size_res_norm_loss + CORNER_LOSS_WEIGHT * corners_loss) # some metrics to monitor training status with torch.no_grad(): # accuracy cls_prec = get_accuracy(cls_probs, cls_label.view(-1)) heading_prec = get_accuracy(heading_probs, heading_class_label.view(-1)) size_prec = get_accuracy(size_probs, size_class_label.view(-1)) # iou metrics heading_pred_label = torch.argmax(heading_probs, -1) size_pred_label = torch.argmax(size_probs, -1) heading_preds = angle_decode(heading_res_norm, heading_pred_label) size_preds = size_decode(size_res_norm, mean_size_array, size_pred_label) corner_preds = get_box3d_corners_helper(center_preds, heading_preds, size_preds) overlap = rbbox_iou_3d_pair(corner_preds.detach().cpu().numpy(), corner_gts.detach().cpu().numpy()) iou2ds, iou3ds = overlap[:, 0], overlap[:, 1] iou2d_mean = iou2ds.mean() iou3d_mean = iou3ds.mean() iou3d_gt_mean = (iou3ds >= cfg.IOU_THRESH).mean() iou2d_mean = torch.tensor(iou2d_mean).type_as(cls_prec) iou3d_mean = torch.tensor(iou3d_mean).type_as(cls_prec) iou3d_gt_mean = torch.tensor(iou3d_gt_mean).type_as(cls_prec) losses = { 'total_loss': loss, 'cls_loss': cls_loss, 'center_loss': center_loss, 'head_cls_loss': heading_class_loss, 'head_res_loss': heading_res_norm_loss, 'size_cls_loss': size_class_loss, 'size_res_loss': size_res_norm_loss, 'corners_loss': corners_loss } metrics = { 'cls_acc': cls_prec, 'head_acc': heading_prec, 'size_acc': size_prec, 'IoU_2D': iou2d_mean, 'IoU_3D': iou3d_mean, 'IoU_' + str(cfg.IOU_THRESH): iou3d_gt_mean } return losses, metrics
def get_iou_cc(bb1, bb2): ious = box_ops_cc.rbbox_iou_3d_pair(bb1[np.newaxis, ...], bb2[np.newaxis, ...]) return ious[0, 1]