Exemplo n.º 1
0
def compute_binary_diff(pred_node):
    if pred_node.is_leaf:
        return 0, 0
    else:
        binary_diff = 0; binary_tot = 0;

        # all children
        for cnode in pred_node.children:
            cur_binary_diff, cur_binary_tot = compute_binary_diff(cnode)
            binary_diff += cur_binary_diff
            binary_tot += cur_binary_tot

        # current node
        if pred_node.edges is not None:
            for edge in pred_node.edges:
                pred_part_a_id = edge['part_a']
                obb1 = pred_node.children[pred_part_a_id].box.cpu().numpy()
                obb_quat1 = pred_node.children[pred_part_a_id].get_box_quat().cpu().numpy()
                mesh_v1, mesh_f1 = utils.gen_obb_mesh(obb1)
                pc1 = utils.sample_pc(mesh_v1, mesh_f1, n_points=500)
                pc1 = torch.tensor(pc1, dtype=torch.float32, device=device)
                pred_part_b_id = edge['part_b']
                obb2 = pred_node.children[pred_part_b_id].box.cpu().numpy()
                obb_quat2 = pred_node.children[pred_part_b_id].get_box_quat().cpu().numpy()
                mesh_v2, mesh_f2 = utils.gen_obb_mesh(obb2)
                pc2 = utils.sample_pc(mesh_v2, mesh_f2, n_points=500)
                pc2 = torch.tensor(pc2, dtype=torch.float32, device=device)
                if edge['type'] == 'ADJ':
                    dist1, dist2 = chamferLoss(pc1.view(1, -1, 3), pc2.view(1, -1, 3))
                    binary_diff += (dist1.sqrt().min().item() + dist2.sqrt().min().item()) / 2
                elif 'SYM' in edge['type']:
                    if edge['type'] == 'TRANS_SYM':
                        mat1to2, _ = compute_sym.compute_trans_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1))
                    elif edge['type'] == 'REF_SYM':
                        mat1to2, _ = compute_sym.compute_ref_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1))
                    elif edge['type'] == 'ROT_SYM':
                        mat1to2, _ = compute_sym.compute_rot_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1))
                    else:
                        assert 'ERROR: unknown symmetry type: %s' % edge['type']
                    mat1to2 = torch.tensor(mat1to2, dtype=torch.float32, device=device)
                    transformed_pc1 = pc1.matmul(torch.transpose(mat1to2[:, :3], 0, 1)) + \
                            mat1to2[:, 3].unsqueeze(dim=0).repeat(pc1.size(0), 1)
                    dist1, dist2 = chamferLoss(transformed_pc1.view(1, -1, 3), pc2.view(1, -1, 3))
                    loss = (dist1.sqrt().mean() + dist2.sqrt().mean()) / 2
                    binary_diff += loss.item()
                else:
                    assert 'ERROR: unknown symmetry type: %s' % edge['type']
                binary_tot += 1

        return binary_diff, binary_tot
