Example #1
0
    def compute_shape_diff(source, target, device=None):
        dnode = Tree.DiffNode('SAME')

        # get box_diff for every node
        source_box_quat = source.get_box_quat().squeeze().cpu().numpy()
        target_box_quat = target.get_box_quat().squeeze().cpu().numpy()
        dnode.box_diff = torch.tensor(compute_box_diff(source_box_quat,
                                                       target_box_quat),
                                      dtype=torch.float32,
                                      device=source.box.device).view(1, -1)

        if source.is_leaf:
            if target.is_leaf:
                dnode.node_type = 'LEAF'
            else:
                for cnode in target.children:
                    cdiffnode = Tree.DiffNode('ADD')
                    cdiffnode.subnode = cnode
                    dnode.children.append(cdiffnode)
        else:
            if target.is_leaf:
                for cnode in source.children:
                    dnode.children.append(Tree.DiffNode('DEL'))
            else:
                source_sem = set([node.label for node in source.children])
                target_sem = set([node.label for node in target.children])
                intersect_sem = set.intersection(source_sem, target_sem)

                source_cnodes_per_sem = dict()
                for node_id, cnode in enumerate(source.children):
                    if cnode.label in intersect_sem:
                        if cnode.label not in source_cnodes_per_sem:
                            source_cnodes_per_sem[cnode.label] = []
                        source_cnodes_per_sem[cnode.label].append(node_id)

                target_cnodes_per_sem = dict()
                for node_id, cnode in enumerate(target.children):
                    if cnode.label in intersect_sem:
                        if cnode.label not in target_cnodes_per_sem:
                            target_cnodes_per_sem[cnode.label] = []
                        target_cnodes_per_sem[cnode.label].append(node_id)

                matched_source_ids = []
                matched_target_ids = []
                matched_source2target = dict()
                unmatched_source_ids = set(range(len(source.children)))
                unmatched_target_ids = set(range(len(target.children)))
                for sem in intersect_sem:
                    source_boxes = torch.cat([
                        source.children[cid].get_box_quat()
                        for cid in source_cnodes_per_sem[sem]
                    ],
                                             dim=0)
                    target_boxes = torch.cat([
                        target.children[cid].get_box_quat()
                        for cid in target_cnodes_per_sem[sem]
                    ],
                                             dim=0)

                    num_source = source_boxes.size(0)
                    num_target = target_boxes.size(0)

                    source_boxes_tiled = source_boxes.unsqueeze(dim=1).repeat(
                        1, num_target, 1)
                    target_boxes_tiled = target_boxes.unsqueeze(dim=0).repeat(
                        num_source, 1, 1)

                    dmat = boxLoss(source_boxes_tiled.view(-1, 10),
                                   target_boxes_tiled.view(-1, 10)).view(
                                       -1, num_source, num_target).cpu()
                    _, cur_matched_source_ids, cur_matched_target_ids = linear_assignment(
                        dmat)

                    for i in range(len(cur_matched_source_ids)):
                        source_node_id = source_cnodes_per_sem[sem][
                            cur_matched_source_ids[i]]
                        matched_source_ids.append(source_node_id)
                        unmatched_source_ids.remove(source_node_id)
                        target_node_id = target_cnodes_per_sem[sem][
                            cur_matched_target_ids[i]]
                        matched_target_ids.append(target_node_id)
                        unmatched_target_ids.remove(target_node_id)
                        matched_source2target[source_node_id] = target_node_id

                for node_id, cnode in enumerate(source.children):
                    if node_id in unmatched_source_ids:
                        dnode.children.append(Tree.DiffNode('DEL'))
                    else:
                        dnode.children.append(
                            Tree.compute_shape_diff(
                                source.children[node_id], target.children[
                                    matched_source2target[node_id]]))

                for i in unmatched_target_ids:
                    cdiffnode = Tree.DiffNode('ADD')
                    cdiffnode.subnode = target.children[i]
                    dnode.children.append(cdiffnode)

        return dnode
Example #2
0
    def struct_dist(gt_node, pred_node):
        if gt_node.is_leaf:
            if pred_node.is_leaf:
                return 0
            else:
                return len(pred_node.boxes()) - 1
        else:
            if pred_node.is_leaf:
                return len(gt_node.boxes()) - 1
            else:
                gt_sem = set([node.label for node in gt_node.children])
                pred_sem = set([node.label for node in pred_node.children])
                intersect_sem = set.intersection(gt_sem, pred_sem)

                gt_cnodes_per_sem = dict()
                for node_id, gt_cnode in enumerate(gt_node.children):
                    if gt_cnode.label in intersect_sem:
                        if gt_cnode.label not in gt_cnodes_per_sem:
                            gt_cnodes_per_sem[gt_cnode.label] = []
                        gt_cnodes_per_sem[gt_cnode.label].append(node_id)

                pred_cnodes_per_sem = dict()
                for node_id, pred_cnode in enumerate(pred_node.children):
                    if pred_cnode.label in intersect_sem:
                        if pred_cnode.label not in pred_cnodes_per_sem:
                            pred_cnodes_per_sem[pred_cnode.label] = []
                        pred_cnodes_per_sem[pred_cnode.label].append(node_id)

                matched_gt_idx = []
                matched_pred_idx = []
                matched_gt2pred = np.zeros((100), dtype=np.int32)
                for sem in intersect_sem:
                    gt_boxes = torch.cat([
                        gt_node.children[cid].get_box_quat()
                        for cid in gt_cnodes_per_sem[sem]
                    ],
                                         dim=0)
                    pred_boxes = torch.cat([
                        pred_node.children[cid].get_box_quat()
                        for cid in pred_cnodes_per_sem[sem]
                    ],
                                           dim=0)

                    num_gt = gt_boxes.size(0)
                    num_pred = pred_boxes.size(0)

                    if num_gt == 1 and num_pred == 1:
                        cur_matched_gt_idx = [0]
                        cur_matched_pred_idx = [0]
                    else:
                        gt_boxes_tiled = gt_boxes.unsqueeze(dim=1).repeat(
                            1, num_pred, 1)
                        pred_boxes_tiled = pred_boxes.unsqueeze(dim=0).repeat(
                            num_gt, 1, 1)
                        dmat = box_dist(gt_boxes_tiled.view(-1, 10),
                                        pred_boxes_tiled.view(-1, 10)).view(
                                            -1, num_gt, num_pred)
                        _, cur_matched_gt_idx, cur_matched_pred_idx = utils.linear_assignment(
                            dmat)

                    for i in range(len(cur_matched_gt_idx)):
                        matched_gt_idx.append(
                            gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]])
                        matched_pred_idx.append(
                            pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]])
                        matched_gt2pred[gt_cnodes_per_sem[sem][
                            cur_matched_gt_idx[i]]] = pred_cnodes_per_sem[sem][
                                cur_matched_pred_idx[i]]

                struct_diff = 0.0
                for i in range(len(gt_node.children)):
                    if i not in matched_gt_idx:
                        struct_diff += len(gt_node.children[i].boxes())

                for i in range(len(pred_node.children)):
                    if i not in matched_pred_idx:
                        struct_diff += len(pred_node.children[i].boxes())

                for i in range(len(matched_gt_idx)):
                    gt_id = matched_gt_idx[i]
                    pred_id = matched_pred_idx[i]
                    struct_diff += struct_dist(gt_node.children[gt_id],
                                               pred_node.children[pred_id])

                return struct_diff
