def compute_binary_diff(pred_node): if pred_node.is_leaf: return 0, 0 else: binary_diff = 0; binary_tot = 0; # all children for cnode in pred_node.children: cur_binary_diff, cur_binary_tot = compute_binary_diff(cnode) binary_diff += cur_binary_diff binary_tot += cur_binary_tot # current node if pred_node.edges is not None: for edge in pred_node.edges: pred_part_a_id = edge['part_a'] obb1 = pred_node.children[pred_part_a_id].box.cpu().numpy() obb_quat1 = pred_node.children[pred_part_a_id].get_box_quat().cpu().numpy() mesh_v1, mesh_f1 = utils.gen_obb_mesh(obb1) pc1 = utils.sample_pc(mesh_v1, mesh_f1, n_points=500) pc1 = torch.tensor(pc1, dtype=torch.float32, device=device) pred_part_b_id = edge['part_b'] obb2 = pred_node.children[pred_part_b_id].box.cpu().numpy() obb_quat2 = pred_node.children[pred_part_b_id].get_box_quat().cpu().numpy() mesh_v2, mesh_f2 = utils.gen_obb_mesh(obb2) pc2 = utils.sample_pc(mesh_v2, mesh_f2, n_points=500) pc2 = torch.tensor(pc2, dtype=torch.float32, device=device) if edge['type'] == 'ADJ': dist1, dist2 = chamferLoss(pc1.view(1, -1, 3), pc2.view(1, -1, 3)) binary_diff += (dist1.sqrt().min().item() + dist2.sqrt().min().item()) / 2 elif 'SYM' in edge['type']: if edge['type'] == 'TRANS_SYM': mat1to2, _ = compute_sym.compute_trans_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1)) elif edge['type'] == 'REF_SYM': mat1to2, _ = compute_sym.compute_ref_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1)) elif edge['type'] == 'ROT_SYM': mat1to2, _ = compute_sym.compute_rot_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1)) else: assert 'ERROR: unknown symmetry type: %s' % edge['type'] mat1to2 = torch.tensor(mat1to2, dtype=torch.float32, device=device) transformed_pc1 = pc1.matmul(torch.transpose(mat1to2[:, :3], 0, 1)) + \ mat1to2[:, 3].unsqueeze(dim=0).repeat(pc1.size(0), 1) dist1, dist2 = chamferLoss(transformed_pc1.view(1, -1, 3), pc2.view(1, -1, 3)) loss = (dist1.sqrt().mean() + dist2.sqrt().mean()) / 2 binary_diff += loss.item() else: assert 'ERROR: unknown symmetry type: %s' % edge['type'] binary_tot += 1 return binary_diff, binary_tot
def node_recon_loss(self, node_latent, gt_node): gt_geo = gt_node.geo gt_geo_feat = gt_node.geo_feat gt_center = gt_geo.mean(dim=1) gt_scale = (gt_geo - gt_center.unsqueeze(dim=1).repeat(1, self.conf.num_point, 1)).pow(2).sum(dim=2).max(dim=1)[0].sqrt().view(1, 1) geo_local, geo_center, geo_scale, geo_feat = self.node_decoder(node_latent) geo_global = self.geoToGlobal(geo_local, geo_center, geo_scale) # geo loss for the current part latent_loss = self.mseLoss(geo_feat, gt_geo_feat).mean() center_loss = self.mseLoss(geo_center, gt_center).mean() scale_loss = self.mseLoss(geo_scale, gt_scale).mean() geo_loss = self.chamferDist(geo_global, gt_geo) if gt_node.is_leaf: is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) return {'leaf': is_leaf_loss, 'geo': geo_loss, 'center': center_loss, 'scale': scale_loss, 'latent': latent_loss, 'exists': torch.zeros_like(geo_loss), 'semantic': torch.zeros_like(geo_loss), 'edge_exists': torch.zeros_like(geo_loss), 'sym': torch.zeros_like(geo_loss), 'adj': torch.zeros_like(geo_loss)}, geo_global, geo_global else: child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \ self.child_decoder(node_latent) # generate geo prediction for each child feature_len = node_latent.size(1) all_geo = []; all_leaf_geo = []; all_geo.append(geo_global) child_pred_geo_local, child_pred_geo_center, child_pred_geo_scale, _ = \ self.node_decoder(child_feats.view(-1, feature_len)) child_pred_geo = self.geoToGlobal(child_pred_geo_local, child_pred_geo_center, child_pred_geo_scale) num_pred = child_pred_geo.size(0) # perform hungarian matching between pred geo and gt geo with torch.no_grad(): child_gt_geo = torch.cat([child_node.geo for child_node in gt_node.children], dim=0) num_gt = child_gt_geo.size(0) child_pred_geo_tiled = child_pred_geo.unsqueeze(dim=0).repeat(num_gt, 1, 1, 1) child_gt_geo_tiled = child_gt_geo.unsqueeze(dim=1).repeat(1, num_pred, 1, 1) dist_mat = self.chamferDist(child_pred_geo_tiled.view(-1, self.conf.num_point, 3), \ child_gt_geo_tiled.view(-1, self.conf.num_point, 3)).view(1, num_gt, num_pred) _, matched_gt_idx, matched_pred_idx = linear_assignment(dist_mat) # get edge ground truth edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors( edge_types=self.conf.edge_types, device=child_feats.device, type_onehot=False) gt2pred = {gt_idx: pred_idx for gt_idx, pred_idx in zip(matched_gt_idx, matched_pred_idx)} edge_exists_gt = torch.zeros_like(edge_exists_logits) sym_from = []; sym_to = []; sym_mat = []; sym_type = []; adj_from = []; adj_to = []; for i in range(edge_indices_gt.shape[1]//2): if edge_indices_gt[0, i, 0].item() not in gt2pred or edge_indices_gt[0, i, 1].item() not in gt2pred: """ one of the adjacent nodes of the current gt edge was not matched to any node in the prediction, ignore this edge """ continue # correlate gt edges to pred edges edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()] edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()] edge_exists_gt[:, edge_from_idx, edge_to_idx, edge_type_list_gt[0:1, i]] = 1 edge_exists_gt[:, edge_to_idx, edge_from_idx, edge_type_list_gt[0:1, i]] = 1 # compute binary edge parameters for each matched pred edge if edge_type_list_gt[0, i].item() == 0: # ADJ adj_from.append(edge_from_idx) adj_to.append(edge_to_idx) else: # SYM obb1 = compute_sym.compute_obb(child_pred_geo[edge_from_idx].cpu().detach().numpy()) obb2 = compute_sym.compute_obb(child_pred_geo[edge_to_idx].cpu().detach().numpy()) if edge_type_list_gt[0, i].item() == 1: # ROT_SYM mat1to2, mat2to1 = compute_sym.compute_rot_sym(obb1, obb2) elif edge_type_list_gt[0, i].item() == 2: # TRANS_SYM mat1to2, mat2to1 = compute_sym.compute_trans_sym(obb1, obb2) else: # REF_SYM mat1to2, mat2to1 = compute_sym.compute_ref_sym(obb1, obb2) sym_from.append(edge_from_idx) sym_to.append(edge_to_idx) sym_mat.append(torch.tensor(mat1to2, dtype=torch.float32, device=self.conf.device).view(1, 3, 4)) sym_type.append(edge_type_list_gt[0, i].item()) # train the current node to be non-leaf is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) # gather information child_sem_gt_labels = [] child_sem_pred_logits = [] child_exists_gt = torch.zeros_like(child_exists_logits) for i in range(len(matched_gt_idx)): child_sem_gt_labels.append(gt_node.children[matched_gt_idx[i]].get_semantic_id()) child_sem_pred_logits.append(child_sem_logits[0, matched_pred_idx[i], :].view(1, -1)) child_exists_gt[:, matched_pred_idx[i], :] = 1 # train semantic labels child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0) child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, device=child_sem_pred_logits.device) semantic_loss = self.semCELoss(child_sem_pred_logits, child_sem_gt_labels) semantic_loss = semantic_loss.sum() # train unused geo to zero unmatched_scales = [] for i in range(num_pred): if i not in matched_pred_idx: unmatched_scales.append(child_pred_geo_scale[i:i+1]) if len(unmatched_scales) > 0: unmatched_scales = torch.cat(unmatched_scales, dim=0) unused_geo_loss = unmatched_scales.pow(2).sum() else: unused_geo_loss = 0.0 # train exist scores child_exists_loss = F.binary_cross_entropy_with_logits(\ input=child_exists_logits, target=child_exists_gt, reduction='none') child_exists_loss = child_exists_loss.sum() # train edge exists scores edge_exists_loss = F.binary_cross_entropy_with_logits(\ input=edge_exists_logits, target=edge_exists_gt, reduction='none') edge_exists_loss = edge_exists_loss.sum() # rescale to make it comparable to other losses, # which are in the order of the number of child nodes edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2]*edge_exists_gt.shape[3]) # compute and train binary losses sym_loss = 0 if len(sym_from) > 0: sym_from_th = torch.tensor(sym_from, dtype=torch.long, device=self.conf.device) pc_from = child_pred_geo[sym_from_th, :] sym_to_th = torch.tensor(sym_to, dtype=torch.long, device=self.conf.device) pc_to = child_pred_geo[sym_to_th, :] sym_mat_th = torch.cat(sym_mat, dim=0) transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \ sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1) loss = self.chamferDist(transformed_pc_from, pc_to) * 2 sym_loss = loss.sum() adj_loss = 0 if len(adj_from) > 0: adj_from_th = torch.tensor(adj_from, dtype=torch.long, device=self.conf.device) pc_from = child_pred_geo[adj_from_th, :] adj_to_th = torch.tensor(adj_to, dtype=torch.long, device=self.conf.device) pc_to = child_pred_geo[adj_to_th, :] dist1, dist2 = self.chamferLoss(pc_from, pc_to) loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0]) adj_loss = loss.sum() # call children + aggregate losses pred2allgeo = dict(); pred2allleafgeo = dict(); for i in range(len(matched_gt_idx)): child_losses, child_all_geo, child_all_leaf_geo = self.node_recon_loss( child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]]) pred2allgeo[matched_pred_idx[i]] = child_all_geo pred2allleafgeo[matched_pred_idx[i]] = child_all_leaf_geo all_geo.append(child_all_geo) all_leaf_geo.append(child_all_leaf_geo) latent_loss = latent_loss + child_losses['latent'] geo_loss = geo_loss + child_losses['geo'] center_loss = center_loss + child_losses['center'] scale_loss = scale_loss + child_losses['scale'] is_leaf_loss = is_leaf_loss + child_losses['leaf'] child_exists_loss = child_exists_loss + child_losses['exists'] semantic_loss = semantic_loss + child_losses['semantic'] edge_exists_loss = edge_exists_loss + child_losses['edge_exists'] sym_loss = sym_loss + child_losses['sym'] adj_loss = adj_loss + child_losses['adj'] # for sym-edges, train subtree to be symmetric for i in range(len(sym_from)): s1 = pred2allgeo[sym_from[i]].size(0) s2 = pred2allgeo[sym_to[i]].size(0) if s1 > 1 and s2 > 1: pc_from = pred2allgeo[sym_from[i]][1:, :, :].view(-1, 3) pc_to = pred2allgeo[sym_to[i]][1:, :, :].view(-1, 3) transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \ sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1) dist1, dist2 = self.chamferLoss(transformed_pc_from.view(1, -1, 3), pc_to.view(1, -1, 3)) sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2 # for adj-edges, train leaf-nodes in subtrees to be adjacent for i in range(len(adj_from)): if pred2allgeo[adj_from[i]].size(0) > pred2allleafgeo[adj_from[i]].size(0) \ or pred2allgeo[adj_to[i]].size(0) > pred2allleafgeo[adj_to[i]].size(0): pc_from = pred2allleafgeo[adj_from[i]].view(1, -1, 3) pc_to = pred2allleafgeo[adj_to[i]].view(1, -1, 3) dist1, dist2 = self.chamferLoss(pc_from, pc_to) adj_loss += dist1.min() + dist2.min() return {'leaf': is_leaf_loss, 'geo': geo_loss + unused_geo_loss, 'center': center_loss, 'scale': scale_loss, 'latent': latent_loss, 'exists': child_exists_loss, 'semantic': semantic_loss, 'edge_exists': edge_exists_loss, 'sym': sym_loss, 'adj': adj_loss}, torch.cat(all_geo, dim=0), torch.cat(all_leaf_geo, dim=0)
def compute_struct_diff(gt_node, pred_node): if gt_node.is_leaf: if pred_node.is_leaf: return 0, 0, 0, 0, 0, 0 else: return len( pred_node.boxes()) - 1, 0, 0, pred_node.get_subtree_edge_count( ), 0, 0 else: if pred_node.is_leaf: return len(gt_node.boxes() ) - 1, 0, gt_node.get_subtree_edge_count() * 2, 0, 0, 0 else: gt_sem = set([node.label for node in gt_node.children]) pred_sem = set([node.label for node in pred_node.children]) intersect_sem = set.intersection(gt_sem, pred_sem) gt_cnodes_per_sem = dict() for node_id, gt_cnode in enumerate(gt_node.children): if gt_cnode.label in intersect_sem: if gt_cnode.label not in gt_cnodes_per_sem: gt_cnodes_per_sem[gt_cnode.label] = [] gt_cnodes_per_sem[gt_cnode.label].append(node_id) pred_cnodes_per_sem = dict() for node_id, pred_cnode in enumerate(pred_node.children): if pred_cnode.label in intersect_sem: if pred_cnode.label not in pred_cnodes_per_sem: pred_cnodes_per_sem[pred_cnode.label] = [] pred_cnodes_per_sem[pred_cnode.label].append(node_id) matched_gt_idx = [] matched_pred_idx = [] matched_gt2pred = np.zeros((conf.max_child_num), dtype=np.int32) for sem in intersect_sem: gt_boxes = torch.cat([ gt_node.children[cid].get_box_quat() for cid in gt_cnodes_per_sem[sem] ], dim=0).to(device) pred_boxes = torch.cat([ pred_node.children[cid].get_box_quat() for cid in pred_cnodes_per_sem[sem] ], dim=0).to(device) num_gt = gt_boxes.size(0) num_pred = pred_boxes.size(0) if num_gt == 1 and num_pred == 1: cur_matched_gt_idx = [0] cur_matched_pred_idx = [0] else: gt_boxes_tiled = gt_boxes.unsqueeze(dim=1).repeat( 1, num_pred, 1) pred_boxes_tiled = pred_boxes.unsqueeze(dim=0).repeat( num_gt, 1, 1) dmat = boxLoss(gt_boxes_tiled.view(-1, 10), pred_boxes_tiled.view(-1, 10)).view( -1, num_gt, num_pred).cpu() _, cur_matched_gt_idx, cur_matched_pred_idx = utils.linear_assignment( dmat) for i in range(len(cur_matched_gt_idx)): matched_gt_idx.append( gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]]) matched_pred_idx.append( pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]]) matched_gt2pred[gt_cnodes_per_sem[sem][ cur_matched_gt_idx[i]]] = pred_cnodes_per_sem[sem][ cur_matched_pred_idx[i]] struct_diff = 0.0 edge_both = 0 edge_gt = 0 edge_pred = 0 gt_binary_diff = 0.0 gt_binary_tot = 0 for i in range(len(gt_node.children)): if i not in matched_gt_idx: struct_diff += len(gt_node.children[i].boxes()) edge_gt += gt_node.children[i].get_subtree_edge_count() * 2 for i in range(len(pred_node.children)): if i not in matched_pred_idx: struct_diff += len(pred_node.children[i].boxes()) edge_pred += pred_node.children[i].get_subtree_edge_count() for i in range(len(matched_gt_idx)): gt_id = matched_gt_idx[i] pred_id = matched_pred_idx[i] cur_struct_diff, cur_edge_both, cur_edge_gt, cur_edge_pred, cur_gt_binary_diff, cur_gt_binary_tot = compute_struct_diff( gt_node.children[gt_id], pred_node.children[pred_id]) gt_binary_diff += cur_gt_binary_diff gt_binary_tot += cur_gt_binary_tot struct_diff += cur_struct_diff edge_both += cur_edge_both edge_gt += cur_edge_gt edge_pred += cur_edge_pred pred_node.children[pred_id].part_id = gt_node.children[ gt_id].part_id if pred_node.edges is not None: edge_pred += len(pred_node.edges) if gt_node.edges is not None: edge_gt += len(gt_node.edges) * 2 pred_edges = np.zeros((conf.max_child_num, conf.max_child_num, len(conf.edge_types)), dtype=np.bool) for edge in pred_node.edges: pred_part_a_id = edge['part_a'] pred_part_b_id = edge['part_b'] edge_type_id = conf.edge_types.index(edge['type']) pred_edges[pred_part_a_id, pred_part_b_id, edge_type_id] = True for edge in gt_node.edges: gt_part_a_id = edge['part_a'] gt_part_b_id = edge['part_b'] edge_type_id = conf.edge_types.index(edge['type']) if gt_part_a_id in matched_gt_idx and gt_part_b_id in matched_gt_idx: pred_part_a_id = matched_gt2pred[gt_part_a_id] pred_part_b_id = matched_gt2pred[gt_part_b_id] edge_both += pred_edges[pred_part_a_id, pred_part_b_id, edge_type_id] edge_both += pred_edges[pred_part_b_id, pred_part_a_id, edge_type_id] # gt edges eval obb1 = pred_node.children[pred_part_a_id].box.cpu( ).numpy() obb_quat1 = pred_node.children[ pred_part_a_id].get_box_quat().cpu().numpy() mesh_v1, mesh_f1 = utils.gen_obb_mesh(obb1) pc1 = utils.sample_pc(mesh_v1, mesh_f1, n_points=500) pc1 = torch.tensor(pc1, dtype=torch.float32, device=device) obb2 = pred_node.children[pred_part_b_id].box.cpu( ).numpy() obb_quat2 = pred_node.children[ pred_part_b_id].get_box_quat().cpu().numpy() mesh_v2, mesh_f2 = utils.gen_obb_mesh(obb2) pc2 = utils.sample_pc(mesh_v2, mesh_f2, n_points=500) pc2 = torch.tensor(pc2, dtype=torch.float32, device=device) if edge_type_id == 0: # ADJ dist1, dist2 = chamferLoss(pc1.view(1, -1, 3), pc2.view(1, -1, 3)) gt_binary_diff += (dist1.sqrt().min().item() + dist2.sqrt().min().item()) / 2 else: # SYM if edge_type_id == 2: # TRANS_SYM mat1to2, _ = compute_sym.compute_trans_sym( obb_quat1.reshape(-1), obb_quat2.reshape(-1)) elif edge_type_id == 3: # REF_SYM mat1to2, _ = compute_sym.compute_ref_sym( obb_quat1.reshape(-1), obb_quat2.reshape(-1)) elif edge_type_id == 1: # ROT_SYM mat1to2, _ = compute_sym.compute_rot_sym( obb_quat1.reshape(-1), obb_quat2.reshape(-1)) else: assert 'ERROR: unknown symmetry type: %s' % edge[ 'type'] mat1to2 = torch.tensor(mat1to2, dtype=torch.float32, device=device) transformed_pc1 = pc1.matmul(torch.transpose(mat1to2[:, :3], 0, 1)) + \ mat1to2[:, 3].unsqueeze(dim=0).repeat(pc1.size(0), 1) dist1, dist2 = chamferLoss( transformed_pc1.view(1, -1, 3), pc2.view(1, -1, 3)) loss = (dist1.sqrt().mean() + dist2.sqrt().mean()) / 2 gt_binary_diff += loss.item() gt_binary_tot += 1 return struct_diff, edge_both, edge_gt, edge_pred, gt_binary_diff, gt_binary_tot
def node_recon_loss(self, node_latent, gt_node): if gt_node.is_leaf: box = self.box_decoder(node_latent) box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1)) anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1)) is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) return {'box': box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 'exists': torch.zeros_like(box_loss), 'semantic': torch.zeros_like(box_loss), 'edge_exists': torch.zeros_like(box_loss), 'sym': torch.zeros_like(box_loss), 'adj': torch.zeros_like(box_loss)}, box, box else: child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \ self.child_decoder(node_latent) # generate box prediction for each child feature_len = node_latent.size(1) child_pred_boxes = self.box_decoder(child_feats.view(-1, feature_len)) num_child_parts = child_pred_boxes.size(0) # perform hungarian matching between pred boxes and gt boxes with torch.no_grad(): child_gt_boxes = torch.cat([child_node.get_box_quat().view(1, -1) for child_node in gt_node.children], dim=0) num_gt = child_gt_boxes.size(0) child_pred_boxes_tiled = child_pred_boxes.unsqueeze(dim=0).repeat(num_gt, 1, 1) child_gt_boxes_tiled = child_gt_boxes.unsqueeze(dim=1).repeat(1, num_child_parts, 1) dist_mat = self.boxLossEstimator(child_gt_boxes_tiled.view(-1, 10), child_pred_boxes_tiled.view(-1, 10)).view(-1, num_gt, num_child_parts) _, matched_gt_idx, matched_pred_idx = linear_assignment(dist_mat) # get edge ground truth edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors( edge_types=self.conf.edge_types, device=child_feats.device, type_onehot=False) gt2pred = {gt_idx: pred_idx for gt_idx, pred_idx in zip(matched_gt_idx, matched_pred_idx)} edge_exists_gt = torch.zeros_like(edge_exists_logits) sym_from = []; sym_to = []; sym_mat = []; sym_type = []; adj_from = []; adj_to = []; for i in range(edge_indices_gt.shape[1]//2): if edge_indices_gt[0, i, 0].item() not in gt2pred or edge_indices_gt[0, i, 1].item() not in gt2pred: """ one of the adjacent nodes of the current gt edge was not matched to any node in the prediction, ignore this edge """ continue # correlate gt edges to pred edges edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()] edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()] edge_exists_gt[:, edge_from_idx, edge_to_idx, edge_type_list_gt[0:1, i]] = 1 edge_exists_gt[:, edge_to_idx, edge_from_idx, edge_type_list_gt[0:1, i]] = 1 # compute binary edge parameters for each matched pred edge if edge_type_list_gt[0, i].item() == 0: # ADJ adj_from.append(edge_from_idx) adj_to.append(edge_to_idx) else: # SYM if edge_type_list_gt[0, i].item() == 1: # ROT_SYM mat1to2, mat2to1 = compute_sym.compute_rot_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy()) elif edge_type_list_gt[0, i].item() == 2: # TRANS_SYM mat1to2, mat2to1 = compute_sym.compute_trans_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy()) else: # REF_SYM mat1to2, mat2to1 = compute_sym.compute_ref_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy()) sym_from.append(edge_from_idx) sym_to.append(edge_to_idx) sym_mat.append(torch.tensor(mat1to2, dtype=torch.float32, device=self.conf.device).view(1, 3, 4)) sym_type.append(edge_type_list_gt[0, i].item()) # train the current node to be non-leaf is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) # train the current node box to gt all_boxes = []; all_leaf_boxes = []; box = self.box_decoder(node_latent) all_boxes.append(box) box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1)) anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1)) # gather information child_sem_gt_labels = [] child_sem_pred_logits = [] child_box_gt = [] child_box_pred = [] child_exists_gt = torch.zeros_like(child_exists_logits) for i in range(len(matched_gt_idx)): child_sem_gt_labels.append(gt_node.children[matched_gt_idx[i]].get_semantic_id()) child_sem_pred_logits.append(child_sem_logits[0, matched_pred_idx[i], :].view(1, -1)) child_box_gt.append(gt_node.children[matched_gt_idx[i]].get_box_quat().view(1, -1)) child_box_pred.append(child_pred_boxes[matched_pred_idx[i], :].view(1, -1)) child_exists_gt[:, matched_pred_idx[i], :] = 1 # train semantic labels child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0) child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, \ device=child_sem_pred_logits.device) semantic_loss = self.semCELoss(child_sem_pred_logits, child_sem_gt_labels) semantic_loss = semantic_loss.sum() # train unused boxes to zeros unmatched_boxes = [] for i in range(num_child_parts): if i not in matched_pred_idx: unmatched_boxes.append(child_pred_boxes[i, 3:6].view(1, -1)) if len(unmatched_boxes) > 0: unmatched_boxes = torch.cat(unmatched_boxes, dim=0) unused_box_loss = unmatched_boxes.pow(2).sum() * 0.01 else: unused_box_loss = 0.0 # train exist scores child_exists_loss = F.binary_cross_entropy_with_logits(\ input=child_exists_logits, target=child_exists_gt, reduction='none') child_exists_loss = child_exists_loss.sum() # train edge exists scores edge_exists_loss = F.binary_cross_entropy_with_logits(\ input=edge_exists_logits, target=edge_exists_gt, reduction='none') edge_exists_loss = edge_exists_loss.sum() # rescale to make it comparable to other losses, # which are in the order of the number of child nodes edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2]*edge_exists_gt.shape[3]) # compute and train binary losses sym_loss = 0 if len(sym_from) > 0: sym_from_th = torch.tensor(sym_from, dtype=torch.long, device=self.conf.device) obb_from = child_pred_boxes[sym_from_th, :] with torch.no_grad(): reweight_from = get_surface_reweighting_batch(obb_from[:, 3:6], self.unit_cube.size(0)) pc_from = transform_pc_batch(self.unit_cube, obb_from) sym_to_th = torch.tensor(sym_to, dtype=torch.long, device=self.conf.device) obb_to = child_pred_boxes[sym_to_th, :] with torch.no_grad(): reweight_to = get_surface_reweighting_batch(obb_to[:, 3:6], self.unit_cube.size(0)) pc_to = transform_pc_batch(self.unit_cube, obb_to) sym_mat_th = torch.cat(sym_mat, dim=0) transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \ sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1) dist1, dist2 = self.chamferLoss(transformed_pc_from, pc_to) loss1 = (dist1 * reweight_from).sum(dim=1) / (reweight_from.sum(dim=1) + 1e-12) loss2 = (dist2 * reweight_to).sum(dim=1) / (reweight_to.sum(dim=1) + 1e-12) loss = loss1 + loss2 sym_loss = loss.sum() adj_loss = 0 if len(adj_from) > 0: adj_from_th = torch.tensor(adj_from, dtype=torch.long, device=self.conf.device) obb_from = child_pred_boxes[adj_from_th, :] pc_from = transform_pc_batch(self.unit_cube, obb_from) adj_to_th = torch.tensor(adj_to, dtype=torch.long, device=self.conf.device) obb_to = child_pred_boxes[adj_to_th, :] pc_to = transform_pc_batch(self.unit_cube, obb_to) dist1, dist2 = self.chamferLoss(pc_from, pc_to) loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0]) adj_loss = loss.sum() # call children + aggregate losses pred2allboxes = dict(); pred2allleafboxes = dict(); for i in range(len(matched_gt_idx)): child_losses, child_all_boxes, child_all_leaf_boxes = self.node_recon_loss( child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]]) pred2allboxes[matched_pred_idx[i]] = child_all_boxes pred2allleafboxes[matched_pred_idx[i]] = child_all_leaf_boxes all_boxes.append(child_all_boxes) all_leaf_boxes.append(child_all_leaf_boxes) box_loss = box_loss + child_losses['box'] anchor_loss = anchor_loss + child_losses['anchor'] is_leaf_loss = is_leaf_loss + child_losses['leaf'] child_exists_loss = child_exists_loss + child_losses['exists'] semantic_loss = semantic_loss + child_losses['semantic'] edge_exists_loss = edge_exists_loss + child_losses['edge_exists'] sym_loss = sym_loss + child_losses['sym'] adj_loss = adj_loss + child_losses['adj'] # for sym-edges, train subtree to be symmetric for i in range(len(sym_from)): s1 = pred2allboxes[sym_from[i]].size(0) s2 = pred2allboxes[sym_to[i]].size(0) if s1 > 1 and s2 > 1: obbs_from = pred2allboxes[sym_from[i]][1:, :] obbs_to = pred2allboxes[sym_to[i]][1:, :] pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(-1, 3) pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(-1, 3) transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \ sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1) dist1, dist2 = self.chamferLoss(transformed_pc_from.view(1, -1, 3), pc_to.view(1, -1, 3)) sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2 # for adj-edges, train leaf-nodes in subtrees to be adjacent for i in range(len(adj_from)): if pred2allboxes[adj_from[i]].size(0) > pred2allleafboxes[adj_from[i]].size(0) \ or pred2allboxes[adj_to[i]].size(0) > pred2allleafboxes[adj_to[i]].size(0): obbs_from = pred2allleafboxes[adj_from[i]] obbs_to = pred2allleafboxes[adj_to[i]] pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(1, -1, 3) pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(1, -1, 3) dist1, dist2 = self.chamferLoss(pc_from, pc_to) adj_loss += dist1.min() + dist2.min() return {'box': box_loss + unused_box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 'exists': child_exists_loss, 'semantic': semantic_loss, 'edge_exists': edge_exists_loss, 'sym': sym_loss, 'adj': adj_loss}, \ torch.cat(all_boxes, dim=0), torch.cat(all_leaf_boxes, dim=0)