Exemplo n.º 2
0
    def node_recon_loss(self, node_latent, gt_node):
        gt_geo = gt_node.geo
        gt_geo_feat = gt_node.geo_feat
        gt_center = gt_geo.mean(dim=1)
        gt_scale = (gt_geo - gt_center.unsqueeze(dim=1).repeat(1, self.conf.num_point, 1)).pow(2).sum(dim=2).max(dim=1)[0].sqrt().view(1, 1)

        geo_local, geo_center, geo_scale, geo_feat = self.node_decoder(node_latent)
        geo_global = self.geoToGlobal(geo_local, geo_center, geo_scale)

        # geo loss for the current part
        latent_loss = self.mseLoss(geo_feat, gt_geo_feat).mean()
        center_loss = self.mseLoss(geo_center, gt_center).mean()
        scale_loss = self.mseLoss(geo_scale, gt_scale).mean()
        geo_loss = self.chamferDist(geo_global, gt_geo)

        if gt_node.is_leaf:
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))
            return {'leaf': is_leaf_loss, 'geo': geo_loss, 'center': center_loss, 'scale': scale_loss, 
                    'latent': latent_loss, 'exists': torch.zeros_like(geo_loss), 
                    'semantic': torch.zeros_like(geo_loss), 'edge_exists': torch.zeros_like(geo_loss), 
                    'sym': torch.zeros_like(geo_loss), 'adj': torch.zeros_like(geo_loss)}, geo_global, geo_global
        else:
            child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \
                    self.child_decoder(node_latent)

            # generate geo prediction for each child
            feature_len = node_latent.size(1)

            all_geo = []; all_leaf_geo = [];
            all_geo.append(geo_global)

            child_pred_geo_local, child_pred_geo_center, child_pred_geo_scale, _ = \
                    self.node_decoder(child_feats.view(-1, feature_len))
            child_pred_geo = self.geoToGlobal(child_pred_geo_local, child_pred_geo_center, child_pred_geo_scale)
            num_pred = child_pred_geo.size(0)

            # perform hungarian matching between pred geo and gt geo
            with torch.no_grad():
                child_gt_geo = torch.cat([child_node.geo for child_node in gt_node.children], dim=0)
                num_gt = child_gt_geo.size(0)

                child_pred_geo_tiled = child_pred_geo.unsqueeze(dim=0).repeat(num_gt, 1, 1, 1)
                child_gt_geo_tiled = child_gt_geo.unsqueeze(dim=1).repeat(1, num_pred, 1, 1)

                dist_mat = self.chamferDist(child_pred_geo_tiled.view(-1, self.conf.num_point, 3), \
                        child_gt_geo_tiled.view(-1, self.conf.num_point, 3)).view(1, num_gt, num_pred)
                _, matched_gt_idx, matched_pred_idx = linear_assignment(dist_mat)

                # get edge ground truth
                edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors(
                    edge_types=self.conf.edge_types, device=child_feats.device, type_onehot=False)

                gt2pred = {gt_idx: pred_idx for gt_idx, pred_idx in zip(matched_gt_idx, matched_pred_idx)}
                edge_exists_gt = torch.zeros_like(edge_exists_logits)

                sym_from = []; sym_to = []; sym_mat = []; sym_type = []; adj_from = []; adj_to = [];
                for i in range(edge_indices_gt.shape[1]//2):
                    if edge_indices_gt[0, i, 0].item() not in gt2pred or edge_indices_gt[0, i, 1].item() not in gt2pred:
                        """
                            one of the adjacent nodes of the current gt edge was not matched 
                            to any node in the prediction, ignore this edge
                        """
                        continue

                    # correlate gt edges to pred edges
                    edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()]
                    edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()]
                    edge_exists_gt[:, edge_from_idx, edge_to_idx, edge_type_list_gt[0:1, i]] = 1
                    edge_exists_gt[:, edge_to_idx, edge_from_idx, edge_type_list_gt[0:1, i]] = 1

                    # compute binary edge parameters for each matched pred edge
                    if edge_type_list_gt[0, i].item() == 0: # ADJ
                        adj_from.append(edge_from_idx)
                        adj_to.append(edge_to_idx)
                    else:   # SYM
                        obb1 = compute_sym.compute_obb(child_pred_geo[edge_from_idx].cpu().detach().numpy())
                        obb2 = compute_sym.compute_obb(child_pred_geo[edge_to_idx].cpu().detach().numpy())
                        if edge_type_list_gt[0, i].item() == 1: # ROT_SYM
                            mat1to2, mat2to1 = compute_sym.compute_rot_sym(obb1, obb2)
                        elif edge_type_list_gt[0, i].item() == 2: # TRANS_SYM
                            mat1to2, mat2to1 = compute_sym.compute_trans_sym(obb1, obb2)
                        else:   # REF_SYM
                            mat1to2, mat2to1 = compute_sym.compute_ref_sym(obb1, obb2)
                        sym_from.append(edge_from_idx)
                        sym_to.append(edge_to_idx)
                        sym_mat.append(torch.tensor(mat1to2, dtype=torch.float32, device=self.conf.device).view(1, 3, 4))
                        sym_type.append(edge_type_list_gt[0, i].item())

            # train the current node to be non-leaf
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))

            # gather information
            child_sem_gt_labels = []
            child_sem_pred_logits = []
            child_exists_gt = torch.zeros_like(child_exists_logits)
            for i in range(len(matched_gt_idx)):
                child_sem_gt_labels.append(gt_node.children[matched_gt_idx[i]].get_semantic_id())
                child_sem_pred_logits.append(child_sem_logits[0, matched_pred_idx[i], :].view(1, -1))
                child_exists_gt[:, matched_pred_idx[i], :] = 1
 
            # train semantic labels
            child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0)
            child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, device=child_sem_pred_logits.device)
            semantic_loss = self.semCELoss(child_sem_pred_logits, child_sem_gt_labels)
            semantic_loss = semantic_loss.sum()
           
            # train unused geo to zero
            unmatched_scales = []
            for i in range(num_pred):
                if i not in matched_pred_idx:
                    unmatched_scales.append(child_pred_geo_scale[i:i+1])
            if len(unmatched_scales) > 0:
                unmatched_scales = torch.cat(unmatched_scales, dim=0)
                unused_geo_loss = unmatched_scales.pow(2).sum()
            else:
                unused_geo_loss = 0.0

            # train exist scores
            child_exists_loss = F.binary_cross_entropy_with_logits(\
                input=child_exists_logits, target=child_exists_gt, reduction='none')
            child_exists_loss = child_exists_loss.sum()

            # train edge exists scores
            edge_exists_loss = F.binary_cross_entropy_with_logits(\
                    input=edge_exists_logits, target=edge_exists_gt, reduction='none')
            edge_exists_loss = edge_exists_loss.sum()
            # rescale to make it comparable to other losses, 
            # which are in the order of the number of child nodes
            edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2]*edge_exists_gt.shape[3])

            # compute and train binary losses
            sym_loss = 0
            if len(sym_from) > 0:
                sym_from_th = torch.tensor(sym_from, dtype=torch.long, device=self.conf.device)
                pc_from = child_pred_geo[sym_from_th, :]
                sym_to_th = torch.tensor(sym_to, dtype=torch.long, device=self.conf.device)
                pc_to = child_pred_geo[sym_to_th, :]
                sym_mat_th = torch.cat(sym_mat, dim=0)
                transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \
                        sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1)
                loss = self.chamferDist(transformed_pc_from, pc_to) * 2
                sym_loss = loss.sum()

            adj_loss = 0
            if len(adj_from) > 0:
                adj_from_th = torch.tensor(adj_from, dtype=torch.long, device=self.conf.device)
                pc_from = child_pred_geo[adj_from_th, :]
                adj_to_th = torch.tensor(adj_to, dtype=torch.long, device=self.conf.device)
                pc_to = child_pred_geo[adj_to_th, :]
                dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0])
                adj_loss = loss.sum()

            # call children + aggregate losses
            pred2allgeo = dict(); pred2allleafgeo = dict();
            for i in range(len(matched_gt_idx)):
                child_losses, child_all_geo, child_all_leaf_geo = self.node_recon_loss(
                    child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]])
                pred2allgeo[matched_pred_idx[i]] = child_all_geo
                pred2allleafgeo[matched_pred_idx[i]] = child_all_leaf_geo
                all_geo.append(child_all_geo)
                all_leaf_geo.append(child_all_leaf_geo)
                latent_loss = latent_loss + child_losses['latent']
                geo_loss = geo_loss + child_losses['geo']
                center_loss = center_loss + child_losses['center']
                scale_loss = scale_loss + child_losses['scale']
                is_leaf_loss = is_leaf_loss + child_losses['leaf']
                child_exists_loss = child_exists_loss + child_losses['exists']
                semantic_loss = semantic_loss + child_losses['semantic']
                edge_exists_loss = edge_exists_loss + child_losses['edge_exists']
                sym_loss = sym_loss + child_losses['sym']
                adj_loss = adj_loss + child_losses['adj']

            # for sym-edges, train subtree to be symmetric
            for i in range(len(sym_from)):
                s1 = pred2allgeo[sym_from[i]].size(0)
                s2 = pred2allgeo[sym_to[i]].size(0)
                if s1 > 1 and s2 > 1:
                    pc_from = pred2allgeo[sym_from[i]][1:, :, :].view(-1, 3)
                    pc_to = pred2allgeo[sym_to[i]][1:, :, :].view(-1, 3)
                    transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \
                            sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1)
                    dist1, dist2 = self.chamferLoss(transformed_pc_from.view(1, -1, 3), pc_to.view(1, -1, 3))
                    sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2

            # for adj-edges, train leaf-nodes in subtrees to be adjacent
            for i in range(len(adj_from)):
                if pred2allgeo[adj_from[i]].size(0) > pred2allleafgeo[adj_from[i]].size(0) \
                        or pred2allgeo[adj_to[i]].size(0) > pred2allleafgeo[adj_to[i]].size(0):
                    pc_from = pred2allleafgeo[adj_from[i]].view(1, -1, 3)
                    pc_to = pred2allleafgeo[adj_to[i]].view(1, -1, 3)
                    dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                    adj_loss += dist1.min() + dist2.min()

            return {'leaf': is_leaf_loss, 'geo': geo_loss + unused_geo_loss, 'center': center_loss, 
                    'scale': scale_loss, 'latent': latent_loss, 'exists': child_exists_loss, 
                    'semantic': semantic_loss, 'edge_exists': edge_exists_loss, 
                    'sym': sym_loss, 'adj': adj_loss}, torch.cat(all_geo, dim=0), torch.cat(all_leaf_geo, dim=0)