Example #3
0
    def node_recon_loss(self, node_latent, gt_node):
        if gt_node.is_leaf:
            box = self.box_decoder(node_latent)
            box_loss = self.boxLossEstimator(
                box,
                gt_node.get_box_quat().view(1, -1))
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(
                is_leaf_logit,
                is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))
            return {
                'box': box_loss,
                'leaf': is_leaf_loss,
                'exists': torch.zeros_like(box_loss),
                'semantic': torch.zeros_like(box_loss)
            }
        else:
            child_feats, child_sem_logits, child_exists_logits = \
                    self.child_decoder(node_latent)

            # generate box prediction for each child
            feature_len = node_latent.size(1)
            child_pred_boxes = self.box_decoder(
                child_feats.view(-1, feature_len))
            num_child_parts = child_pred_boxes.size(0)

            # perform hungarian matching between pred boxes and gt boxes
            with torch.no_grad():
                child_gt_boxes = torch.cat([
                    child_node.get_box_quat()
                    for child_node in gt_node.children
                ],
                                           dim=0)
                num_gt = child_gt_boxes.size(0)

                child_pred_boxes_tiled = child_pred_boxes.unsqueeze(
                    dim=0).repeat(num_gt, 1, 1)
                child_gt_boxes_tiled = child_gt_boxes.unsqueeze(dim=1).repeat(
                    1, num_child_parts, 1)

                dist_mat = self.boxLossEstimator(
                    child_gt_boxes_tiled.view(-1, 10),
                    child_pred_boxes_tiled.view(-1, 10)).view(
                        -1, num_gt, num_child_parts)

                _, matched_gt_idx, matched_pred_idx = linear_assignment(
                    dist_mat)

            # train the current node to be non-leaf
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(
                is_leaf_logit,
                is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))

            # train the current node box to gt
            box = self.box_decoder(node_latent)
            box_loss = self.boxLossEstimator(
                box,
                gt_node.get_box_quat().view(1, -1))

            # gather information
            child_sem_gt_labels = []
            child_sem_pred_logits = []
            child_box_gt = []
            child_box_pred = []
            child_exists_gt = torch.zeros_like(child_exists_logits)
            for i in range(len(matched_gt_idx)):
                child_sem_gt_labels.append(
                    gt_node.children[matched_gt_idx[i]].get_semantic_id())
                child_sem_pred_logits.append(
                    child_sem_logits[0, matched_pred_idx[i], :].view(1, -1))
                child_box_gt.append(
                    gt_node.children[matched_gt_idx[i]].get_box_quat())
                child_box_pred.append(
                    child_pred_boxes[matched_pred_idx[i], :].view(1, -1))
                child_exists_gt[:, matched_pred_idx[i], :] = 1

            # train semantic labels
            child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0)
            child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, \
                    device=child_sem_pred_logits.device)
            semantic_loss = self.ceLoss(child_sem_pred_logits,
                                        child_sem_gt_labels)
            semantic_loss = semantic_loss.sum()

            # train unused boxes to zeros
            unmatched_boxes = []
            for i in range(num_child_parts):
                if i not in matched_pred_idx:
                    unmatched_boxes.append(child_pred_boxes[i,
                                                            3:6].view(1, -1))
            if len(unmatched_boxes) > 0:
                unmatched_boxes = torch.cat(unmatched_boxes, dim=0)
                unused_box_loss = unmatched_boxes.pow(2).sum() * 0.01
            else:
                unused_box_loss = 0.0

            # train exist scores
            child_exists_loss = F.binary_cross_entropy_with_logits(
                input=child_exists_logits,
                target=child_exists_gt,
                reduction='none')
            child_exists_loss = child_exists_loss.sum()

            # calculate children + aggregate losses
            for i in range(len(matched_gt_idx)):
                child_losses = self.node_recon_loss(\
                        child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]])
                box_loss = box_loss + child_losses['box']
                is_leaf_loss = is_leaf_loss + child_losses['leaf']
                child_exists_loss = child_exists_loss + child_losses['exists']
                semantic_loss = semantic_loss + child_losses['semantic']

        return {
            'box': box_loss + unused_box_loss,
            'leaf': is_leaf_loss,
            'exists': child_exists_loss,
            'semantic': semantic_loss
        }
