Exemple #1
0
def box_dist(box_feature, gt_box_feature):
    pred_box_pc = utils.transform_pc_batch(unit_cube, box_feature)
    pred_reweight = utils.get_surface_reweighting_batch(box_feature[:, 3:6], unit_cube.size(0))
    gt_box_pc = utils.transform_pc_batch(unit_cube, gt_box_feature)
    gt_reweight = utils.get_surface_reweighting_batch(gt_box_feature[:, 3:6], unit_cube.size(0))
    dist1, dist2 = chamfer_loss(gt_box_pc, pred_box_pc)
    loss1 = (dist1 * gt_reweight).sum(dim=1) / (gt_reweight.sum(dim=1) + 1e-12)
    loss2 = (dist2 * pred_reweight).sum(dim=1) / (pred_reweight.sum(dim=1) + 1e-12)
    loss = (loss1 + loss2) / 2
    return loss
Exemple #2
0
    def get_anchor_loss(self, box_size, quat1, quat2, center1, center2):
        box1 = torch.cat([center1, box_size, quat1], dim=-1)
        box2 = torch.cat([center2, box_size, quat2], dim=-1)

        anchor1_pc = transform_pc_batch(self.unit_anchor, box1, anchor=True)
        anchor2_pc = transform_pc_batch(self.unit_anchor, box2, anchor=True)

        d1, d2 = chamfer_distance(anchor1_pc, anchor2_pc, transpose=False)
        loss_per_data = (d1.mean(dim=1) + d2.mean(dim=1)) / 2
        return loss_per_data
Exemple #3
0
 def anchorLossEstimator(self, box_feature, gt_box_feature):
     pred_anchor_pc = transform_pc_batch(self.anchor,
                                         box_feature,
                                         anchor=True)
     gt_anchor_pc = transform_pc_batch(self.anchor,
                                       gt_box_feature,
                                       anchor=True)
     dist1, dist2 = self.chamferLoss(gt_anchor_pc, pred_anchor_pc)
     loss = (dist1.mean(dim=1) + dist2.mean(dim=1)) / 2
     return loss
 def boxLossEstimator(self, box_feature, gt_box_feature):
     pred_box_pc = transform_pc_batch(self.unit_cube, box_feature)
     with torch.no_grad():
         pred_reweight = get_surface_reweighting_batch(box_feature[:, 3:6], self.unit_cube.size(0))
     gt_box_pc = transform_pc_batch(self.unit_cube, gt_box_feature)
     with torch.no_grad():
         gt_reweight = get_surface_reweighting_batch(gt_box_feature[:, 3:6], self.unit_cube.size(0))
     dist1, dist2 = self.chamferLoss(gt_box_pc, pred_box_pc)
     loss1 = (dist1 * gt_reweight).sum(dim=1) / (gt_reweight.sum(dim=1) + 1e-12)
     loss2 = (dist2 * pred_reweight).sum(dim=1) / (pred_reweight.sum(dim=1) + 1e-12)
     loss = (loss1 + loss2) / 2
     return loss
