def compute_shape_diff(source, target, device=None): dnode = Tree.DiffNode('SAME') # get box_diff for every node source_box_quat = source.get_box_quat().squeeze().cpu().numpy() target_box_quat = target.get_box_quat().squeeze().cpu().numpy() dnode.box_diff = torch.tensor(compute_box_diff(source_box_quat, target_box_quat), dtype=torch.float32, device=source.box.device).view(1, -1) if source.is_leaf: if target.is_leaf: dnode.node_type = 'LEAF' else: for cnode in target.children: cdiffnode = Tree.DiffNode('ADD') cdiffnode.subnode = cnode dnode.children.append(cdiffnode) else: if target.is_leaf: for cnode in source.children: dnode.children.append(Tree.DiffNode('DEL')) else: source_sem = set([node.label for node in source.children]) target_sem = set([node.label for node in target.children]) intersect_sem = set.intersection(source_sem, target_sem) source_cnodes_per_sem = dict() for node_id, cnode in enumerate(source.children): if cnode.label in intersect_sem: if cnode.label not in source_cnodes_per_sem: source_cnodes_per_sem[cnode.label] = [] source_cnodes_per_sem[cnode.label].append(node_id) target_cnodes_per_sem = dict() for node_id, cnode in enumerate(target.children): if cnode.label in intersect_sem: if cnode.label not in target_cnodes_per_sem: target_cnodes_per_sem[cnode.label] = [] target_cnodes_per_sem[cnode.label].append(node_id) matched_source_ids = [] matched_target_ids = [] matched_source2target = dict() unmatched_source_ids = set(range(len(source.children))) unmatched_target_ids = set(range(len(target.children))) for sem in intersect_sem: source_boxes = torch.cat([ source.children[cid].get_box_quat() for cid in source_cnodes_per_sem[sem] ], dim=0) target_boxes = torch.cat([ target.children[cid].get_box_quat() for cid in target_cnodes_per_sem[sem] ], dim=0) num_source = source_boxes.size(0) num_target = target_boxes.size(0) source_boxes_tiled = source_boxes.unsqueeze(dim=1).repeat( 1, num_target, 1) target_boxes_tiled = target_boxes.unsqueeze(dim=0).repeat( num_source, 1, 1) dmat = boxLoss(source_boxes_tiled.view(-1, 10), target_boxes_tiled.view(-1, 10)).view( -1, num_source, num_target).cpu() _, cur_matched_source_ids, cur_matched_target_ids = linear_assignment( dmat) for i in range(len(cur_matched_source_ids)): source_node_id = source_cnodes_per_sem[sem][ cur_matched_source_ids[i]] matched_source_ids.append(source_node_id) unmatched_source_ids.remove(source_node_id) target_node_id = target_cnodes_per_sem[sem][ cur_matched_target_ids[i]] matched_target_ids.append(target_node_id) unmatched_target_ids.remove(target_node_id) matched_source2target[source_node_id] = target_node_id for node_id, cnode in enumerate(source.children): if node_id in unmatched_source_ids: dnode.children.append(Tree.DiffNode('DEL')) else: dnode.children.append( Tree.compute_shape_diff( source.children[node_id], target.children[ matched_source2target[node_id]])) for i in unmatched_target_ids: cdiffnode = Tree.DiffNode('ADD') cdiffnode.subnode = target.children[i] dnode.children.append(cdiffnode) return dnode
def struct_dist(gt_node, pred_node): if gt_node.is_leaf: if pred_node.is_leaf: return 0 else: return len(pred_node.boxes()) - 1 else: if pred_node.is_leaf: return len(gt_node.boxes()) - 1 else: gt_sem = set([node.label for node in gt_node.children]) pred_sem = set([node.label for node in pred_node.children]) intersect_sem = set.intersection(gt_sem, pred_sem) gt_cnodes_per_sem = dict() for node_id, gt_cnode in enumerate(gt_node.children): if gt_cnode.label in intersect_sem: if gt_cnode.label not in gt_cnodes_per_sem: gt_cnodes_per_sem[gt_cnode.label] = [] gt_cnodes_per_sem[gt_cnode.label].append(node_id) pred_cnodes_per_sem = dict() for node_id, pred_cnode in enumerate(pred_node.children): if pred_cnode.label in intersect_sem: if pred_cnode.label not in pred_cnodes_per_sem: pred_cnodes_per_sem[pred_cnode.label] = [] pred_cnodes_per_sem[pred_cnode.label].append(node_id) matched_gt_idx = [] matched_pred_idx = [] matched_gt2pred = np.zeros((100), dtype=np.int32) for sem in intersect_sem: gt_boxes = torch.cat([ gt_node.children[cid].get_box_quat() for cid in gt_cnodes_per_sem[sem] ], dim=0) pred_boxes = torch.cat([ pred_node.children[cid].get_box_quat() for cid in pred_cnodes_per_sem[sem] ], dim=0) num_gt = gt_boxes.size(0) num_pred = pred_boxes.size(0) if num_gt == 1 and num_pred == 1: cur_matched_gt_idx = [0] cur_matched_pred_idx = [0] else: gt_boxes_tiled = gt_boxes.unsqueeze(dim=1).repeat( 1, num_pred, 1) pred_boxes_tiled = pred_boxes.unsqueeze(dim=0).repeat( num_gt, 1, 1) dmat = box_dist(gt_boxes_tiled.view(-1, 10), pred_boxes_tiled.view(-1, 10)).view( -1, num_gt, num_pred) _, cur_matched_gt_idx, cur_matched_pred_idx = utils.linear_assignment( dmat) for i in range(len(cur_matched_gt_idx)): matched_gt_idx.append( gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]]) matched_pred_idx.append( pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]]) matched_gt2pred[gt_cnodes_per_sem[sem][ cur_matched_gt_idx[i]]] = pred_cnodes_per_sem[sem][ cur_matched_pred_idx[i]] struct_diff = 0.0 for i in range(len(gt_node.children)): if i not in matched_gt_idx: struct_diff += len(gt_node.children[i].boxes()) for i in range(len(pred_node.children)): if i not in matched_pred_idx: struct_diff += len(pred_node.children[i].boxes()) for i in range(len(matched_gt_idx)): gt_id = matched_gt_idx[i] pred_id = matched_pred_idx[i] struct_diff += struct_dist(gt_node.children[gt_id], pred_node.children[pred_id]) return struct_diff
def node_recon_loss(self, node_latent, gt_node): if gt_node.is_leaf: box = self.box_decoder(node_latent) box_loss = self.boxLossEstimator( box, gt_node.get_box_quat().view(1, -1)) is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator( is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) return { 'box': box_loss, 'leaf': is_leaf_loss, 'exists': torch.zeros_like(box_loss), 'semantic': torch.zeros_like(box_loss) } else: child_feats, child_sem_logits, child_exists_logits = \ self.child_decoder(node_latent) # generate box prediction for each child feature_len = node_latent.size(1) child_pred_boxes = self.box_decoder( child_feats.view(-1, feature_len)) num_child_parts = child_pred_boxes.size(0) # perform hungarian matching between pred boxes and gt boxes with torch.no_grad(): child_gt_boxes = torch.cat([ child_node.get_box_quat() for child_node in gt_node.children ], dim=0) num_gt = child_gt_boxes.size(0) child_pred_boxes_tiled = child_pred_boxes.unsqueeze( dim=0).repeat(num_gt, 1, 1) child_gt_boxes_tiled = child_gt_boxes.unsqueeze(dim=1).repeat( 1, num_child_parts, 1) dist_mat = self.boxLossEstimator( child_gt_boxes_tiled.view(-1, 10), child_pred_boxes_tiled.view(-1, 10)).view( -1, num_gt, num_child_parts) _, matched_gt_idx, matched_pred_idx = linear_assignment( dist_mat) # train the current node to be non-leaf is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator( is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) # train the current node box to gt box = self.box_decoder(node_latent) box_loss = self.boxLossEstimator( box, gt_node.get_box_quat().view(1, -1)) # gather information child_sem_gt_labels = [] child_sem_pred_logits = [] child_box_gt = [] child_box_pred = [] child_exists_gt = torch.zeros_like(child_exists_logits) for i in range(len(matched_gt_idx)): child_sem_gt_labels.append( gt_node.children[matched_gt_idx[i]].get_semantic_id()) child_sem_pred_logits.append( child_sem_logits[0, matched_pred_idx[i], :].view(1, -1)) child_box_gt.append( gt_node.children[matched_gt_idx[i]].get_box_quat()) child_box_pred.append( child_pred_boxes[matched_pred_idx[i], :].view(1, -1)) child_exists_gt[:, matched_pred_idx[i], :] = 1 # train semantic labels child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0) child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, \ device=child_sem_pred_logits.device) semantic_loss = self.ceLoss(child_sem_pred_logits, child_sem_gt_labels) semantic_loss = semantic_loss.sum() # train unused boxes to zeros unmatched_boxes = [] for i in range(num_child_parts): if i not in matched_pred_idx: unmatched_boxes.append(child_pred_boxes[i, 3:6].view(1, -1)) if len(unmatched_boxes) > 0: unmatched_boxes = torch.cat(unmatched_boxes, dim=0) unused_box_loss = unmatched_boxes.pow(2).sum() * 0.01 else: unused_box_loss = 0.0 # train exist scores child_exists_loss = F.binary_cross_entropy_with_logits( input=child_exists_logits, target=child_exists_gt, reduction='none') child_exists_loss = child_exists_loss.sum() # calculate children + aggregate losses for i in range(len(matched_gt_idx)): child_losses = self.node_recon_loss(\ child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]]) box_loss = box_loss + child_losses['box'] is_leaf_loss = is_leaf_loss + child_losses['leaf'] child_exists_loss = child_exists_loss + child_losses['exists'] semantic_loss = semantic_loss + child_losses['semantic'] return { 'box': box_loss + unused_box_loss, 'leaf': is_leaf_loss, 'exists': child_exists_loss, 'semantic': semantic_loss }
def node_diff_recon_loss(self, z, z_skip, obj_node, gt_diff): # initialize all losses to zeros box_loss = torch.zeros(1, device=z.device) is_leaf_loss = torch.zeros(1, device=z.device) child_exists_loss = torch.zeros(1, device=z.device) semantic_loss = torch.zeros(1, device=z.device) diffnode_type_loss = torch.zeros(1, device=z.device) diffnode_box_loss = torch.zeros(1, device=z.device) # DEL/SAME/LEAF node_diff_types = [] gt_node_diff_types = [] for i, cnode in enumerate(obj_node.children): cdiff = gt_diff.children[i] node_diff_feat = self.node_diff_feature_extractor( torch.cat([ z, z_skip, cnode.box_feature, cnode.feature, cnode.get_semantic_one_hot() ], dim=1)) node_diff_types.append(self.node_diff_classifier(node_diff_feat)) gt_node_diff_types.append(cdiff.get_node_type_id()) if cdiff.node_type == 'SAME': child_losses = self.node_diff_recon_loss(z=node_diff_feat, z_skip=z_skip, obj_node=cnode, gt_diff=cdiff) diffnode_type_loss += child_losses['diffnode_type'] diffnode_box_loss += child_losses['diffnode_box'] box_loss += child_losses['box'] is_leaf_loss += child_losses['leaf'] child_exists_loss += child_losses['exists'] semantic_loss += child_losses['semantic'] if cdiff.node_type == 'LEAF' or cdiff.node_type == 'SAME': box2 = self.apply_box_diff( cnode.get_box_quat(), self.box_diff_decoder(node_diff_feat)) gt_box2 = self.apply_box_diff(cnode.get_box_quat(), cdiff.box_diff) diffnode_box_loss += self.boxLossEstimator(box2, gt_box2).sum() # diff node loss if len(obj_node.children) > 0: node_diff_types = torch.cat(node_diff_types, dim=0) gt_node_diff_types = torch.tensor(gt_node_diff_types, dtype=torch.int64, device=node_diff_types.device) diffnode_type_loss += self.ceLoss(node_diff_types, gt_node_diff_types).sum() # ADD add_child_feats, add_child_sem_logits, add_child_exists_logits = self.add_child_decoder( z) add_child_boxes = self.box_decoder( add_child_feats.view(-1, self.conf.feature_size)) num_pred = add_child_boxes.size(0) child_exists_gt = torch.zeros_like(add_child_exists_logits) with torch.no_grad(): add_child_gt_boxes = [] for i in range(len(obj_node.children), len(gt_diff.children)): add_child_gt_boxes.append( gt_diff.children[i].subnode.get_box_quat()) num_gt = len(add_child_gt_boxes) if num_gt > 0: add_child_gt_boxes = torch.cat(add_child_gt_boxes, dim=0) pred_tiled = add_child_boxes.unsqueeze(0).repeat(num_gt, 1, 1) gt_tiled = add_child_gt_boxes.unsqueeze(1).repeat( 1, num_pred, 1) dmat = self.boxLossEstimator(gt_tiled.view(-1, 10), pred_tiled.view(-1, 10)).view( -1, num_gt, num_pred) _, matched_gt_idx, matched_pred_idx = linear_assignment(dmat) if num_gt > 0: # gather information child_sem_gt_labels = [] child_sem_pred_logits = [] child_box_gt = [] child_box_pred = [] for i in range(len(matched_gt_idx)): child_sem_gt_labels.append(gt_diff.children[ matched_gt_idx[i] + len(obj_node.children)].subnode.get_semantic_id()) child_sem_pred_logits.append( add_child_sem_logits[0, matched_pred_idx[i], :].view(1, -1)) child_box_gt.append(gt_diff.children[ matched_gt_idx[i] + len(obj_node.children)].subnode.get_box_quat()) child_box_pred.append( add_child_boxes[matched_pred_idx[i], :].view(1, -1)) child_exists_gt[:, matched_pred_idx[i], :] = 1 # train add node subtree child_losses = self.node_recon_loss(add_child_feats[:, matched_pred_idx[i], :], \ gt_diff.children[matched_gt_idx[i]+len(obj_node.children)].subnode) box_loss += child_losses['box'] is_leaf_loss += child_losses['leaf'] child_exists_loss += child_losses['exists'] semantic_loss += child_losses['semantic'] # train semantic labels child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0) child_sem_gt_labels = torch.tensor( child_sem_gt_labels, dtype=torch.int64, device=add_child_sem_logits.device) semantic_loss += self.ceLoss(child_sem_pred_logits, child_sem_gt_labels).sum() # train unused boxes to zeros unmatched_boxes = [] for i in range(num_pred): if num_gt == 0 or i not in matched_pred_idx: unmatched_boxes.append(add_child_boxes[i, 3:6].view(1, -1)) if len(unmatched_boxes) > 0: unmatched_boxes = torch.cat(unmatched_boxes, dim=0) box_loss += unmatched_boxes.pow(2).sum() * 0.01 # train exist scores child_exists_loss += F.binary_cross_entropy_with_logits( input=add_child_exists_logits, target=child_exists_gt, reduction='none').sum() return {'box': box_loss, 'leaf': is_leaf_loss, 'exists': child_exists_loss, 'semantic': semantic_loss, \ 'diffnode_type': diffnode_type_loss, 'diffnode_box': diffnode_box_loss}
def node_recon_loss(self, node_latent, gt_node): gt_geo = gt_node.geo gt_geo_feat = gt_node.geo_feat gt_center = gt_geo.mean(dim=1) gt_scale = (gt_geo - gt_center.unsqueeze(dim=1).repeat( 1, self.conf.num_point, 1)).pow(2).sum(dim=2).max( dim=1)[0].sqrt().view(1, 1) geo_local, geo_center, geo_scale, geo_feat = self.node_decoder( node_latent) geo_global = self.geoToGlobal(geo_local, geo_center, geo_scale) # geo loss for the current part latent_loss = self.mseLoss(geo_feat, gt_geo_feat).mean() center_loss = self.mseLoss(geo_center, gt_center).mean() scale_loss = self.mseLoss(geo_scale, gt_scale).mean() geo_loss = self.chamferDist(geo_global, gt_geo) if gt_node.is_leaf: is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator( is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) return { 'leaf': is_leaf_loss, 'geo': geo_loss, 'center': center_loss, 'scale': scale_loss, 'latent': latent_loss, 'exists': torch.zeros_like(geo_loss), 'semantic': torch.zeros_like(geo_loss), 'edge_exists': torch.zeros_like(geo_loss), 'sym': torch.zeros_like(geo_loss), 'adj': torch.zeros_like(geo_loss) }, geo_global, geo_global else: child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \ self.child_decoder(node_latent) # generate geo prediction for each child feature_len = node_latent.size(1) all_geo = [] all_leaf_geo = [] all_geo.append(geo_global) child_pred_geo_local, child_pred_geo_center, child_pred_geo_scale, _ = \ self.node_decoder(child_feats.view(-1, feature_len)) child_pred_geo = self.geoToGlobal(child_pred_geo_local, child_pred_geo_center, child_pred_geo_scale) num_pred = child_pred_geo.size(0) # perform hungarian matching between pred geo and gt geo with torch.no_grad(): child_gt_geo = torch.cat( [child_node.geo for child_node in gt_node.children], dim=0) num_gt = child_gt_geo.size(0) child_pred_geo_tiled = child_pred_geo.unsqueeze(dim=0).repeat( num_gt, 1, 1, 1) child_gt_geo_tiled = child_gt_geo.unsqueeze(dim=1).repeat( 1, num_pred, 1, 1) dist_mat = self.chamferDist(child_pred_geo_tiled.view(-1, self.conf.num_point, 3), \ child_gt_geo_tiled.view(-1, self.conf.num_point, 3)).view(1, num_gt, num_pred) _, matched_gt_idx, matched_pred_idx = linear_assignment( dist_mat) # get edge ground truth edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors( edge_types=self.conf.edge_types, device=child_feats.device, type_onehot=False) gt2pred = { gt_idx: pred_idx for gt_idx, pred_idx in zip(matched_gt_idx, matched_pred_idx) } edge_exists_gt = torch.zeros_like(edge_exists_logits) sym_from = [] sym_to = [] sym_mat = [] sym_type = [] adj_from = [] adj_to = [] for i in range(edge_indices_gt.shape[1] // 2): if edge_indices_gt[ 0, i, 0].item() not in gt2pred or edge_indices_gt[ 0, i, 1].item() not in gt2pred: """ one of the adjacent nodes of the current gt edge was not matched to any node in the prediction, ignore this edge """ continue # correlate gt edges to pred edges edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()] edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()] edge_exists_gt[:, edge_from_idx, edge_to_idx, edge_type_list_gt[0:1, i]] = 1 edge_exists_gt[:, edge_to_idx, edge_from_idx, edge_type_list_gt[0:1, i]] = 1 # compute binary edge parameters for each matched pred edge if edge_type_list_gt[0, i].item() == 0: # ADJ adj_from.append(edge_from_idx) adj_to.append(edge_to_idx) else: # SYM obb1 = compute_sym.compute_obb( child_pred_geo[edge_from_idx].cpu().detach().numpy( )) obb2 = compute_sym.compute_obb( child_pred_geo[edge_to_idx].cpu().detach().numpy()) if edge_type_list_gt[0, i].item() == 1: # ROT_SYM mat1to2, mat2to1 = compute_sym.compute_rot_sym( obb1, obb2) elif edge_type_list_gt[0, i].item() == 2: # TRANS_SYM mat1to2, mat2to1 = compute_sym.compute_trans_sym( obb1, obb2) else: # REF_SYM mat1to2, mat2to1 = compute_sym.compute_ref_sym( obb1, obb2) sym_from.append(edge_from_idx) sym_to.append(edge_to_idx) sym_mat.append( torch.tensor(mat1to2, dtype=torch.float32, device=self.conf.device).view( 1, 3, 4)) sym_type.append(edge_type_list_gt[0, i].item()) # train the current node to be non-leaf is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator( is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) # gather information child_sem_gt_labels = [] child_sem_pred_logits = [] child_exists_gt = torch.zeros_like(child_exists_logits) for i in range(len(matched_gt_idx)): child_sem_gt_labels.append( gt_node.children[matched_gt_idx[i]].get_semantic_id()) child_sem_pred_logits.append( child_sem_logits[0, matched_pred_idx[i], :].view(1, -1)) child_exists_gt[:, matched_pred_idx[i], :] = 1 # train semantic labels child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0) child_sem_gt_labels = torch.tensor( child_sem_gt_labels, dtype=torch.int64, device=child_sem_pred_logits.device) semantic_loss = self.semCELoss(child_sem_pred_logits, child_sem_gt_labels) semantic_loss = semantic_loss.sum() # train unused geo to zero unmatched_scales = [] for i in range(num_pred): if i not in matched_pred_idx: unmatched_scales.append(child_pred_geo_scale[i:i + 1]) if len(unmatched_scales) > 0: unmatched_scales = torch.cat(unmatched_scales, dim=0) unused_geo_loss = unmatched_scales.pow(2).sum() else: unused_geo_loss = 0.0 # train exist scores child_exists_loss = F.binary_cross_entropy_with_logits(\ input=child_exists_logits, target=child_exists_gt, reduction='none') child_exists_loss = child_exists_loss.sum() # train edge exists scores edge_exists_loss = F.binary_cross_entropy_with_logits(\ input=edge_exists_logits, target=edge_exists_gt, reduction='none') edge_exists_loss = edge_exists_loss.sum() # rescale to make it comparable to other losses, # which are in the order of the number of child nodes edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2] * edge_exists_gt.shape[3]) # compute and train binary losses sym_loss = 0 if len(sym_from) > 0: sym_from_th = torch.tensor(sym_from, dtype=torch.long, device=self.conf.device) pc_from = child_pred_geo[sym_from_th, :] sym_to_th = torch.tensor(sym_to, dtype=torch.long, device=self.conf.device) pc_to = child_pred_geo[sym_to_th, :] sym_mat_th = torch.cat(sym_mat, dim=0) transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \ sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1) loss = self.chamferDist(transformed_pc_from, pc_to) * 2 sym_loss = loss.sum() adj_loss = 0 if len(adj_from) > 0: adj_from_th = torch.tensor(adj_from, dtype=torch.long, device=self.conf.device) pc_from = child_pred_geo[adj_from_th, :] adj_to_th = torch.tensor(adj_to, dtype=torch.long, device=self.conf.device) pc_to = child_pred_geo[adj_to_th, :] dist1, dist2 = self.chamferLoss(pc_from, pc_to) loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0]) adj_loss = loss.sum() # call children + aggregate losses pred2allgeo = dict() pred2allleafgeo = dict() for i in range(len(matched_gt_idx)): child_losses, child_all_geo, child_all_leaf_geo = self.node_recon_loss( child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]]) pred2allgeo[matched_pred_idx[i]] = child_all_geo pred2allleafgeo[matched_pred_idx[i]] = child_all_leaf_geo all_geo.append(child_all_geo) all_leaf_geo.append(child_all_leaf_geo) latent_loss = latent_loss + child_losses['latent'] geo_loss = geo_loss + child_losses['geo'] center_loss = center_loss + child_losses['center'] scale_loss = scale_loss + child_losses['scale'] is_leaf_loss = is_leaf_loss + child_losses['leaf'] child_exists_loss = child_exists_loss + child_losses['exists'] semantic_loss = semantic_loss + child_losses['semantic'] edge_exists_loss = edge_exists_loss + child_losses[ 'edge_exists'] sym_loss = sym_loss + child_losses['sym'] adj_loss = adj_loss + child_losses['adj'] # for sym-edges, train subtree to be symmetric for i in range(len(sym_from)): s1 = pred2allgeo[sym_from[i]].size(0) s2 = pred2allgeo[sym_to[i]].size(0) if s1 > 1 and s2 > 1: pc_from = pred2allgeo[sym_from[i]][1:, :, :].view(-1, 3) pc_to = pred2allgeo[sym_to[i]][1:, :, :].view(-1, 3) transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \ sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1) dist1, dist2 = self.chamferLoss( transformed_pc_from.view(1, -1, 3), pc_to.view(1, -1, 3)) sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2 # for adj-edges, train leaf-nodes in subtrees to be adjacent for i in range(len(adj_from)): if pred2allgeo[adj_from[i]].size(0) > pred2allleafgeo[adj_from[i]].size(0) \ or pred2allgeo[adj_to[i]].size(0) > pred2allleafgeo[adj_to[i]].size(0): pc_from = pred2allleafgeo[adj_from[i]].view(1, -1, 3) pc_to = pred2allleafgeo[adj_to[i]].view(1, -1, 3) dist1, dist2 = self.chamferLoss(pc_from, pc_to) adj_loss += dist1.min() + dist2.min() return { 'leaf': is_leaf_loss, 'geo': geo_loss + unused_geo_loss, 'center': center_loss, 'scale': scale_loss, 'latent': latent_loss, 'exists': child_exists_loss, 'semantic': semantic_loss, 'edge_exists': edge_exists_loss, 'sym': sym_loss, 'adj': adj_loss }, torch.cat(all_geo, dim=0), torch.cat(all_leaf_geo, dim=0)
def compute_struct_diff(gt_node, pred_node): if gt_node.is_leaf: if pred_node.is_leaf: return 0, 0, 0, 0, 0, 0 else: return len(pred_node.boxes())-1, 0, 0, pred_node.get_subtree_edge_count(), 0, 0 else: if pred_node.is_leaf: return len(gt_node.boxes())-1, 0, gt_node.get_subtree_edge_count() * 2, 0, 0, 0 else: gt_sem = set([node.label for node in gt_node.children]) pred_sem = set([node.label for node in pred_node.children]) intersect_sem = set.intersection(gt_sem, pred_sem) gt_cnodes_per_sem = dict() for node_id, gt_cnode in enumerate(gt_node.children): if gt_cnode.label in intersect_sem: if gt_cnode.label not in gt_cnodes_per_sem: gt_cnodes_per_sem[gt_cnode.label] = [] gt_cnodes_per_sem[gt_cnode.label].append(node_id) pred_cnodes_per_sem = dict() for node_id, pred_cnode in enumerate(pred_node.children): if pred_cnode.label in intersect_sem: if pred_cnode.label not in pred_cnodes_per_sem: pred_cnodes_per_sem[pred_cnode.label] = [] pred_cnodes_per_sem[pred_cnode.label].append(node_id) matched_gt_idx = []; matched_pred_idx = []; matched_gt2pred = np.zeros((conf.max_child_num), dtype=np.int32) for sem in intersect_sem: gt_boxes = torch.cat([gt_node.children[cid].get_box_quat() for cid in gt_cnodes_per_sem[sem]], dim=0).to(device) pred_boxes = torch.cat([pred_node.children[cid].get_box_quat() for cid in pred_cnodes_per_sem[sem]], dim=0).to(device) num_gt = gt_boxes.size(0) num_pred = pred_boxes.size(0) if num_gt == 1 and num_pred == 1: cur_matched_gt_idx = [0] cur_matched_pred_idx = [0] else: gt_boxes_tiled = gt_boxes.unsqueeze(dim=1).repeat(1, num_pred, 1) pred_boxes_tiled = pred_boxes.unsqueeze(dim=0).repeat(num_gt, 1, 1) dmat = boxLoss(gt_boxes_tiled.view(-1, 10), pred_boxes_tiled.view(-1, 10)).view(-1, num_gt, num_pred).cpu() _, cur_matched_gt_idx, cur_matched_pred_idx = utils.linear_assignment(dmat) for i in range(len(cur_matched_gt_idx)): matched_gt_idx.append(gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]]) matched_pred_idx.append(pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]]) matched_gt2pred[gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]]] = pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]] struct_diff = 0.0; edge_both = 0; edge_gt = 0; edge_pred = 0; gt_binary_diff = 0.0; gt_binary_tot = 0; for i in range(len(gt_node.children)): if i not in matched_gt_idx: struct_diff += len(gt_node.children[i].boxes()) edge_gt += gt_node.children[i].get_subtree_edge_count() * 2 for i in range(len(pred_node.children)): if i not in matched_pred_idx: struct_diff += len(pred_node.children[i].boxes()) edge_pred += pred_node.children[i].get_subtree_edge_count() for i in range(len(matched_gt_idx)): gt_id = matched_gt_idx[i] pred_id = matched_pred_idx[i] cur_struct_diff, cur_edge_both, cur_edge_gt, cur_edge_pred, cur_gt_binary_diff, cur_gt_binary_tot = compute_struct_diff(gt_node.children[gt_id], pred_node.children[pred_id]) gt_binary_diff += cur_gt_binary_diff gt_binary_tot += cur_gt_binary_tot struct_diff += cur_struct_diff edge_both += cur_edge_both edge_gt += cur_edge_gt edge_pred += cur_edge_pred pred_node.children[pred_id].part_id = gt_node.children[gt_id].part_id if pred_node.edges is not None: edge_pred += len(pred_node.edges) if gt_node.edges is not None: edge_gt += len(gt_node.edges) * 2 pred_edges = np.zeros((conf.max_child_num, conf.max_child_num, len(conf.edge_types)), dtype=np.bool) for edge in pred_node.edges: pred_part_a_id = edge['part_a'] pred_part_b_id = edge['part_b'] edge_type_id = conf.edge_types.index(edge['type']) pred_edges[pred_part_a_id, pred_part_b_id, edge_type_id] = True for edge in gt_node.edges: gt_part_a_id = edge['part_a'] gt_part_b_id = edge['part_b'] edge_type_id = conf.edge_types.index(edge['type']) if gt_part_a_id in matched_gt_idx and gt_part_b_id in matched_gt_idx: pred_part_a_id = matched_gt2pred[gt_part_a_id] pred_part_b_id = matched_gt2pred[gt_part_b_id] edge_both += pred_edges[pred_part_a_id, pred_part_b_id, edge_type_id] edge_both += pred_edges[pred_part_b_id, pred_part_a_id, edge_type_id] # gt edges eval obb1 = pred_node.children[pred_part_a_id].box.cpu().numpy() obb_quat1 = pred_node.children[pred_part_a_id].get_box_quat().cpu().numpy() mesh_v1, mesh_f1 = utils.gen_obb_mesh(obb1) pc1 = utils.sample_pc(mesh_v1, mesh_f1, n_points=500) pc1 = torch.tensor(pc1, dtype=torch.float32, device=device) obb2 = pred_node.children[pred_part_b_id].box.cpu().numpy() obb_quat2 = pred_node.children[pred_part_b_id].get_box_quat().cpu().numpy() mesh_v2, mesh_f2 = utils.gen_obb_mesh(obb2) pc2 = utils.sample_pc(mesh_v2, mesh_f2, n_points=500) pc2 = torch.tensor(pc2, dtype=torch.float32, device=device) if edge_type_id == 0: # ADJ dist1, dist2 = chamferLoss(pc1.view(1, -1, 3), pc2.view(1, -1, 3)) gt_binary_diff += (dist1.sqrt().min().item() + dist2.sqrt().min().item()) / 2 else: # SYM if edge_type_id == 2: # TRANS_SYM mat1to2, _ = compute_sym.compute_trans_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1)) elif edge_type_id == 3: # REF_SYM mat1to2, _ = compute_sym.compute_ref_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1)) elif edge_type_id == 1: # ROT_SYM mat1to2, _ = compute_sym.compute_rot_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1)) else: assert 'ERROR: unknown symmetry type: %s' % edge['type'] mat1to2 = torch.tensor(mat1to2, dtype=torch.float32, device=device) transformed_pc1 = pc1.matmul(torch.transpose(mat1to2[:, :3], 0, 1)) + \ mat1to2[:, 3].unsqueeze(dim=0).repeat(pc1.size(0), 1) dist1, dist2 = chamferLoss(transformed_pc1.view(1, -1, 3), pc2.view(1, -1, 3)) loss = (dist1.sqrt().mean() + dist2.sqrt().mean()) / 2 gt_binary_diff += loss.item() gt_binary_tot += 1 return struct_diff, edge_both, edge_gt, edge_pred, gt_binary_diff, gt_binary_tot
def node_recon_loss(self, node_latent, gt_node): if gt_node.is_leaf: box = self.box_decoder(node_latent) box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1)) anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1)) is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) return {'box': box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 'exists': torch.zeros_like(box_loss), 'semantic': torch.zeros_like(box_loss), 'edge_exists': torch.zeros_like(box_loss), 'sym': torch.zeros_like(box_loss), 'adj': torch.zeros_like(box_loss)}, box, box else: child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \ self.child_decoder(node_latent) # generate box prediction for each child feature_len = node_latent.size(1) child_pred_boxes = self.box_decoder(child_feats.view(-1, feature_len)) num_child_parts = child_pred_boxes.size(0) # perform hungarian matching between pred boxes and gt boxes with torch.no_grad(): child_gt_boxes = torch.cat([child_node.get_box_quat().view(1, -1) for child_node in gt_node.children], dim=0) num_gt = child_gt_boxes.size(0) child_pred_boxes_tiled = child_pred_boxes.unsqueeze(dim=0).repeat(num_gt, 1, 1) child_gt_boxes_tiled = child_gt_boxes.unsqueeze(dim=1).repeat(1, num_child_parts, 1) dist_mat = self.boxLossEstimator(child_gt_boxes_tiled.view(-1, 10), child_pred_boxes_tiled.view(-1, 10)).view(-1, num_gt, num_child_parts) _, matched_gt_idx, matched_pred_idx = linear_assignment(dist_mat) # get edge ground truth edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors( edge_types=self.conf.edge_types, device=child_feats.device, type_onehot=False) gt2pred = {gt_idx: pred_idx for gt_idx, pred_idx in zip(matched_gt_idx, matched_pred_idx)} edge_exists_gt = torch.zeros_like(edge_exists_logits) sym_from = []; sym_to = []; sym_mat = []; sym_type = []; adj_from = []; adj_to = []; for i in range(edge_indices_gt.shape[1]//2): if edge_indices_gt[0, i, 0].item() not in gt2pred or edge_indices_gt[0, i, 1].item() not in gt2pred: """ one of the adjacent nodes of the current gt edge was not matched to any node in the prediction, ignore this edge """ continue # correlate gt edges to pred edges edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()] edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()] edge_exists_gt[:, edge_from_idx, edge_to_idx, edge_type_list_gt[0:1, i]] = 1 edge_exists_gt[:, edge_to_idx, edge_from_idx, edge_type_list_gt[0:1, i]] = 1 # compute binary edge parameters for each matched pred edge if edge_type_list_gt[0, i].item() == 0: # ADJ adj_from.append(edge_from_idx) adj_to.append(edge_to_idx) else: # SYM if edge_type_list_gt[0, i].item() == 1: # ROT_SYM mat1to2, mat2to1 = compute_sym.compute_rot_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy()) elif edge_type_list_gt[0, i].item() == 2: # TRANS_SYM mat1to2, mat2to1 = compute_sym.compute_trans_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy()) else: # REF_SYM mat1to2, mat2to1 = compute_sym.compute_ref_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy()) sym_from.append(edge_from_idx) sym_to.append(edge_to_idx) sym_mat.append(torch.tensor(mat1to2, dtype=torch.float32, device=self.conf.device).view(1, 3, 4)) sym_type.append(edge_type_list_gt[0, i].item()) # train the current node to be non-leaf is_leaf_logit = self.leaf_classifier(node_latent) is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1)) # train the current node box to gt all_boxes = []; all_leaf_boxes = []; box = self.box_decoder(node_latent) all_boxes.append(box) box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1)) anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1)) # gather information child_sem_gt_labels = [] child_sem_pred_logits = [] child_box_gt = [] child_box_pred = [] child_exists_gt = torch.zeros_like(child_exists_logits) for i in range(len(matched_gt_idx)): child_sem_gt_labels.append(gt_node.children[matched_gt_idx[i]].get_semantic_id()) child_sem_pred_logits.append(child_sem_logits[0, matched_pred_idx[i], :].view(1, -1)) child_box_gt.append(gt_node.children[matched_gt_idx[i]].get_box_quat().view(1, -1)) child_box_pred.append(child_pred_boxes[matched_pred_idx[i], :].view(1, -1)) child_exists_gt[:, matched_pred_idx[i], :] = 1 # train semantic labels child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0) child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, \ device=child_sem_pred_logits.device) semantic_loss = self.semCELoss(child_sem_pred_logits, child_sem_gt_labels) semantic_loss = semantic_loss.sum() # train unused boxes to zeros unmatched_boxes = [] for i in range(num_child_parts): if i not in matched_pred_idx: unmatched_boxes.append(child_pred_boxes[i, 3:6].view(1, -1)) if len(unmatched_boxes) > 0: unmatched_boxes = torch.cat(unmatched_boxes, dim=0) unused_box_loss = unmatched_boxes.pow(2).sum() * 0.01 else: unused_box_loss = 0.0 # train exist scores child_exists_loss = F.binary_cross_entropy_with_logits(\ input=child_exists_logits, target=child_exists_gt, reduction='none') child_exists_loss = child_exists_loss.sum() # train edge exists scores edge_exists_loss = F.binary_cross_entropy_with_logits(\ input=edge_exists_logits, target=edge_exists_gt, reduction='none') edge_exists_loss = edge_exists_loss.sum() # rescale to make it comparable to other losses, # which are in the order of the number of child nodes edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2]*edge_exists_gt.shape[3]) # compute and train binary losses sym_loss = 0 if len(sym_from) > 0: sym_from_th = torch.tensor(sym_from, dtype=torch.long, device=self.conf.device) obb_from = child_pred_boxes[sym_from_th, :] with torch.no_grad(): reweight_from = get_surface_reweighting_batch(obb_from[:, 3:6], self.unit_cube.size(0)) pc_from = transform_pc_batch(self.unit_cube, obb_from) sym_to_th = torch.tensor(sym_to, dtype=torch.long, device=self.conf.device) obb_to = child_pred_boxes[sym_to_th, :] with torch.no_grad(): reweight_to = get_surface_reweighting_batch(obb_to[:, 3:6], self.unit_cube.size(0)) pc_to = transform_pc_batch(self.unit_cube, obb_to) sym_mat_th = torch.cat(sym_mat, dim=0) transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \ sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1) dist1, dist2 = self.chamferLoss(transformed_pc_from, pc_to) loss1 = (dist1 * reweight_from).sum(dim=1) / (reweight_from.sum(dim=1) + 1e-12) loss2 = (dist2 * reweight_to).sum(dim=1) / (reweight_to.sum(dim=1) + 1e-12) loss = loss1 + loss2 sym_loss = loss.sum() adj_loss = 0 if len(adj_from) > 0: adj_from_th = torch.tensor(adj_from, dtype=torch.long, device=self.conf.device) obb_from = child_pred_boxes[adj_from_th, :] pc_from = transform_pc_batch(self.unit_cube, obb_from) adj_to_th = torch.tensor(adj_to, dtype=torch.long, device=self.conf.device) obb_to = child_pred_boxes[adj_to_th, :] pc_to = transform_pc_batch(self.unit_cube, obb_to) dist1, dist2 = self.chamferLoss(pc_from, pc_to) loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0]) adj_loss = loss.sum() # call children + aggregate losses pred2allboxes = dict(); pred2allleafboxes = dict(); for i in range(len(matched_gt_idx)): child_losses, child_all_boxes, child_all_leaf_boxes = self.node_recon_loss( child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]]) pred2allboxes[matched_pred_idx[i]] = child_all_boxes pred2allleafboxes[matched_pred_idx[i]] = child_all_leaf_boxes all_boxes.append(child_all_boxes) all_leaf_boxes.append(child_all_leaf_boxes) box_loss = box_loss + child_losses['box'] anchor_loss = anchor_loss + child_losses['anchor'] is_leaf_loss = is_leaf_loss + child_losses['leaf'] child_exists_loss = child_exists_loss + child_losses['exists'] semantic_loss = semantic_loss + child_losses['semantic'] edge_exists_loss = edge_exists_loss + child_losses['edge_exists'] sym_loss = sym_loss + child_losses['sym'] adj_loss = adj_loss + child_losses['adj'] # for sym-edges, train subtree to be symmetric for i in range(len(sym_from)): s1 = pred2allboxes[sym_from[i]].size(0) s2 = pred2allboxes[sym_to[i]].size(0) if s1 > 1 and s2 > 1: obbs_from = pred2allboxes[sym_from[i]][1:, :] obbs_to = pred2allboxes[sym_to[i]][1:, :] pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(-1, 3) pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(-1, 3) transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \ sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1) dist1, dist2 = self.chamferLoss(transformed_pc_from.view(1, -1, 3), pc_to.view(1, -1, 3)) sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2 # for adj-edges, train leaf-nodes in subtrees to be adjacent for i in range(len(adj_from)): if pred2allboxes[adj_from[i]].size(0) > pred2allleafboxes[adj_from[i]].size(0) \ or pred2allboxes[adj_to[i]].size(0) > pred2allleafboxes[adj_to[i]].size(0): obbs_from = pred2allleafboxes[adj_from[i]] obbs_to = pred2allleafboxes[adj_to[i]] pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(1, -1, 3) pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(1, -1, 3) dist1, dist2 = self.chamferLoss(pc_from, pc_to) adj_loss += dist1.min() + dist2.min() return {'box': box_loss + unused_box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 'exists': child_exists_loss, 'semantic': semantic_loss, 'edge_exists': edge_exists_loss, 'sym': sym_loss, 'adj': adj_loss}, \ torch.cat(all_boxes, dim=0), torch.cat(all_leaf_boxes, dim=0)