Example #4
0
    def node_diff_recon_loss(self, z, z_skip, obj_node, gt_diff):
        # initialize all losses to zeros
        box_loss = torch.zeros(1, device=z.device)
        is_leaf_loss = torch.zeros(1, device=z.device)
        child_exists_loss = torch.zeros(1, device=z.device)
        semantic_loss = torch.zeros(1, device=z.device)
        diffnode_type_loss = torch.zeros(1, device=z.device)
        diffnode_box_loss = torch.zeros(1, device=z.device)

        # DEL/SAME/LEAF
        node_diff_types = []
        gt_node_diff_types = []
        for i, cnode in enumerate(obj_node.children):
            cdiff = gt_diff.children[i]
            node_diff_feat = self.node_diff_feature_extractor(
                torch.cat([
                    z, z_skip, cnode.box_feature, cnode.feature,
                    cnode.get_semantic_one_hot()
                ],
                          dim=1))
            node_diff_types.append(self.node_diff_classifier(node_diff_feat))
            gt_node_diff_types.append(cdiff.get_node_type_id())
            if cdiff.node_type == 'SAME':
                child_losses = self.node_diff_recon_loss(z=node_diff_feat,
                                                         z_skip=z_skip,
                                                         obj_node=cnode,
                                                         gt_diff=cdiff)
                diffnode_type_loss += child_losses['diffnode_type']
                diffnode_box_loss += child_losses['diffnode_box']
                box_loss += child_losses['box']
                is_leaf_loss += child_losses['leaf']
                child_exists_loss += child_losses['exists']
                semantic_loss += child_losses['semantic']
            if cdiff.node_type == 'LEAF' or cdiff.node_type == 'SAME':
                box2 = self.apply_box_diff(
                    cnode.get_box_quat(),
                    self.box_diff_decoder(node_diff_feat))
                gt_box2 = self.apply_box_diff(cnode.get_box_quat(),
                                              cdiff.box_diff)
                diffnode_box_loss += self.boxLossEstimator(box2, gt_box2).sum()

        # diff node loss
        if len(obj_node.children) > 0:
            node_diff_types = torch.cat(node_diff_types, dim=0)
            gt_node_diff_types = torch.tensor(gt_node_diff_types,
                                              dtype=torch.int64,
                                              device=node_diff_types.device)
            diffnode_type_loss += self.ceLoss(node_diff_types,
                                              gt_node_diff_types).sum()

        # ADD
        add_child_feats, add_child_sem_logits, add_child_exists_logits = self.add_child_decoder(
            z)

        add_child_boxes = self.box_decoder(
            add_child_feats.view(-1, self.conf.feature_size))
        num_pred = add_child_boxes.size(0)

        child_exists_gt = torch.zeros_like(add_child_exists_logits)

        with torch.no_grad():
            add_child_gt_boxes = []
            for i in range(len(obj_node.children), len(gt_diff.children)):
                add_child_gt_boxes.append(
                    gt_diff.children[i].subnode.get_box_quat())

            num_gt = len(add_child_gt_boxes)
            if num_gt > 0:
                add_child_gt_boxes = torch.cat(add_child_gt_boxes, dim=0)
                pred_tiled = add_child_boxes.unsqueeze(0).repeat(num_gt, 1, 1)
                gt_tiled = add_child_gt_boxes.unsqueeze(1).repeat(
                    1, num_pred, 1)
                dmat = self.boxLossEstimator(gt_tiled.view(-1, 10),
                                             pred_tiled.view(-1, 10)).view(
                                                 -1, num_gt, num_pred)
                _, matched_gt_idx, matched_pred_idx = linear_assignment(dmat)

        if num_gt > 0:
            # gather information
            child_sem_gt_labels = []
            child_sem_pred_logits = []
            child_box_gt = []
            child_box_pred = []
            for i in range(len(matched_gt_idx)):
                child_sem_gt_labels.append(gt_diff.children[
                    matched_gt_idx[i] +
                    len(obj_node.children)].subnode.get_semantic_id())
                child_sem_pred_logits.append(
                    add_child_sem_logits[0,
                                         matched_pred_idx[i], :].view(1, -1))
                child_box_gt.append(gt_diff.children[
                    matched_gt_idx[i] +
                    len(obj_node.children)].subnode.get_box_quat())
                child_box_pred.append(
                    add_child_boxes[matched_pred_idx[i], :].view(1, -1))
                child_exists_gt[:, matched_pred_idx[i], :] = 1

                # train add node subtree
                child_losses = self.node_recon_loss(add_child_feats[:, matched_pred_idx[i], :], \
                        gt_diff.children[matched_gt_idx[i]+len(obj_node.children)].subnode)
                box_loss += child_losses['box']
                is_leaf_loss += child_losses['leaf']
                child_exists_loss += child_losses['exists']
                semantic_loss += child_losses['semantic']

            # train semantic labels
            child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0)
            child_sem_gt_labels = torch.tensor(
                child_sem_gt_labels,
                dtype=torch.int64,
                device=add_child_sem_logits.device)
            semantic_loss += self.ceLoss(child_sem_pred_logits,
                                         child_sem_gt_labels).sum()

        # train unused boxes to zeros
        unmatched_boxes = []
        for i in range(num_pred):
            if num_gt == 0 or i not in matched_pred_idx:
                unmatched_boxes.append(add_child_boxes[i, 3:6].view(1, -1))
        if len(unmatched_boxes) > 0:
            unmatched_boxes = torch.cat(unmatched_boxes, dim=0)
            box_loss += unmatched_boxes.pow(2).sum() * 0.01

        # train exist scores
        child_exists_loss += F.binary_cross_entropy_with_logits(
            input=add_child_exists_logits,
            target=child_exists_gt,
            reduction='none').sum()

        return {'box': box_loss, 'leaf': is_leaf_loss, 'exists': child_exists_loss, 'semantic': semantic_loss, \
                'diffnode_type': diffnode_type_loss, 'diffnode_box': diffnode_box_loss}