Exemple #5
0
    def get_box_loss(self, box_size, quat1, quat2, center1, center2):
        box1 = torch.cat([center1, box_size, quat1], dim=-1)
        box2 = torch.cat([center2, box_size, quat2], dim=-1)

        box1_pc = transform_pc_batch(self.unit_cube, box1)
        box2_pc = transform_pc_batch(self.unit_cube, box2)

        with torch.no_grad():
            box1_reweight = get_surface_reweighting_batch(
                box1[:, 3:6], self.unit_cube.size(0))
            box2_reweight = get_surface_reweighting_batch(
                box2[:, 3:6], self.unit_cube.size(0))

        d1, d2 = chamfer_distance(box1_pc, box2_pc, transpose=False)
        loss_per_data =  (d1 * box1_reweight).sum(dim=1) / (box1_reweight.sum(dim=1) + 1e-12) + \
                (d2 * box2_reweight).sum(dim=1) / (box2_reweight.sum(dim=1) + 1e-12)
        return loss_per_data
    def node_recon_loss(self, node_latent, gt_node):
        if gt_node.is_leaf:
            box = self.box_decoder(node_latent)
            box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))
            return {'box': box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 
                    'exists': torch.zeros_like(box_loss), 'semantic': torch.zeros_like(box_loss),
                    'edge_exists': torch.zeros_like(box_loss),
                    'sym': torch.zeros_like(box_loss), 'adj': torch.zeros_like(box_loss)}, box, box
        else:
            child_feats, child_sem_logits, child_exists_logits, edge_exists_logits = \
                    self.child_decoder(node_latent)

            # generate box prediction for each child
            feature_len = node_latent.size(1)
            child_pred_boxes = self.box_decoder(child_feats.view(-1, feature_len))
            num_child_parts = child_pred_boxes.size(0)

            # perform hungarian matching between pred boxes and gt boxes
            with torch.no_grad():
                child_gt_boxes = torch.cat([child_node.get_box_quat().view(1, -1) for child_node in gt_node.children], dim=0)
                num_gt = child_gt_boxes.size(0)

                child_pred_boxes_tiled = child_pred_boxes.unsqueeze(dim=0).repeat(num_gt, 1, 1)
                child_gt_boxes_tiled = child_gt_boxes.unsqueeze(dim=1).repeat(1, num_child_parts, 1)

                dist_mat = self.boxLossEstimator(child_gt_boxes_tiled.view(-1, 10), child_pred_boxes_tiled.view(-1, 10)).view(-1, num_gt, num_child_parts)

                _, matched_gt_idx, matched_pred_idx = linear_assignment(dist_mat)

                # get edge ground truth
                edge_type_list_gt, edge_indices_gt = gt_node.edge_tensors(
                    edge_types=self.conf.edge_types, device=child_feats.device, type_onehot=False)

                gt2pred = {gt_idx: pred_idx for gt_idx, pred_idx in zip(matched_gt_idx, matched_pred_idx)}
                edge_exists_gt = torch.zeros_like(edge_exists_logits)

                sym_from = []; sym_to = []; sym_mat = []; sym_type = []; adj_from = []; adj_to = [];
                for i in range(edge_indices_gt.shape[1]//2):
                    if edge_indices_gt[0, i, 0].item() not in gt2pred or edge_indices_gt[0, i, 1].item() not in gt2pred:
                        """
                            one of the adjacent nodes of the current gt edge was not matched 
                            to any node in the prediction, ignore this edge
                        """
                        continue
                    
                    # correlate gt edges to pred edges
                    edge_from_idx = gt2pred[edge_indices_gt[0, i, 0].item()]
                    edge_to_idx = gt2pred[edge_indices_gt[0, i, 1].item()]
                    edge_exists_gt[:, edge_from_idx, edge_to_idx, edge_type_list_gt[0:1, i]] = 1
                    edge_exists_gt[:, edge_to_idx, edge_from_idx, edge_type_list_gt[0:1, i]] = 1

                    # compute binary edge parameters for each matched pred edge
                    if edge_type_list_gt[0, i].item() == 0: # ADJ
                        adj_from.append(edge_from_idx)
                        adj_to.append(edge_to_idx)
                    else:   # SYM
                        if edge_type_list_gt[0, i].item() == 1: # ROT_SYM
                            mat1to2, mat2to1 = compute_sym.compute_rot_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        elif edge_type_list_gt[0, i].item() == 2: # TRANS_SYM
                            mat1to2, mat2to1 = compute_sym.compute_trans_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        else:   # REF_SYM
                            mat1to2, mat2to1 = compute_sym.compute_ref_sym(child_pred_boxes[edge_from_idx].cpu().detach().numpy(), child_pred_boxes[edge_to_idx].cpu().detach().numpy())
                        sym_from.append(edge_from_idx)
                        sym_to.append(edge_to_idx)
                        sym_mat.append(torch.tensor(mat1to2, dtype=torch.float32, device=self.conf.device).view(1, 3, 4))
                        sym_type.append(edge_type_list_gt[0, i].item())

            # train the current node to be non-leaf
            is_leaf_logit = self.leaf_classifier(node_latent)
            is_leaf_loss = self.isLeafLossEstimator(is_leaf_logit, is_leaf_logit.new_tensor(gt_node.is_leaf).view(1, -1))

            # train the current node box to gt
            all_boxes = []; all_leaf_boxes = [];
            box = self.box_decoder(node_latent)
            all_boxes.append(box)
            box_loss = self.boxLossEstimator(box, gt_node.get_box_quat().view(1, -1))
            anchor_loss = self.anchorLossEstimator(box, gt_node.get_box_quat().view(1, -1))

            # gather information
            child_sem_gt_labels = []
            child_sem_pred_logits = []
            child_box_gt = []
            child_box_pred = []
            child_exists_gt = torch.zeros_like(child_exists_logits)
            for i in range(len(matched_gt_idx)):
                child_sem_gt_labels.append(gt_node.children[matched_gt_idx[i]].get_semantic_id())
                child_sem_pred_logits.append(child_sem_logits[0, matched_pred_idx[i], :].view(1, -1))
                child_box_gt.append(gt_node.children[matched_gt_idx[i]].get_box_quat().view(1, -1))
                child_box_pred.append(child_pred_boxes[matched_pred_idx[i], :].view(1, -1))
                child_exists_gt[:, matched_pred_idx[i], :] = 1

            # train semantic labels
            child_sem_pred_logits = torch.cat(child_sem_pred_logits, dim=0)
            child_sem_gt_labels = torch.tensor(child_sem_gt_labels, dtype=torch.int64, \
                    device=child_sem_pred_logits.device)
            semantic_loss = self.semCELoss(child_sem_pred_logits, child_sem_gt_labels)
            semantic_loss = semantic_loss.sum()

            # train unused boxes to zeros
            unmatched_boxes = []
            for i in range(num_child_parts):
                if i not in matched_pred_idx:
                    unmatched_boxes.append(child_pred_boxes[i, 3:6].view(1, -1))
            if len(unmatched_boxes) > 0:
                unmatched_boxes = torch.cat(unmatched_boxes, dim=0)
                unused_box_loss = unmatched_boxes.pow(2).sum() * 0.01
            else:
                unused_box_loss = 0.0

            # train exist scores
            child_exists_loss = F.binary_cross_entropy_with_logits(\
                input=child_exists_logits, target=child_exists_gt, reduction='none')
            child_exists_loss = child_exists_loss.sum()

            # train edge exists scores
            edge_exists_loss = F.binary_cross_entropy_with_logits(\
                    input=edge_exists_logits, target=edge_exists_gt, reduction='none')
            edge_exists_loss = edge_exists_loss.sum()
            # rescale to make it comparable to other losses, 
            # which are in the order of the number of child nodes
            edge_exists_loss = edge_exists_loss / (edge_exists_gt.shape[2]*edge_exists_gt.shape[3]) 

            # compute and train binary losses
            sym_loss = 0
            if len(sym_from) > 0:
                sym_from_th = torch.tensor(sym_from, dtype=torch.long, device=self.conf.device)
                obb_from = child_pred_boxes[sym_from_th, :]
                with torch.no_grad():
                    reweight_from = get_surface_reweighting_batch(obb_from[:, 3:6], self.unit_cube.size(0))
                pc_from = transform_pc_batch(self.unit_cube, obb_from)
                sym_to_th = torch.tensor(sym_to, dtype=torch.long, device=self.conf.device)
                obb_to = child_pred_boxes[sym_to_th, :]
                with torch.no_grad():
                    reweight_to = get_surface_reweighting_batch(obb_to[:, 3:6], self.unit_cube.size(0))
                pc_to = transform_pc_batch(self.unit_cube, obb_to)
                sym_mat_th = torch.cat(sym_mat, dim=0)
                transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat_th[:, :, :3], 1, 2)) + \
                        sym_mat_th[:, :, 3].unsqueeze(dim=1).repeat(1, pc_from.size(1), 1)
                dist1, dist2 = self.chamferLoss(transformed_pc_from, pc_to)
                loss1 = (dist1 * reweight_from).sum(dim=1) / (reweight_from.sum(dim=1) + 1e-12)
                loss2 = (dist2 * reweight_to).sum(dim=1) / (reweight_to.sum(dim=1) + 1e-12)
                loss = loss1 + loss2
                sym_loss = loss.sum()

            adj_loss = 0
            if len(adj_from) > 0:
                adj_from_th = torch.tensor(adj_from, dtype=torch.long, device=self.conf.device)
                obb_from = child_pred_boxes[adj_from_th, :]
                pc_from = transform_pc_batch(self.unit_cube, obb_from)
                adj_to_th = torch.tensor(adj_to, dtype=torch.long, device=self.conf.device)
                obb_to = child_pred_boxes[adj_to_th, :]
                pc_to = transform_pc_batch(self.unit_cube, obb_to)
                dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                loss = (dist1.min(dim=1)[0] + dist2.min(dim=1)[0])
                adj_loss = loss.sum()

            # call children + aggregate losses
            pred2allboxes = dict(); pred2allleafboxes = dict();
            for i in range(len(matched_gt_idx)):
                child_losses, child_all_boxes, child_all_leaf_boxes = self.node_recon_loss(
                    child_feats[:, matched_pred_idx[i], :], gt_node.children[matched_gt_idx[i]])
                pred2allboxes[matched_pred_idx[i]] = child_all_boxes
                pred2allleafboxes[matched_pred_idx[i]] = child_all_leaf_boxes
                all_boxes.append(child_all_boxes)
                all_leaf_boxes.append(child_all_leaf_boxes)
                box_loss = box_loss + child_losses['box']
                anchor_loss = anchor_loss + child_losses['anchor'] 
                is_leaf_loss = is_leaf_loss + child_losses['leaf']
                child_exists_loss = child_exists_loss + child_losses['exists']
                semantic_loss = semantic_loss + child_losses['semantic']
                edge_exists_loss = edge_exists_loss + child_losses['edge_exists']
                sym_loss = sym_loss + child_losses['sym']
                adj_loss = adj_loss + child_losses['adj']

            # for sym-edges, train subtree to be symmetric
            for i in range(len(sym_from)):
                s1 = pred2allboxes[sym_from[i]].size(0)
                s2 = pred2allboxes[sym_to[i]].size(0)
                if s1 > 1 and s2 > 1:
                    obbs_from = pred2allboxes[sym_from[i]][1:, :]
                    obbs_to = pred2allboxes[sym_to[i]][1:, :]
                    pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(-1, 3)
                    pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(-1, 3)
                    transformed_pc_from = pc_from.matmul(torch.transpose(sym_mat[i][0, :, :3], 0, 1)) + \
                            sym_mat[i][0, :, 3].unsqueeze(dim=0).repeat(pc_from.size(0), 1)
                    dist1, dist2 = self.chamferLoss(transformed_pc_from.view(1, -1, 3), pc_to.view(1, -1, 3))
                    sym_loss += (dist1.mean() + dist2.mean()) * (s1 + s2) / 2

            # for adj-edges, train leaf-nodes in subtrees to be adjacent
            for i in range(len(adj_from)):
                if pred2allboxes[adj_from[i]].size(0) > pred2allleafboxes[adj_from[i]].size(0) \
                        or pred2allboxes[adj_to[i]].size(0) > pred2allleafboxes[adj_to[i]].size(0):
                    obbs_from = pred2allleafboxes[adj_from[i]]
                    obbs_to = pred2allleafboxes[adj_to[i]]
                    pc_from = transform_pc_batch(self.unit_cube, obbs_from).view(1, -1, 3)
                    pc_to = transform_pc_batch(self.unit_cube, obbs_to).view(1, -1, 3)
                    dist1, dist2 = self.chamferLoss(pc_from, pc_to)
                    adj_loss += dist1.min() + dist2.min()

            return {'box': box_loss + unused_box_loss, 'leaf': is_leaf_loss, 'anchor': anchor_loss, 
                    'exists': child_exists_loss, 'semantic': semantic_loss,
                    'edge_exists': edge_exists_loss, 'sym': sym_loss, 'adj': adj_loss}, \
                            torch.cat(all_boxes, dim=0), torch.cat(all_leaf_boxes, dim=0)