Exemplo n.º 3
0
def compute_struct_diff(gt_node, pred_node):
    if gt_node.is_leaf:
        if pred_node.is_leaf:
            return 0, 0, 0, 0, 0, 0
        else:
            return len(
                pred_node.boxes()) - 1, 0, 0, pred_node.get_subtree_edge_count(
                ), 0, 0
    else:
        if pred_node.is_leaf:
            return len(gt_node.boxes()
                       ) - 1, 0, gt_node.get_subtree_edge_count() * 2, 0, 0, 0
        else:
            gt_sem = set([node.label for node in gt_node.children])
            pred_sem = set([node.label for node in pred_node.children])
            intersect_sem = set.intersection(gt_sem, pred_sem)

            gt_cnodes_per_sem = dict()
            for node_id, gt_cnode in enumerate(gt_node.children):
                if gt_cnode.label in intersect_sem:
                    if gt_cnode.label not in gt_cnodes_per_sem:
                        gt_cnodes_per_sem[gt_cnode.label] = []
                    gt_cnodes_per_sem[gt_cnode.label].append(node_id)

            pred_cnodes_per_sem = dict()
            for node_id, pred_cnode in enumerate(pred_node.children):
                if pred_cnode.label in intersect_sem:
                    if pred_cnode.label not in pred_cnodes_per_sem:
                        pred_cnodes_per_sem[pred_cnode.label] = []
                    pred_cnodes_per_sem[pred_cnode.label].append(node_id)

            matched_gt_idx = []
            matched_pred_idx = []
            matched_gt2pred = np.zeros((conf.max_child_num), dtype=np.int32)
            for sem in intersect_sem:
                gt_boxes = torch.cat([
                    gt_node.children[cid].get_box_quat()
                    for cid in gt_cnodes_per_sem[sem]
                ],
                                     dim=0).to(device)
                pred_boxes = torch.cat([
                    pred_node.children[cid].get_box_quat()
                    for cid in pred_cnodes_per_sem[sem]
                ],
                                       dim=0).to(device)

                num_gt = gt_boxes.size(0)
                num_pred = pred_boxes.size(0)

                if num_gt == 1 and num_pred == 1:
                    cur_matched_gt_idx = [0]
                    cur_matched_pred_idx = [0]
                else:
                    gt_boxes_tiled = gt_boxes.unsqueeze(dim=1).repeat(
                        1, num_pred, 1)
                    pred_boxes_tiled = pred_boxes.unsqueeze(dim=0).repeat(
                        num_gt, 1, 1)
                    dmat = boxLoss(gt_boxes_tiled.view(-1, 10),
                                   pred_boxes_tiled.view(-1, 10)).view(
                                       -1, num_gt, num_pred).cpu()
                    _, cur_matched_gt_idx, cur_matched_pred_idx = utils.linear_assignment(
                        dmat)

                for i in range(len(cur_matched_gt_idx)):
                    matched_gt_idx.append(
                        gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]])
                    matched_pred_idx.append(
                        pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]])
                    matched_gt2pred[gt_cnodes_per_sem[sem][
                        cur_matched_gt_idx[i]]] = pred_cnodes_per_sem[sem][
                            cur_matched_pred_idx[i]]

            struct_diff = 0.0
            edge_both = 0
            edge_gt = 0
            edge_pred = 0
            gt_binary_diff = 0.0
            gt_binary_tot = 0
            for i in range(len(gt_node.children)):
                if i not in matched_gt_idx:
                    struct_diff += len(gt_node.children[i].boxes())
                    edge_gt += gt_node.children[i].get_subtree_edge_count() * 2

            for i in range(len(pred_node.children)):
                if i not in matched_pred_idx:
                    struct_diff += len(pred_node.children[i].boxes())
                    edge_pred += pred_node.children[i].get_subtree_edge_count()

            for i in range(len(matched_gt_idx)):
                gt_id = matched_gt_idx[i]
                pred_id = matched_pred_idx[i]
                cur_struct_diff, cur_edge_both, cur_edge_gt, cur_edge_pred, cur_gt_binary_diff, cur_gt_binary_tot = compute_struct_diff(
                    gt_node.children[gt_id], pred_node.children[pred_id])
                gt_binary_diff += cur_gt_binary_diff
                gt_binary_tot += cur_gt_binary_tot
                struct_diff += cur_struct_diff
                edge_both += cur_edge_both
                edge_gt += cur_edge_gt
                edge_pred += cur_edge_pred
                pred_node.children[pred_id].part_id = gt_node.children[
                    gt_id].part_id

            if pred_node.edges is not None:
                edge_pred += len(pred_node.edges)

            if gt_node.edges is not None:
                edge_gt += len(gt_node.edges) * 2
                pred_edges = np.zeros((conf.max_child_num, conf.max_child_num,
                                       len(conf.edge_types)),
                                      dtype=np.bool)
                for edge in pred_node.edges:
                    pred_part_a_id = edge['part_a']
                    pred_part_b_id = edge['part_b']
                    edge_type_id = conf.edge_types.index(edge['type'])
                    pred_edges[pred_part_a_id, pred_part_b_id,
                               edge_type_id] = True

                for edge in gt_node.edges:
                    gt_part_a_id = edge['part_a']
                    gt_part_b_id = edge['part_b']
                    edge_type_id = conf.edge_types.index(edge['type'])
                    if gt_part_a_id in matched_gt_idx and gt_part_b_id in matched_gt_idx:
                        pred_part_a_id = matched_gt2pred[gt_part_a_id]
                        pred_part_b_id = matched_gt2pred[gt_part_b_id]
                        edge_both += pred_edges[pred_part_a_id, pred_part_b_id,
                                                edge_type_id]
                        edge_both += pred_edges[pred_part_b_id, pred_part_a_id,
                                                edge_type_id]

                        # gt edges eval
                        obb1 = pred_node.children[pred_part_a_id].box.cpu(
                        ).numpy()
                        obb_quat1 = pred_node.children[
                            pred_part_a_id].get_box_quat().cpu().numpy()
                        mesh_v1, mesh_f1 = utils.gen_obb_mesh(obb1)
                        pc1 = utils.sample_pc(mesh_v1, mesh_f1, n_points=500)
                        pc1 = torch.tensor(pc1,
                                           dtype=torch.float32,
                                           device=device)
                        obb2 = pred_node.children[pred_part_b_id].box.cpu(
                        ).numpy()
                        obb_quat2 = pred_node.children[
                            pred_part_b_id].get_box_quat().cpu().numpy()
                        mesh_v2, mesh_f2 = utils.gen_obb_mesh(obb2)
                        pc2 = utils.sample_pc(mesh_v2, mesh_f2, n_points=500)
                        pc2 = torch.tensor(pc2,
                                           dtype=torch.float32,
                                           device=device)
                        if edge_type_id == 0:  # ADJ
                            dist1, dist2 = chamferLoss(pc1.view(1, -1, 3),
                                                       pc2.view(1, -1, 3))
                            gt_binary_diff += (dist1.sqrt().min().item() +
                                               dist2.sqrt().min().item()) / 2
                        else:  # SYM
                            if edge_type_id == 2:  # TRANS_SYM
                                mat1to2, _ = compute_sym.compute_trans_sym(
                                    obb_quat1.reshape(-1),
                                    obb_quat2.reshape(-1))
                            elif edge_type_id == 3:  # REF_SYM
                                mat1to2, _ = compute_sym.compute_ref_sym(
                                    obb_quat1.reshape(-1),
                                    obb_quat2.reshape(-1))
                            elif edge_type_id == 1:  # ROT_SYM
                                mat1to2, _ = compute_sym.compute_rot_sym(
                                    obb_quat1.reshape(-1),
                                    obb_quat2.reshape(-1))
                            else:
                                assert 'ERROR: unknown symmetry type: %s' % edge[
                                    'type']
                            mat1to2 = torch.tensor(mat1to2,
                                                   dtype=torch.float32,
                                                   device=device)
                            transformed_pc1 = pc1.matmul(torch.transpose(mat1to2[:, :3], 0, 1)) + \
                                    mat1to2[:, 3].unsqueeze(dim=0).repeat(pc1.size(0), 1)
                            dist1, dist2 = chamferLoss(
                                transformed_pc1.view(1, -1, 3),
                                pc2.view(1, -1, 3))
                            loss = (dist1.sqrt().mean() +
                                    dist2.sqrt().mean()) / 2
                            gt_binary_diff += loss.item()
                        gt_binary_tot += 1

            return struct_diff, edge_both, edge_gt, edge_pred, gt_binary_diff, gt_binary_tot