Example #5
0
    def node_recon_loss(self, node_latent, gt_node):
        gt_geo = gt_node.geo
        gt_geo_feat = gt_node.geo_feat
        gt_center = gt_geo.mean(dim=1)
        gt_scale = (gt_geo - gt_center.unsqueeze(dim=1).repeat(
            1, self.conf.num_point, 1)).pow(2).sum(dim=2).max(
                dim=1)[0].sqrt().view(1, 1)

        geo_local, geo_center, geo_scale, geo_feat = self.node_decoder(
            node_latent)
        geo_global = self.geoToGlobal(geo_local, geo_center, geo_scale)

        # geo loss for the current part
        latent_loss = self.mseLoss(geo_feat, gt_geo_feat).mean()
        center_loss = self.mseLoss(geo_center, gt_center).mean()
        scale_loss = self.mseLoss(geo_scale, gt_scale).mean()
        geo_loss = self.chamferDist(geo_global, gt_geo)

        if gt_node.is_leaf:
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(
                is_leaf_logit,
                is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))
            return {
                'leaf': is_leaf_loss,
                'geo': geo_loss,
                'center': center_loss,
                'scale': scale_loss,
                'latent': latent_loss,
                'exists': torch.zeros_like(geo_loss),
                'semantic': torch.zeros_like(geo_loss),
                'edge_exists': torch.zeros_like(geo_loss),
                'sym': torch.zeros_like(geo_loss),
                'adj': torch.zeros_like(geo_loss)
            }, geo_global, geo_global
        else:
            child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \
                    self.child_decoder(node_latent)

            # generate geo prediction for each child
            feature_len = node_latent.size(1)

            all_geo = []
            all_leaf_geo = []
            all_geo.append(geo_global)

            child_pred_geo_local, child_pred_geo_center, child_pred_geo_scale, _ = \
                    self.node_decoder(child_feats.view(-1, feature_len))
            child_pred_geo = self.geoToGlobal(child_pred_geo_local,
                                              child_pred_geo_center,
                                              child_pred_geo_scale)
            num_pred = child_pred_geo.size(0)

            # perform hungarian matching between pred geo and gt geo
            with torch.no_grad():
                child_gt_geo = torch.cat(
                    [child_node.geo for child_node in gt_node.children], dim=0)
                num_gt = child_gt_geo.size(0)

                child_pred_geo_tiled = child_pred_geo.unsqueeze(dim=0).repeat(
                    num_gt, 1, 1, 1)
                child_gt_geo_tiled = child_gt_geo.unsqueeze(dim=1).repeat(
                    1, num_pred, 1, 1)

                dist_mat = self.chamferDist(child_pred_geo_tiled.view(-1, self.conf.num_point, 3), \
                        child_gt_geo_tiled.view(-1, self.conf.num_point, 3)).view(1, num_gt, num_pred)
                _, matched_gt_idx, matched_pred_idx = linear_assignment(
                    dist_mat)

                # get edge ground truth
                edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors(
                    edge_types=self.conf.edge_types,
                    device=child_feats.device,
                    type_onehot=False)

                gt2pred = {
                    gt_idx: pred_idx
                    for gt_idx, pred_idx in zip(matched_gt_idx,
                                                matched_pred_idx)
                }
                edge_exists_gt = torch.zeros_like(edge_exists_logits)

                sym_from = []
                sym_to = []
                sym_mat = []
                sym_type = []
                adj_from = []
                adj_to = []
                for i in range(edge_indices_gt.shape[1] // 2):
                    if edge_indices_gt[
                            0, i, 0].item() not in gt2pred or edge_indices_gt[
                                0, i, 1].item() not in gt2pred:
                        """
                            one of the adjacent nodes of the current gt edge was not matched 
                            to any node in the prediction, ignore this edge
                        """
                        continue

                    # correlate gt edges to pred edges
                    edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()]
                    edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()]
                    edge_exists_gt[:, edge_from_idx, edge_to_idx,
                                   edge_type_list_gt[0:1, i]] = 1
                    edge_exists_gt[:, edge_to_idx, edge_from_idx,
                                   edge_type_list_gt[0:1, i]] = 1

                    # compute binary edge parameters for each matched pred edge
                    if edge_type_list_gt[0, i].item() == 0:  # ADJ
                        adj_from.append(edge_from_idx)
                        adj_to.append(edge_to_idx)
                    else:  # SYM
                        obb1 = compute_sym.compute_obb(
                            child_pred_geo[edge_from_idx].cpu().detach().numpy(
                            ))
                        obb2 = compute_sym.compute_obb(
                            child_pred_geo[edge_to_idx].cpu().detach().numpy())
                        if edge_type_list_gt[0, i].item() == 1:  # ROT_SYM
                            mat1to2, mat2to1 = compute_sym.compute_rot_sym(
                                obb1, obb2)
                        elif edge_type_list_gt[0, i].item() == 2:  # TRANS_SYM
                            mat1to2, mat2to1 = compute_sym.compute_trans_sym(
                                obb1, obb2)
                        else:  # REF_SYM
                            mat1to2, mat2to1 = compute_sym.compute_ref_sym(
                                obb1, obb2)
                        sym_from.append(edge_from_idx)
                        sym_to.append(edge_to_idx)
                        sym_mat.append(
                            torch.tensor(mat1to2,
                                         dtype=torch.float32,
                                         device=self.conf.device).view(
                                             1, 3, 4))
                        sym_type.append(edge_type_list_gt[0, i].item())

            # train the current node to be non-leaf
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(
                is_leaf_logit,
                is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))

            # gather information
            child_sem_gt_labels = []
            child_sem_pred_logits = []
            child_exists_gt = torch.zeros_like(child_exists_logits)
            for i in range(len(matched_gt_idx)):
                child_sem_gt_labels.append(
                    gt_node.children[matched_gt_idx[i]].get_semantic_id())
                child_sem_pred_logits.append(
                    child_sem_logits[0, matched_pred_idx[i], :].view(1, -1))
                child_exists_gt[:, matched_pred_idx[i], :] = 1

            # train semantic labels
            child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0)
            child_sem_gt_labels = torch.tensor(
                child_sem_gt_labels,
                dtype=torch.int64,
                device=child_sem_pred_logits.device)
            semantic_loss = self.semCELoss(child_sem_pred_logits,
                                           child_sem_gt_labels)
            semantic_loss = semantic_loss.sum()

            # train unused geo to zero
            unmatched_scales = []
            for i in range(num_pred):
                if i not in matched_pred_idx:
                    unmatched_scales.append(child_pred_geo_scale[i:i + 1])
            if len(unmatched_scales) > 0:
                unmatched_scales = torch.cat(unmatched_scales, dim=0)
                unused_geo_loss = unmatched_scales.pow(2).sum()
            else:
                unused_geo_loss = 0.0

            # train exist scores
            child_exists_loss = F.binary_cross_entropy_with_logits(\
                input=child_exists_logits, target=child_exists_gt, reduction='none')
            child_exists_loss = child_exists_loss.sum()

            # train edge exists scores
            edge_exists_loss = F.binary_cross_entropy_with_logits(\
                    input=edge_exists_logits, target=edge_exists_gt, reduction='none')
            edge_exists_loss = edge_exists_loss.sum()
            # rescale to make it comparable to other losses,
            # which are in the order of the number of child nodes
            edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2] *
                                                   edge_exists_gt.shape[3])

            # compute and train binary losses
            sym_loss = 0
            if len(sym_from) > 0:
                sym_from_th = torch.tensor(sym_from,
                                           dtype=torch.long,
                                           device=self.conf.device)
                pc_from = child_pred_geo[sym_from_th, :]
                sym_to_th = torch.tensor(sym_to,
                                         dtype=torch.long,
                                         device=self.conf.device)
                pc_to = child_pred_geo[sym_to_th, :]
                sym_mat_th = torch.cat(sym_mat, dim=0)
                transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \
                        sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1)
                loss = self.chamferDist(transformed_pc_from, pc_to) * 2
                sym_loss = loss.sum()

            adj_loss = 0
            if len(adj_from) > 0:
                adj_from_th = torch.tensor(adj_from,
                                           dtype=torch.long,
                                           device=self.conf.device)
                pc_from = child_pred_geo[adj_from_th, :]
                adj_to_th = torch.tensor(adj_to,
                                         dtype=torch.long,
                                         device=self.conf.device)
                pc_to = child_pred_geo[adj_to_th, :]
                dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0])
                adj_loss = loss.sum()

            # call children + aggregate losses
            pred2allgeo = dict()
            pred2allleafgeo = dict()
            for i in range(len(matched_gt_idx)):
                child_losses, child_all_geo, child_all_leaf_geo = self.node_recon_loss(
                    child_feats[:, matched_pred_idx[i], :],
                    gt_node.children[matched_gt_idx[i]])
                pred2allgeo[matched_pred_idx[i]] = child_all_geo
                pred2allleafgeo[matched_pred_idx[i]] = child_all_leaf_geo
                all_geo.append(child_all_geo)
                all_leaf_geo.append(child_all_leaf_geo)
                latent_loss = latent_loss + child_losses['latent']
                geo_loss = geo_loss + child_losses['geo']
                center_loss = center_loss + child_losses['center']
                scale_loss = scale_loss + child_losses['scale']
                is_leaf_loss = is_leaf_loss + child_losses['leaf']
                child_exists_loss = child_exists_loss + child_losses['exists']
                semantic_loss = semantic_loss + child_losses['semantic']
                edge_exists_loss = edge_exists_loss + child_losses[
                    'edge_exists']
                sym_loss = sym_loss + child_losses['sym']
                adj_loss = adj_loss + child_losses['adj']

            # for sym-edges, train subtree to be symmetric
            for i in range(len(sym_from)):
                s1 = pred2allgeo[sym_from[i]].size(0)
                s2 = pred2allgeo[sym_to[i]].size(0)
                if s1 > 1 and s2 > 1:
                    pc_from = pred2allgeo[sym_from[i]][1:, :, :].view(-1, 3)
                    pc_to = pred2allgeo[sym_to[i]][1:, :, :].view(-1, 3)
                    transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \
                            sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1)
                    dist1, dist2 = self.chamferLoss(
                        transformed_pc_from.view(1, -1, 3),
                        pc_to.view(1, -1, 3))
                    sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2

            # for adj-edges, train leaf-nodes in subtrees to be adjacent
            for i in range(len(adj_from)):
                if pred2allgeo[adj_from[i]].size(0) > pred2allleafgeo[adj_from[i]].size(0) \
                        or pred2allgeo[adj_to[i]].size(0) > pred2allleafgeo[adj_to[i]].size(0):
                    pc_from = pred2allleafgeo[adj_from[i]].view(1, -1, 3)
                    pc_to = pred2allleafgeo[adj_to[i]].view(1, -1, 3)
                    dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                    adj_loss += dist1.min() + dist2.min()

            return {
                'leaf': is_leaf_loss,
                'geo': geo_loss + unused_geo_loss,
                'center': center_loss,
                'scale': scale_loss,
                'latent': latent_loss,
                'exists': child_exists_loss,
                'semantic': semantic_loss,
                'edge_exists': edge_exists_loss,
                'sym': sym_loss,
                'adj': adj_loss
            }, torch.cat(all_geo, dim=0), torch.cat(all_leaf_geo, dim=0)
