def compute_box_and_sem_cls_loss(data_dict, config): """ Compute 3D bounding box and semantic classification loss. Args: data_dict: dict (read-only) Returns: center_loss heading_cls_loss heading_reg_loss size_cls_loss size_reg_loss sem_cls_loss """ num_heading_bin = config.num_heading_bin num_size_cluster = config.num_size_cluster num_class = config.num_class mean_size_arr = config.mean_size_arr object_assignment = data_dict['object_assignment'] batch_size = object_assignment.shape[0] # Compute center loss pred_center = data_dict['center'] gt_center = data_dict['center_label'][:, :, 0:3] dist1, ind1, dist2, _ = nn_distance(pred_center, gt_center) # dist1: BxK, dist2: BxK2 box_label_mask = data_dict['box_label_mask'] objectness_label = data_dict['objectness_label'].float() centroid_reg_loss1 = \ torch.sum(dist1*objectness_label)/(torch.sum(objectness_label)+1e-6) centroid_reg_loss2 = \ torch.sum(dist2*box_label_mask)/(torch.sum(box_label_mask)+1e-6) center_loss = centroid_reg_loss1 + centroid_reg_loss2 # Compute heading loss heading_class_label = torch.gather( data_dict['heading_class_label'], 1, object_assignment) # select (B,K) from (B,K2) criterion_heading_class = nn.CrossEntropyLoss(reduction='none') heading_class_loss = criterion_heading_class( data_dict['heading_scores'].transpose(2, 1), heading_class_label) # (B,K) heading_class_loss = torch.sum(heading_class_loss * objectness_label) / ( torch.sum(objectness_label) + 1e-6) heading_residual_label = torch.gather( data_dict['heading_residual_label'], 1, object_assignment) # select (B,K) from (B,K2) heading_residual_normalized_label = heading_residual_label / ( np.pi / num_heading_bin) # Ref: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/3 heading_label_one_hot = torch.cuda.FloatTensor( batch_size, heading_class_label.shape[1], num_heading_bin).zero_() heading_label_one_hot.scatter_( 2, heading_class_label.unsqueeze(-1), 1) # src==1 so it's *one-hot* (B,K,num_heading_bin) heading_residual_normalized_loss = huber_loss(torch.sum( data_dict['heading_residuals_normalized'] * heading_label_one_hot, -1) - heading_residual_normalized_label, delta=1.0) # (B,K) heading_residual_normalized_loss = torch.sum( heading_residual_normalized_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) # Compute size loss size_class_label = torch.gather( data_dict['size_class_label'], 1, object_assignment) # select (B,K) from (B,K2) criterion_size_class = nn.CrossEntropyLoss(reduction='none') size_class_loss = criterion_size_class(data_dict['size_scores'].transpose( 2, 1), size_class_label) # (B,K) size_class_loss = torch.sum(size_class_loss * objectness_label) / ( torch.sum(objectness_label) + 1e-6) size_residual_label = torch.gather(data_dict['size_residual_label'], 1, object_assignment.unsqueeze(-1).repeat( 1, 1, 3)) # select (B,K,3) from (B,K2,3) size_label_one_hot = torch.cuda.FloatTensor(batch_size, size_class_label.shape[1], num_size_cluster).zero_() size_label_one_hot.scatter_( 2, size_class_label.unsqueeze(-1), 1) # src==1 so it's *one-hot* (B,K,num_size_cluster) size_label_one_hot_tiled = size_label_one_hot.unsqueeze(-1).repeat( 1, 1, 1, 3) # (B,K,num_size_cluster,3) predicted_size_residual_normalized = torch.sum( data_dict['size_residuals_normalized'] * size_label_one_hot_tiled, 2) # (B,K,3) mean_size_arr_expanded = torch.from_numpy(mean_size_arr.astype( np.float32)).cuda().unsqueeze(0).unsqueeze( 0) # (1,1,num_size_cluster,3) mean_size_label = torch.sum(size_label_one_hot_tiled * mean_size_arr_expanded, 2) # (B,K,3) size_residual_label_normalized = size_residual_label / mean_size_label # (B,K,3) size_residual_normalized_loss = torch.mean( huber_loss(predicted_size_residual_normalized - size_residual_label_normalized, delta=1.0), -1) # (B,K,3) -> (B,K) size_residual_normalized_loss = torch.sum( size_residual_normalized_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) # 3.4 Semantic cls loss sem_cls_label = torch.gather(data_dict['sem_cls_label'], 1, object_assignment) # select (B,K) from (B,K2) criterion_sem_cls = nn.CrossEntropyLoss(reduction='none') sem_cls_loss = criterion_sem_cls(data_dict['sem_cls_scores'].transpose( 2, 1), sem_cls_label) # (B,K) sem_cls_loss = torch.sum( sem_cls_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) return center_loss, heading_class_loss, heading_residual_normalized_loss, size_class_loss, size_residual_normalized_loss, sem_cls_loss
def compute_box_and_sem_cls_loss(end_points, supervised_inds, dataset_config, config_dict): """ Compute 3D bounding box and semantic classification loss. Args: end_points: dict (read-only) Returns: center_loss heading_cls_loss heading_reg_loss size_cls_loss size_reg_loss sem_cls_loss """ num_heading_bin = dataset_config.num_heading_bin num_size_cluster = dataset_config.num_size_cluster mean_size_arr = dataset_config.mean_size_arr object_assignment = end_points['object_assignment'] batch_size = object_assignment.shape[0] # Compute center loss dist1, ind1, dist2, _ = nn_distance( end_points['center'][supervised_inds, ...], end_points['center_label'][supervised_inds, ...][:, :, 0:3]) # dist1: BxK, dist2: BxK2 box_label_mask = end_points['box_label_mask'][supervised_inds, ...] objectness_label = end_points['objectness_label'].float() centroid_reg_loss1 = \ torch.sum(dist1 * objectness_label) / (torch.sum(objectness_label) + 1e-6) centroid_reg_loss2 = \ torch.sum(dist2 * box_label_mask) / (torch.sum(box_label_mask) + 1e-6) center_loss = centroid_reg_loss1 + centroid_reg_loss2 # Compute heading loss heading_class_label = torch.gather( end_points['heading_class_label'][supervised_inds, ...], 1, object_assignment) # select (B,K) from (B,K2) criterion_heading_class = nn.CrossEntropyLoss(reduction='none') heading_class_loss = criterion_heading_class( end_points['heading_scores'][supervised_inds, ...].transpose(2, 1), heading_class_label) # (B,K) heading_class_loss = torch.sum(heading_class_loss * objectness_label) / ( torch.sum(objectness_label) + 1e-6) heading_residual_label = torch.gather( end_points['heading_residual_label'][supervised_inds, ...], 1, object_assignment) # select (B,K) from (B,K2) heading_residual_normalized_label = heading_residual_label / ( np.pi / num_heading_bin) # Ref: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/3 heading_label_one_hot = torch.cuda.FloatTensor( batch_size, heading_class_label.shape[1], num_heading_bin).zero_() heading_label_one_hot.scatter_( 2, heading_class_label.unsqueeze(-1), 1) # src==1 so it's *one-hot* (B,K,num_heading_bin) heading_residual_normalized_loss = huber_loss(torch.sum( end_points['heading_residuals_normalized'][supervised_inds, ...] * heading_label_one_hot, -1) - heading_residual_normalized_label, delta=1.0) # (B,K) heading_residual_normalized_loss = torch.sum( heading_residual_normalized_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) # Compute size loss size_class_label = torch.gather( end_points['size_class_label'][supervised_inds, ...], 1, object_assignment) # select (B,K) from (B,K2) criterion_size_class = nn.CrossEntropyLoss(reduction='none') size_class_loss = criterion_size_class( end_points['size_scores'][supervised_inds, ...].transpose(2, 1), size_class_label) # (B,K) size_class_loss = torch.sum(size_class_loss * objectness_label) / ( torch.sum(objectness_label) + 1e-6) size_residual_label = torch.gather( end_points['size_residual_label'][supervised_inds, ...], 1, object_assignment.unsqueeze(-1).repeat( 1, 1, 3)) # select (B,K,3) from (B,K2,3) size_label_one_hot = torch.cuda.FloatTensor(batch_size, size_class_label.shape[1], num_size_cluster).zero_() size_label_one_hot.scatter_( 2, size_class_label.unsqueeze(-1), 1) # src==1 so it's *one-hot* (B,K,num_size_cluster) size_label_one_hot_tiled = size_label_one_hot.unsqueeze(-1).repeat( 1, 1, 1, 3) # (B,K,num_size_cluster,3) predicted_size_residual_normalized = torch.sum( end_points['size_residuals_normalized'][supervised_inds, ...] * size_label_one_hot_tiled, 2) # (B,K,3) mean_size_arr_expanded = torch.from_numpy(mean_size_arr.astype( np.float32)).cuda().unsqueeze(0).unsqueeze( 0) # (1,1,num_size_cluster,3) mean_size_label = torch.sum(size_label_one_hot_tiled * mean_size_arr_expanded, 2) # (B,K,3) size_residual_label_normalized = size_residual_label / mean_size_label # (B,K,3) size_residual_normalized_loss = torch.mean( huber_loss(predicted_size_residual_normalized - size_residual_label_normalized, delta=1.0), -1) # (B,K,3) -> (B,K) size_residual_normalized_loss = torch.sum( size_residual_normalized_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) # 3.4 Semantic cls loss sem_cls_label = torch.gather(end_points['sem_cls_label'][supervised_inds, ...], 1, object_assignment) # select (B,K) from (B,K2) criterion_sem_cls = nn.CrossEntropyLoss(reduction='none') sem_cls_loss = criterion_sem_cls( end_points['sem_cls_scores'][supervised_inds, ...].transpose(2, 1), sem_cls_label) # (B,K) sem_cls_loss = torch.sum( sem_cls_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) end_points['cls_acc'] = torch.sum( (sem_cls_label == end_points['sem_cls_scores'][supervised_inds, ...].argmax(dim=-1)) * objectness_label) / (torch.sum(objectness_label) + 1e-6) iou_labels, _, iou_assignment = compute_iou_labels( # aggregated_vote_xyz -> center end_points, supervised_inds, end_points['aggregated_vote_xyz'][supervised_inds, ...], end_points['center'][supervised_inds, ...], None, None, end_points['heading_scores'][supervised_inds, ...], end_points['heading_residuals'][supervised_inds, ...], end_points['size_scores'][supervised_inds, ...], end_points['size_residuals'][supervised_inds, ...], config_dict={'dataset_config': dataset_config}) end_points['pred_iou_value'] = torch.sum(iou_labels) / iou_labels.view( -1).shape[0] end_points['pred_iou_obj_value'] = torch.sum( iou_labels * objectness_label) / (torch.sum(objectness_label) + 1e-6) end_points['obj_count'] = torch.sum(objectness_label) if 'jitter_center' in end_points.keys(): jitter_center = end_points['jitter_center'][supervised_inds, ...] jitter_size = end_points['jitter_size'][supervised_inds, ...] jitter_heading = end_points['jitter_heading'][supervised_inds, ...] jitter_objectness_label = torch.ones(batch_size, jitter_heading.shape[1]).cuda() center_label = end_points['center_label'][supervised_inds, ...] zero_mask = (1 - end_points['box_label_mask'][supervised_inds, ...] ).unsqueeze(-1).expand(-1, -1, 3).bool() center_label[zero_mask] = -1000 heading_class_label = end_points['heading_class_label'][ supervised_inds, ...] heading_residual_label = end_points['heading_residual_label'][ supervised_inds, ...] size_class_label = end_points['size_class_label'][supervised_inds, ...] size_residual_label = end_points['size_residual_label'][ supervised_inds, ...] gt_size = dataset_config.class2size_gpu(size_class_label, size_residual_label) / 2 gt_angle = dataset_config.class2angle_gpu(heading_class_label, heading_residual_label) gt_bbox = torch.cat([center_label, gt_size * 2, -gt_angle[:, :, None]], dim=2) pred_bbox = torch.cat( [jitter_center, jitter_size, -jitter_heading[:, :, None]], axis=2) pred_num = pred_bbox.shape[1] gt_bbox_ = gt_bbox.view(-1, 7) pred_bbox_ = pred_bbox.view(-1, 7) jitter_iou_labels = box3d_iou_batch_gpu(pred_bbox_, gt_bbox_) jitter_iou_labels, jitter_object_assignment = jitter_iou_labels.view( batch_size * pred_num, batch_size, -1).max(dim=2) inds = torch.arange(batch_size).cuda().unsqueeze(1).expand( -1, pred_num).contiguous().view(-1, 1) jitter_iou_labels = jitter_iou_labels.gather(dim=1, index=inds).view( batch_size, -1) jitter_iou_labels = jitter_iou_labels.detach() jitter_object_assignment = jitter_object_assignment.gather( dim=1, index=inds).view(batch_size, -1) jitter_sem_class_label = torch.gather( end_points['sem_cls_label'][supervised_inds, ...], 1, jitter_object_assignment) # select (B,K) from (B,K2) jitter_iou_pred = nn.Sigmoid()( end_points['iou_scores_jitter'][supervised_inds, ...]) if jitter_iou_pred.shape[2] > 1: # gt sem cls jitter_iou_pred = torch.gather( jitter_iou_pred, 2, jitter_sem_class_label.unsqueeze(-1)).squeeze( -1) # use pred semantic labels else: jitter_iou_pred = jitter_iou_pred.squeeze(-1) jitter_iou_acc = torch.abs(jitter_iou_pred - jitter_iou_labels) end_points['jitter_iou_acc'] = torch.sum( jitter_iou_acc) / jitter_iou_acc.view(-1).shape[0] end_points['jitter_iou_acc_obj'] = torch.sum( jitter_iou_acc * jitter_objectness_label) / ( torch.sum(jitter_objectness_label) + 1e-6) jitter_iou_loss = huber_loss(jitter_iou_pred - jitter_iou_labels.detach(), delta=1.0) jitter_iou_loss = torch.sum( jitter_iou_loss * jitter_objectness_label) / ( torch.sum(jitter_objectness_label) + 1e-6) end_points['jitter_iou_loss'] = jitter_iou_loss if 'iou_scores' in end_points.keys(): iou_pred = nn.Sigmoid()(end_points['iou_scores'][supervised_inds, ...]) if iou_pred.shape[2] > 1: # gt sem cls iou_sem_cls_label = torch.gather( end_points['sem_cls_label'][supervised_inds, ...], 1, iou_assignment) iou_pred = torch.gather(iou_pred, 2, iou_sem_cls_label.unsqueeze(-1)).squeeze( -1) # use pred semantic labels else: iou_pred = iou_pred.squeeze(-1) iou_acc = torch.abs(iou_pred - iou_labels) end_points['iou_acc'] = torch.sum(iou_acc) / torch.sum( torch.ones(iou_acc.shape)) end_points['iou_acc_obj'] = torch.sum( iou_acc * objectness_label) / (torch.sum(objectness_label) + 1e-6) iou_loss = huber_loss(iou_pred - iou_labels.detach(), delta=1.0) # (B, K, 1) iou_loss = iou_loss.mean() end_points['iou_loss'] = iou_loss return center_loss, heading_class_loss, heading_residual_normalized_loss, size_class_loss, size_residual_normalized_loss, sem_cls_loss
def compute_box_and_sem_cls_loss(end_points, config, test_time=False): """ Compute 3D bounding box and semantic classification loss. Args: end_points: dict (read-only) Returns: center_loss heading_cls_loss heading_reg_loss size_cls_loss size_reg_loss sem_cls_loss """ num_heading_bin = config.num_heading_bin num_size_cluster = config.num_size_cluster num_class = config.num_class mean_size_arr = config.mean_size_arr object_assignment = end_points['object_assignment'] batch_size = object_assignment.shape[0] # Compute center loss pred_center = end_points['center'] gt_center = end_points['center_label'][:, :, 0:3] dist1, ind1, dist2, _ = nn_distance(pred_center, gt_center) # dist1: BxK, dist2: BxK2 box_label_mask = end_points['box_label_mask'] objectness_label = end_points['objectness_label'].float() centroid_reg_loss1 = \ torch.sum(dist1*objectness_label)/(torch.sum(objectness_label)+1e-6) centroid_reg_loss2 = \ torch.sum(dist2*box_label_mask)/(torch.sum(box_label_mask)+1e-6) center_loss = centroid_reg_loss1 + centroid_reg_loss2 # Compute heading loss heading_class_label = torch.gather( end_points['heading_class_label'], 1, object_assignment) # select (B,K) from (B,K2) criterion_heading_class = nn.CrossEntropyLoss(reduction='none') heading_class_loss = criterion_heading_class( end_points['heading_scores'].transpose(2, 1), heading_class_label) # (B,K) heading_class_loss = torch.sum(heading_class_loss * objectness_label) / ( torch.sum(objectness_label) + 1e-6) heading_residual_label = torch.gather( end_points['heading_residual_label'], 1, object_assignment) # select (B,K) from (B,K2) heading_residual_normalized_label = heading_residual_label / ( np.pi / num_heading_bin) # Ref: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/3 heading_label_one_hot = torch.cuda.FloatTensor( batch_size, heading_class_label.shape[1], num_heading_bin).zero_() heading_label_one_hot.scatter_( 2, heading_class_label.unsqueeze(-1), 1) # src==1 so it's *one-hot* (B,K,num_heading_bin) heading_residual_normalized_loss = huber_loss(torch.sum( end_points['heading_residuals_normalized'] * heading_label_one_hot, -1) - heading_residual_normalized_label, delta=1.0) # (B,K) heading_residual_normalized_loss = torch.sum( heading_residual_normalized_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) # Compute size loss size_class_label = torch.gather( end_points['size_class_label'], 1, object_assignment) # select (B,K) from (B,K2) criterion_size_class = nn.CrossEntropyLoss(reduction='none') size_class_loss = criterion_size_class(end_points['size_scores'].transpose( 2, 1), size_class_label) # (B,K) size_class_loss = torch.sum(size_class_loss * objectness_label) / ( torch.sum(objectness_label) + 1e-6) size_residual_label = torch.gather(end_points['size_residual_label'], 1, object_assignment.unsqueeze(-1).repeat( 1, 1, 3)) # select (B,K,3) from (B,K2,3) size_label_one_hot = torch.cuda.FloatTensor(batch_size, size_class_label.shape[1], num_size_cluster).zero_() size_label_one_hot.scatter_( 2, size_class_label.unsqueeze(-1), 1) # src==1 so it's *one-hot* (B,K,num_size_cluster) size_label_one_hot_tiled = size_label_one_hot.unsqueeze(-1).repeat( 1, 1, 1, 3) # (B,K,num_size_cluster,3) predicted_size_residual_normalized = torch.sum( end_points['size_residuals_normalized'] * size_label_one_hot_tiled, 2) # (B,K,3) mean_size_arr_expanded = torch.from_numpy(mean_size_arr.astype( np.float32)).cuda().unsqueeze(0).unsqueeze( 0) # (1,1,num_size_cluster,3) mean_size_label = torch.sum(size_label_one_hot_tiled * mean_size_arr_expanded, 2) # (B,K,3) size_residual_label_normalized = size_residual_label / mean_size_label # (B,K,3) size_residual_normalized_loss = torch.mean( huber_loss(predicted_size_residual_normalized - size_residual_label_normalized, delta=1.0), -1) # (B,K,3) -> (B,K) size_residual_normalized_loss = torch.sum( size_residual_normalized_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) # 3.4 Semantic cls loss sem_cls_label = torch.gather(end_points['sem_cls_label'], 1, object_assignment) # select (B,K) from (B,K2) criterion_sem_cls = nn.CrossEntropyLoss(reduction='none') sem_cls_loss = criterion_sem_cls(end_points['sem_cls_scores'].transpose( 2, 1), sem_cls_label) # (B,K) sem_cls_loss = torch.sum( sem_cls_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) end_points['cls_acc'] = torch.sum( (sem_cls_label == end_points['sem_cls_scores'].argmax( dim=-1))).float() / sem_cls_label.view(-1).shape[0] end_points['cls_acc_obj'] = torch.sum( (sem_cls_label == end_points['sem_cls_scores'].argmax(dim=-1)) * objectness_label) / (torch.sum(objectness_label) + 1e-6) # end_points['center'].retain_grad() mask = torch.arange(batch_size).cuda() iou_labels, iou_zero_mask, _ = compute_iou_labels( end_points, mask, end_points['aggregated_vote_xyz'], end_points['center'], None, None, end_points['heading_scores'], end_points['heading_residuals'], end_points['size_scores'], end_points['size_residuals'], {'dataset_config': config}) end_points['iou_labels'] = iou_labels end_points['pred_iou_value'] = torch.sum(iou_labels) / iou_labels.view( -1).shape[0] end_points['pred_iou_obj_value'] = torch.sum( iou_labels * objectness_label) / (torch.sum(objectness_label) + 1e-6) if 'iou_scores' in end_points.keys(): iou_pred = nn.Sigmoid()(end_points['iou_scores']) if iou_pred.shape[2] > 1: iou_pred = torch.gather( iou_pred, 2, end_points['sem_cls_scores'].argmax( dim=-1).unsqueeze(-1)).squeeze( -1) # use pred semantic labels else: iou_pred = iou_pred.squeeze(-1) iou_acc = torch.abs(iou_pred - iou_labels) end_points['iou_acc'] = torch.sum(iou_acc) / torch.sum( torch.ones(iou_acc.shape)) end_points['iou_acc_obj'] = torch.sum( iou_acc * objectness_label) / (torch.sum(objectness_label) + 1e-6) iou_loss = huber_loss(iou_pred - iou_labels, delta=1.0) # (B, K, 1) iou_loss = torch.sum( iou_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6) end_points['iou_loss'] = iou_loss return center_loss, heading_class_loss, heading_residual_normalized_loss, size_class_loss, size_residual_normalized_loss, sem_cls_loss