Exemplo n.º 4
0
    def node_recon_loss(self, node_latent, gt_node):
        if gt_node.is_leaf:
            box = self.box_decoder(node_latent)
            box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))
            return {'box': box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 
                    'exists': torch.zeros_like(box_loss), 'semantic': torch.zeros_like(box_loss),
                    'edge_exists': torch.zeros_like(box_loss),
                    'sym': torch.zeros_like(box_loss), 'adj': torch.zeros_like(box_loss)}, box, box
        else:
            child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \
                    self.child_decoder(node_latent)

            # generate box prediction for each child
            feature_len = node_latent.size(1)
            child_pred_boxes = self.box_decoder(child_feats.view(-1, feature_len))
            num_child_parts = child_pred_boxes.size(0)

            # perform hungarian matching between pred boxes and gt boxes
            with torch.no_grad():
                child_gt_boxes = torch.cat([child_node.get_box_quat().view(1, -1) for child_node in gt_node.children], dim=0)
                num_gt = child_gt_boxes.size(0)

                child_pred_boxes_tiled = child_pred_boxes.unsqueeze(dim=0).repeat(num_gt, 1, 1)
                child_gt_boxes_tiled = child_gt_boxes.unsqueeze(dim=1).repeat(1, num_child_parts, 1)

                dist_mat = self.boxLossEstimator(child_gt_boxes_tiled.view(-1, 10), child_pred_boxes_tiled.view(-1, 10)).view(-1, num_gt, num_child_parts)

                _, matched_gt_idx, matched_pred_idx = linear_assignment(dist_mat)

                # get edge ground truth
                edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors(
                    edge_types=self.conf.edge_types, device=child_feats.device, type_onehot=False)

                gt2pred = {gt_idx: pred_idx for gt_idx, pred_idx in zip(matched_gt_idx, matched_pred_idx)}
                edge_exists_gt = torch.zeros_like(edge_exists_logits)

                sym_from = []; sym_to = []; sym_mat = []; sym_type = []; adj_from = []; adj_to = [];
                for i in range(edge_indices_gt.shape[1]//2):
                    if edge_indices_gt[0, i, 0].item() not in gt2pred or edge_indices_gt[0, i, 1].item() not in gt2pred:
                        """
                            one of the adjacent nodes of the current gt edge was not matched 
                            to any node in the prediction, ignore this edge
                        """
                        continue
                    
                    # correlate gt edges to pred edges
                    edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()]
                    edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()]
                    edge_exists_gt[:, edge_from_idx, edge_to_idx, edge_type_list_gt[0:1, i]] = 1
                    edge_exists_gt[:, edge_to_idx, edge_from_idx, edge_type_list_gt[0:1, i]] = 1

                    # compute binary edge parameters for each matched pred edge
                    if edge_type_list_gt[0, i].item() == 0: # ADJ
                        adj_from.append(edge_from_idx)
                        adj_to.append(edge_to_idx)
                    else:   # SYM
                        if edge_type_list_gt[0, i].item() == 1: # ROT_SYM
                            mat1to2, mat2to1 = compute_sym.compute_rot_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        elif edge_type_list_gt[0, i].item() == 2: # TRANS_SYM
                            mat1to2, mat2to1 = compute_sym.compute_trans_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        else:   # REF_SYM
                            mat1to2, mat2to1 = compute_sym.compute_ref_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        sym_from.append(edge_from_idx)
                        sym_to.append(edge_to_idx)
                        sym_mat.append(torch.tensor(mat1to2, dtype=torch.float32, device=self.conf.device).view(1, 3, 4))
                        sym_type.append(edge_type_list_gt[0, i].item())

            # train the current node to be non-leaf
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))

            # train the current node box to gt
            all_boxes = []; all_leaf_boxes = [];
            box = self.box_decoder(node_latent)
            all_boxes.append(box)
            box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1))

            # gather information
            child_sem_gt_labels = []
            child_sem_pred_logits = []
            child_box_gt = []
            child_box_pred = []
            child_exists_gt = torch.zeros_like(child_exists_logits)
            for i in range(len(matched_gt_idx)):
                child_sem_gt_labels.append(gt_node.children[matched_gt_idx[i]].get_semantic_id())
                child_sem_pred_logits.append(child_sem_logits[0, matched_pred_idx[i], :].view(1, -1))
                child_box_gt.append(gt_node.children[matched_gt_idx[i]].get_box_quat().view(1, -1))
                child_box_pred.append(child_pred_boxes[matched_pred_idx[i], :].view(1, -1))
                child_exists_gt[:, matched_pred_idx[i], :] = 1

            # train semantic labels
            child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0)
            child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, \
                    device=child_sem_pred_logits.device)
            semantic_loss = self.semCELoss(child_sem_pred_logits, child_sem_gt_labels)
            semantic_loss = semantic_loss.sum()

            # train unused boxes to zeros
            unmatched_boxes = []
            for i in range(num_child_parts):
                if i not in matched_pred_idx:
                    unmatched_boxes.append(child_pred_boxes[i, 3:6].view(1, -1))
            if len(unmatched_boxes) > 0:
                unmatched_boxes = torch.cat(unmatched_boxes, dim=0)
                unused_box_loss = unmatched_boxes.pow(2).sum() * 0.01
            else:
                unused_box_loss = 0.0

            # train exist scores
            child_exists_loss = F.binary_cross_entropy_with_logits(\
                input=child_exists_logits, target=child_exists_gt, reduction='none')
            child_exists_loss = child_exists_loss.sum()

            # train edge exists scores
            edge_exists_loss = F.binary_cross_entropy_with_logits(\
                    input=edge_exists_logits, target=edge_exists_gt, reduction='none')
            edge_exists_loss = edge_exists_loss.sum()
            # rescale to make it comparable to other losses, 
            # which are in the order of the number of child nodes
            edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2]*edge_exists_gt.shape[3]) 

            # compute and train binary losses
            sym_loss = 0
            if len(sym_from) > 0:
                sym_from_th = torch.tensor(sym_from, dtype=torch.long, device=self.conf.device)
                obb_from = child_pred_boxes[sym_from_th, :]
                with torch.no_grad():
                    reweight_from = get_surface_reweighting_batch(obb_from[:, 3:6], self.unit_cube.size(0))
                pc_from = transform_pc_batch(self.unit_cube, obb_from)
                sym_to_th = torch.tensor(sym_to, dtype=torch.long, device=self.conf.device)
                obb_to = child_pred_boxes[sym_to_th, :]
                with torch.no_grad():
                    reweight_to = get_surface_reweighting_batch(obb_to[:, 3:6], self.unit_cube.size(0))
                pc_to = transform_pc_batch(self.unit_cube, obb_to)
                sym_mat_th = torch.cat(sym_mat, dim=0)
                transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \
                        sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1)
                dist1, dist2 = self.chamferLoss(transformed_pc_from, pc_to)
                loss1 = (dist1 * reweight_from).sum(dim=1) / (reweight_from.sum(dim=1) + 1e-12)
                loss2 = (dist2 * reweight_to).sum(dim=1) / (reweight_to.sum(dim=1) + 1e-12)
                loss = loss1 + loss2
                sym_loss = loss.sum()

            adj_loss = 0
            if len(adj_from) > 0:
                adj_from_th = torch.tensor(adj_from, dtype=torch.long, device=self.conf.device)
                obb_from = child_pred_boxes[adj_from_th, :]
                pc_from = transform_pc_batch(self.unit_cube, obb_from)
                adj_to_th = torch.tensor(adj_to, dtype=torch.long, device=self.conf.device)
                obb_to = child_pred_boxes[adj_to_th, :]
                pc_to = transform_pc_batch(self.unit_cube, obb_to)
                dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0])
                adj_loss = loss.sum()

            # call children + aggregate losses
            pred2allboxes = dict(); pred2allleafboxes = dict();
            for i in range(len(matched_gt_idx)):
                child_losses, child_all_boxes, child_all_leaf_boxes = self.node_recon_loss(
                    child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]])
                pred2allboxes[matched_pred_idx[i]] = child_all_boxes
                pred2allleafboxes[matched_pred_idx[i]] = child_all_leaf_boxes
                all_boxes.append(child_all_boxes)
                all_leaf_boxes.append(child_all_leaf_boxes)
                box_loss = box_loss + child_losses['box']
                anchor_loss = anchor_loss + child_losses['anchor'] 
                is_leaf_loss = is_leaf_loss + child_losses['leaf']
                child_exists_loss = child_exists_loss + child_losses['exists']
                semantic_loss = semantic_loss + child_losses['semantic']
                edge_exists_loss = edge_exists_loss + child_losses['edge_exists']
                sym_loss = sym_loss + child_losses['sym']
                adj_loss = adj_loss + child_losses['adj']

            # for sym-edges, train subtree to be symmetric
            for i in range(len(sym_from)):
                s1 = pred2allboxes[sym_from[i]].size(0)
                s2 = pred2allboxes[sym_to[i]].size(0)
                if s1 > 1 and s2 > 1:
                    obbs_from = pred2allboxes[sym_from[i]][1:, :]
                    obbs_to = pred2allboxes[sym_to[i]][1:, :]
                    pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(-1, 3)
                    pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(-1, 3)
                    transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \
                            sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1)
                    dist1, dist2 = self.chamferLoss(transformed_pc_from.view(1, -1, 3), pc_to.view(1, -1, 3))
                    sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2

            # for adj-edges, train leaf-nodes in subtrees to be adjacent
            for i in range(len(adj_from)):
                if pred2allboxes[adj_from[i]].size(0) > pred2allleafboxes[adj_from[i]].size(0) \
                        or pred2allboxes[adj_to[i]].size(0) > pred2allleafboxes[adj_to[i]].size(0):
                    obbs_from = pred2allleafboxes[adj_from[i]]
                    obbs_to = pred2allleafboxes[adj_to[i]]
                    pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(1, -1, 3)
                    pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(1, -1, 3)
                    dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                    adj_loss += dist1.min() + dist2.min()

            return {'box': box_loss + unused_box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 
                    'exists': child_exists_loss, 'semantic': semantic_loss,
                    'edge_exists': edge_exists_loss, 'sym': sym_loss, 'adj': adj_loss}, \
                            torch.cat(all_boxes, dim=0), torch.cat(all_leaf_boxes, dim=0)