def compute_struct_diff(gt_node, pred_node):
    if gt_node.is_leaf:
        if pred_node.is_leaf:
            return 0, 0, 0, 0, 0, 0
        else:
            return len(pred_node.boxes())-1, 0, 0, pred_node.get_subtree_edge_count(), 0, 0
    else:
        if pred_node.is_leaf:
            return len(gt_node.boxes())-1, 0, gt_node.get_subtree_edge_count() * 2, 0, 0, 0
        else:
            gt_sem = set([node.label for node in gt_node.children])
            pred_sem = set([node.label for node in pred_node.children])
            intersect_sem = set.intersection(gt_sem, pred_sem)

            gt_cnodes_per_sem = dict()
            for node_id, gt_cnode in enumerate(gt_node.children):
                if gt_cnode.label in intersect_sem:
                    if gt_cnode.label not in gt_cnodes_per_sem:
                        gt_cnodes_per_sem[gt_cnode.label] = []
                    gt_cnodes_per_sem[gt_cnode.label].append(node_id)

            pred_cnodes_per_sem = dict()
            for node_id, pred_cnode in enumerate(pred_node.children):
                if pred_cnode.label in intersect_sem:
                    if pred_cnode.label not in pred_cnodes_per_sem:
                        pred_cnodes_per_sem[pred_cnode.label] = []
                    pred_cnodes_per_sem[pred_cnode.label].append(node_id)

            matched_gt_idx = []; matched_pred_idx = []; matched_gt2pred = np.zeros((conf.max_child_num), dtype=np.int32)
            for sem in intersect_sem:
                gt_boxes = torch.cat([gt_node.children[cid].get_box_quat() for cid in gt_cnodes_per_sem[sem]], dim=0).to(device)
                pred_boxes = torch.cat([pred_node.children[cid].get_box_quat() for cid in pred_cnodes_per_sem[sem]], dim=0).to(device)

                num_gt = gt_boxes.size(0)
                num_pred = pred_boxes.size(0)

                if num_gt == 1 and num_pred == 1:
                    cur_matched_gt_idx = [0]
                    cur_matched_pred_idx = [0]
                else:
                    gt_boxes_tiled = gt_boxes.unsqueeze(dim=1).repeat(1, num_pred, 1)
                    pred_boxes_tiled = pred_boxes.unsqueeze(dim=0).repeat(num_gt, 1, 1)
                    dmat = boxLoss(gt_boxes_tiled.view(-1, 10), pred_boxes_tiled.view(-1, 10)).view(-1, num_gt, num_pred).cpu()
                    _, cur_matched_gt_idx, cur_matched_pred_idx = utils.linear_assignment(dmat)

                for i in range(len(cur_matched_gt_idx)):
                    matched_gt_idx.append(gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]])
                    matched_pred_idx.append(pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]])
                    matched_gt2pred[gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]]] = pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]]

            struct_diff = 0.0; edge_both = 0; edge_gt = 0; edge_pred = 0;
            gt_binary_diff = 0.0; gt_binary_tot = 0;
            for i in range(len(gt_node.children)):
                if i not in matched_gt_idx:
                    struct_diff += len(gt_node.children[i].boxes())
                    edge_gt += gt_node.children[i].get_subtree_edge_count() * 2

            for i in range(len(pred_node.children)):
                if i not in matched_pred_idx:
                    struct_diff += len(pred_node.children[i].boxes())
                    edge_pred += pred_node.children[i].get_subtree_edge_count()

            for i in range(len(matched_gt_idx)):
                gt_id = matched_gt_idx[i]
                pred_id = matched_pred_idx[i]
                cur_struct_diff, cur_edge_both, cur_edge_gt, cur_edge_pred, cur_gt_binary_diff, cur_gt_binary_tot = compute_struct_diff(gt_node.children[gt_id], pred_node.children[pred_id])
                gt_binary_diff += cur_gt_binary_diff
                gt_binary_tot += cur_gt_binary_tot
                struct_diff += cur_struct_diff
                edge_both += cur_edge_both
                edge_gt += cur_edge_gt
                edge_pred += cur_edge_pred
                pred_node.children[pred_id].part_id = gt_node.children[gt_id].part_id

            if pred_node.edges is not None:
                edge_pred += len(pred_node.edges)

            if gt_node.edges is not None:
                edge_gt += len(gt_node.edges) * 2
                pred_edges = np.zeros((conf.max_child_num, conf.max_child_num, len(conf.edge_types)), dtype=np.bool)
                for edge in pred_node.edges:
                    pred_part_a_id = edge['part_a']
                    pred_part_b_id = edge['part_b']
                    edge_type_id = conf.edge_types.index(edge['type'])
                    pred_edges[pred_part_a_id, pred_part_b_id, edge_type_id] = True

                for edge in gt_node.edges:
                    gt_part_a_id = edge['part_a']
                    gt_part_b_id = edge['part_b']
                    edge_type_id = conf.edge_types.index(edge['type'])
                    if gt_part_a_id in matched_gt_idx and gt_part_b_id in matched_gt_idx:
                        pred_part_a_id = matched_gt2pred[gt_part_a_id]
                        pred_part_b_id = matched_gt2pred[gt_part_b_id]
                        edge_both += pred_edges[pred_part_a_id, pred_part_b_id, edge_type_id]
                        edge_both += pred_edges[pred_part_b_id, pred_part_a_id, edge_type_id]

                        # gt edges eval
                        obb1 = pred_node.children[pred_part_a_id].box.cpu().numpy()
                        obb_quat1 = pred_node.children[pred_part_a_id].get_box_quat().cpu().numpy()
                        mesh_v1, mesh_f1 = utils.gen_obb_mesh(obb1)
                        pc1 = utils.sample_pc(mesh_v1, mesh_f1, n_points=500)
                        pc1 = torch.tensor(pc1, dtype=torch.float32, device=device)
                        obb2 = pred_node.children[pred_part_b_id].box.cpu().numpy()
                        obb_quat2 = pred_node.children[pred_part_b_id].get_box_quat().cpu().numpy()
                        mesh_v2, mesh_f2 = utils.gen_obb_mesh(obb2)
                        pc2 = utils.sample_pc(mesh_v2, mesh_f2, n_points=500)
                        pc2 = torch.tensor(pc2, dtype=torch.float32, device=device)
                        if edge_type_id == 0: # ADJ
                            dist1, dist2 = chamferLoss(pc1.view(1, -1, 3), pc2.view(1, -1, 3))
                            gt_binary_diff += (dist1.sqrt().min().item() + dist2.sqrt().min().item()) / 2
                        else: # SYM
                            if edge_type_id == 2: # TRANS_SYM
                                mat1to2, _ = compute_sym.compute_trans_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1))
                            elif edge_type_id == 3: # REF_SYM
                                mat1to2, _ = compute_sym.compute_ref_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1))
                            elif edge_type_id == 1: # ROT_SYM
                                mat1to2, _ = compute_sym.compute_rot_sym(obb_quat1.reshape(-1), obb_quat2.reshape(-1))
                            else:
                                assert 'ERROR: unknown symmetry type: %s' % edge['type']
                            mat1to2 = torch.tensor(mat1to2, dtype=torch.float32, device=device)
                            transformed_pc1 = pc1.matmul(torch.transpose(mat1to2[:, :3], 0, 1)) + \
                                    mat1to2[:, 3].unsqueeze(dim=0).repeat(pc1.size(0), 1)
                            dist1, dist2 = chamferLoss(transformed_pc1.view(1, -1, 3), pc2.view(1, -1, 3))
                            loss = (dist1.sqrt().mean() + dist2.sqrt().mean()) / 2
                            gt_binary_diff += loss.item()
                        gt_binary_tot += 1
            
            return struct_diff, edge_both, edge_gt, edge_pred, gt_binary_diff, gt_binary_tot
Example #7
0
    def node_recon_loss(self, node_latent, gt_node):
        if gt_node.is_leaf:
            box = self.box_decoder(node_latent)
            box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))
            return {'box': box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 
                    'exists': torch.zeros_like(box_loss), 'semantic': torch.zeros_like(box_loss),
                    'edge_exists': torch.zeros_like(box_loss),
                    'sym': torch.zeros_like(box_loss), 'adj': torch.zeros_like(box_loss)}, box, box
        else:
            child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \
                    self.child_decoder(node_latent)

            # generate box prediction for each child
            feature_len = node_latent.size(1)
            child_pred_boxes = self.box_decoder(child_feats.view(-1, feature_len))
            num_child_parts = child_pred_boxes.size(0)

            # perform hungarian matching between pred boxes and gt boxes
            with torch.no_grad():
                child_gt_boxes = torch.cat([child_node.get_box_quat().view(1, -1) for child_node in gt_node.children], dim=0)
                num_gt = child_gt_boxes.size(0)

                child_pred_boxes_tiled = child_pred_boxes.unsqueeze(dim=0).repeat(num_gt, 1, 1)
                child_gt_boxes_tiled = child_gt_boxes.unsqueeze(dim=1).repeat(1, num_child_parts, 1)

                dist_mat = self.boxLossEstimator(child_gt_boxes_tiled.view(-1, 10), child_pred_boxes_tiled.view(-1, 10)).view(-1, num_gt, num_child_parts)

                _, matched_gt_idx, matched_pred_idx = linear_assignment(dist_mat)

                # get edge ground truth
                edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors(
                    edge_types=self.conf.edge_types, device=child_feats.device, type_onehot=False)

                gt2pred = {gt_idx: pred_idx for gt_idx, pred_idx in zip(matched_gt_idx, matched_pred_idx)}
                edge_exists_gt = torch.zeros_like(edge_exists_logits)

                sym_from = []; sym_to = []; sym_mat = []; sym_type = []; adj_from = []; adj_to = [];
                for i in range(edge_indices_gt.shape[1]//2):
                    if edge_indices_gt[0, i, 0].item() not in gt2pred or edge_indices_gt[0, i, 1].item() not in gt2pred:
                        """
                            one of the adjacent nodes of the current gt edge was not matched 
                            to any node in the prediction, ignore this edge
                        """
                        continue
                    
                    # correlate gt edges to pred edges
                    edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()]
                    edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()]
                    edge_exists_gt[:, edge_from_idx, edge_to_idx, edge_type_list_gt[0:1, i]] = 1
                    edge_exists_gt[:, edge_to_idx, edge_from_idx, edge_type_list_gt[0:1, i]] = 1

                    # compute binary edge parameters for each matched pred edge
                    if edge_type_list_gt[0, i].item() == 0: # ADJ
                        adj_from.append(edge_from_idx)
                        adj_to.append(edge_to_idx)
                    else:   # SYM
                        if edge_type_list_gt[0, i].item() == 1: # ROT_SYM
                            mat1to2, mat2to1 = compute_sym.compute_rot_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        elif edge_type_list_gt[0, i].item() == 2: # TRANS_SYM
                            mat1to2, mat2to1 = compute_sym.compute_trans_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        else:   # REF_SYM
                            mat1to2, mat2to1 = compute_sym.compute_ref_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        sym_from.append(edge_from_idx)
                        sym_to.append(edge_to_idx)
                        sym_mat.append(torch.tensor(mat1to2, dtype=torch.float32, device=self.conf.device).view(1, 3, 4))
                        sym_type.append(edge_type_list_gt[0, i].item())

            # train the current node to be non-leaf
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))

            # train the current node box to gt
            all_boxes = []; all_leaf_boxes = [];
            box = self.box_decoder(node_latent)
            all_boxes.append(box)
            box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1))

            # gather information
            child_sem_gt_labels = []
            child_sem_pred_logits = []
            child_box_gt = []
            child_box_pred = []
            child_exists_gt = torch.zeros_like(child_exists_logits)
            for i in range(len(matched_gt_idx)):
                child_sem_gt_labels.append(gt_node.children[matched_gt_idx[i]].get_semantic_id())
                child_sem_pred_logits.append(child_sem_logits[0, matched_pred_idx[i], :].view(1, -1))
                child_box_gt.append(gt_node.children[matched_gt_idx[i]].get_box_quat().view(1, -1))
                child_box_pred.append(child_pred_boxes[matched_pred_idx[i], :].view(1, -1))
                child_exists_gt[:, matched_pred_idx[i], :] = 1

            # train semantic labels
            child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0)
            child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, \
                    device=child_sem_pred_logits.device)
            semantic_loss = self.semCELoss(child_sem_pred_logits, child_sem_gt_labels)
            semantic_loss = semantic_loss.sum()

            # train unused boxes to zeros
            unmatched_boxes = []
            for i in range(num_child_parts):
                if i not in matched_pred_idx:
                    unmatched_boxes.append(child_pred_boxes[i, 3:6].view(1, -1))
            if len(unmatched_boxes) > 0:
                unmatched_boxes = torch.cat(unmatched_boxes, dim=0)
                unused_box_loss = unmatched_boxes.pow(2).sum() * 0.01
            else:
                unused_box_loss = 0.0

            # train exist scores
            child_exists_loss = F.binary_cross_entropy_with_logits(\
                input=child_exists_logits, target=child_exists_gt, reduction='none')
            child_exists_loss = child_exists_loss.sum()

            # train edge exists scores
            edge_exists_loss = F.binary_cross_entropy_with_logits(\
                    input=edge_exists_logits, target=edge_exists_gt, reduction='none')
            edge_exists_loss = edge_exists_loss.sum()
            # rescale to make it comparable to other losses, 
            # which are in the order of the number of child nodes
            edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2]*edge_exists_gt.shape[3]) 

            # compute and train binary losses
            sym_loss = 0
            if len(sym_from) > 0:
                sym_from_th = torch.tensor(sym_from, dtype=torch.long, device=self.conf.device)
                obb_from = child_pred_boxes[sym_from_th, :]
                with torch.no_grad():
                    reweight_from = get_surface_reweighting_batch(obb_from[:, 3:6], self.unit_cube.size(0))
                pc_from = transform_pc_batch(self.unit_cube, obb_from)
                sym_to_th = torch.tensor(sym_to, dtype=torch.long, device=self.conf.device)
                obb_to = child_pred_boxes[sym_to_th, :]
                with torch.no_grad():
                    reweight_to = get_surface_reweighting_batch(obb_to[:, 3:6], self.unit_cube.size(0))
                pc_to = transform_pc_batch(self.unit_cube, obb_to)
                sym_mat_th = torch.cat(sym_mat, dim=0)
                transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \
                        sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1)
                dist1, dist2 = self.chamferLoss(transformed_pc_from, pc_to)
                loss1 = (dist1 * reweight_from).sum(dim=1) / (reweight_from.sum(dim=1) + 1e-12)
                loss2 = (dist2 * reweight_to).sum(dim=1) / (reweight_to.sum(dim=1) + 1e-12)
                loss = loss1 + loss2
                sym_loss = loss.sum()

            adj_loss = 0
            if len(adj_from) > 0:
                adj_from_th = torch.tensor(adj_from, dtype=torch.long, device=self.conf.device)
                obb_from = child_pred_boxes[adj_from_th, :]
                pc_from = transform_pc_batch(self.unit_cube, obb_from)
                adj_to_th = torch.tensor(adj_to, dtype=torch.long, device=self.conf.device)
                obb_to = child_pred_boxes[adj_to_th, :]
                pc_to = transform_pc_batch(self.unit_cube, obb_to)
                dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0])
                adj_loss = loss.sum()

            # call children + aggregate losses
            pred2allboxes = dict(); pred2allleafboxes = dict();
            for i in range(len(matched_gt_idx)):
                child_losses, child_all_boxes, child_all_leaf_boxes = self.node_recon_loss(
                    child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]])
                pred2allboxes[matched_pred_idx[i]] = child_all_boxes
                pred2allleafboxes[matched_pred_idx[i]] = child_all_leaf_boxes
                all_boxes.append(child_all_boxes)
                all_leaf_boxes.append(child_all_leaf_boxes)
                box_loss = box_loss + child_losses['box']
                anchor_loss = anchor_loss + child_losses['anchor'] 
                is_leaf_loss = is_leaf_loss + child_losses['leaf']
                child_exists_loss = child_exists_loss + child_losses['exists']
                semantic_loss = semantic_loss + child_losses['semantic']
                edge_exists_loss = edge_exists_loss + child_losses['edge_exists']
                sym_loss = sym_loss + child_losses['sym']
                adj_loss = adj_loss + child_losses['adj']

            # for sym-edges, train subtree to be symmetric
            for i in range(len(sym_from)):
                s1 = pred2allboxes[sym_from[i]].size(0)
                s2 = pred2allboxes[sym_to[i]].size(0)
                if s1 > 1 and s2 > 1:
                    obbs_from = pred2allboxes[sym_from[i]][1:, :]
                    obbs_to = pred2allboxes[sym_to[i]][1:, :]
                    pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(-1, 3)
                    pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(-1, 3)
                    transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \
                            sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1)
                    dist1, dist2 = self.chamferLoss(transformed_pc_from.view(1, -1, 3), pc_to.view(1, -1, 3))
                    sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2

            # for adj-edges, train leaf-nodes in subtrees to be adjacent
            for i in range(len(adj_from)):
                if pred2allboxes[adj_from[i]].size(0) > pred2allleafboxes[adj_from[i]].size(0) \
                        or pred2allboxes[adj_to[i]].size(0) > pred2allleafboxes[adj_to[i]].size(0):
                    obbs_from = pred2allleafboxes[adj_from[i]]
                    obbs_to = pred2allleafboxes[adj_to[i]]
                    pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(1, -1, 3)
                    pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(1, -1, 3)
                    dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                    adj_loss += dist1.min() + dist2.min()

            return {'box': box_loss + unused_box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 
                    'exists': child_exists_loss, 'semantic': semantic_loss,
                    'edge_exists': edge_exists_loss, 'sym': sym_loss, 'adj': adj_loss}, \
                            torch.cat(all_boxes, dim=0), torch.cat(all_leaf_boxes, dim=0)