Example #1
0
    def get_shape_chamfer_loss(self, pts, quat1, quat2, center1, center2,
                               part_cnt):
        num_point = pts.shape[1]
        part_pcs1 = qrot(quat1.unsqueeze(1).repeat(1, num_point, 1),
                         pts) + center1.unsqueeze(1).repeat(1, num_point, 1)
        part_pcs2 = qrot(quat2.unsqueeze(1).repeat(1, num_point, 1),
                         pts) + center2.unsqueeze(1).repeat(1, num_point, 1)
        t = 0
        shape_pcs1 = []
        shape_pcs2 = []
        for cnt in part_cnt:
            cur_shape_pc1 = part_pcs1[t:t + cnt].view(1, -1, 3)
            cur_shape_pc2 = part_pcs2[t:t + cnt].view(1, -1, 3)
            with torch.no_grad():
                idx1 = furthest_point_sample(cur_shape_pc1, 2048).long()[0]
                idx2 = furthest_point_sample(cur_shape_pc2, 2048).long()[0]
            shape_pcs1.append(cur_shape_pc1[:, idx1])
            shape_pcs2.append(cur_shape_pc2[:, idx2])
            t += cnt
        shape_pcs1 = torch.cat(shape_pcs1, dim=0)  # numshapes x 2048 x 3
        shape_pcs2 = torch.cat(shape_pcs2, dim=0)  # numshapes x 2048 x 3

        dist1, dist2 = chamfer_distance(shape_pcs1,
                                        shape_pcs2,
                                        transpose=False)
        loss_per_data = torch.mean(dist1, dim=1) + torch.mean(dist2, dim=1)

        return loss_per_data
Example #2
0
    def get_quat_loss(self, pts, quat1, quat2):
        num_point = pts.shape[1]

        pts1 = qrot(quat1.unsqueeze(1).repeat(1, num_point, 1), pts)
        pts2 = qrot(quat2.unsqueeze(1).repeat(1, num_point, 1), pts)

        dist1, dist2 = chamfer_distance(pts1, pts2, transpose=False)
        loss_per_data = torch.mean(dist1, dim=1) + torch.mean(dist2, dim=1)
        return loss_per_data
Example #3
0
    def get_anchor_loss(self, box_size, quat1, quat2, center1, center2):
        box1 = torch.cat([center1, box_size, quat1], dim=-1)
        box2 = torch.cat([center2, box_size, quat2], dim=-1)

        anchor1_pc = transform_pc_batch(self.unit_anchor, box1, anchor=True)
        anchor2_pc = transform_pc_batch(self.unit_anchor, box2, anchor=True)

        d1, d2 = chamfer_distance(anchor1_pc, anchor2_pc, transpose=False)
        loss_per_data = (d1.mean(dim=1) + d2.mean(dim=1)) / 2
        return loss_per_data
Example #4
0
    def get_box_loss(self, box_size, quat1, quat2, center1, center2):
        box1 = torch.cat([center1, box_size, quat1], dim=-1)
        box2 = torch.cat([center2, box_size, quat2], dim=-1)

        box1_pc = transform_pc_batch(self.unit_cube, box1)
        box2_pc = transform_pc_batch(self.unit_cube, box2)

        with torch.no_grad():
            box1_reweight = get_surface_reweighting_batch(
                box1[:, 3:6], self.unit_cube.size(0))
            box2_reweight = get_surface_reweighting_batch(
                box2[:, 3:6], self.unit_cube.size(0))

        d1, d2 = chamfer_distance(box1_pc, box2_pc, transpose=False)
        loss_per_data =  (d1 * box1_reweight).sum(dim=1) / (box1_reweight.sum(dim=1) + 1e-12) + \
                (d2 * box2_reweight).sum(dim=1) / (box2_reweight.sum(dim=1) + 1e-12)
        return loss_per_data
Example #5
0
    def get_adj_loss(self, pts, quat, center, adjs):
        num_point = pts.shape[1]

        part_pcs = qrot(quat.unsqueeze(1).repeat(1, num_point, 1),
                        pts) + center.unsqueeze(1).repeat(1, num_point, 1)

        loss = []
        for cur_shape_adj in adjs:
            cur_shape_loss = []
            for adj in cur_shape_adj:
                idx1, idx2 = adj
                dist1, dist2 = chamfer_distance(part_pcs[idx1].unsqueeze(0),
                                                part_pcs[idx2].unsqueeze(0),
                                                transpose=False)
                cur_loss = torch.min(dist1, dim=1)[0] + torch.min(dist2,
                                                                  dim=1)[0]
                cur_shape_loss.append(cur_loss)
            loss.append(torch.stack(cur_shape_loss).mean())
        return loss
Example #6
0
                fn = input_shape_id[t] + '-' + input_view_id[t] + '.npy'
                np.save(os.path.join(test_res_matched_dir, fn), to_save)
                print('saved', fn)
                t += cnt

        # get metric stats
        pred_pts = qrot(
            matched_pred_quat2_all.unsqueeze(1).repeat(1, num_point, 1),
            input_pts) + matched_pred_center2_all.unsqueeze(1).repeat(
                1, num_point, 1)
        gt_pts = qrot(
            matched_gt_quat_all.unsqueeze(1).repeat(1, num_point, 1),
            input_pts) + matched_gt_center_all.unsqueeze(1).repeat(
                1, num_point, 1)
        dist1, dist2 = chamfer_distance(gt_pts,
                                        pred_pts,
                                        transpose=False,
                                        sqrt=True)
        dist = torch.mean(dist1, dim=1).cpu().numpy() + torch.mean(
            dist2, dim=1).cpu().numpy()
        dist /= 2.0
        batch_size = input_box_size.shape[0]
        correct = 0
        cur_batch_accu = []
        cur_batch_correct = []
        t = 0
        for cnt in input_total_part_cnt:
            cur_shape_correct = 0
            for i in range(cnt):
                print('part dist', dist[i], 'thresh', eval_conf.thresh,
                      'correct', dist[i] < eval_conf.thresh)
                if dist[t + i] < eval_conf.thresh:
Example #7
0
    def linear_assignment(self, mask1, mask2, similar_cnt, pts, centers1,
                          quats1, centers2, quats2):
        '''
            mask1, mask 2:
                # part_cnt x 224 x 224 
            similar cnt
                # shape: max_mask  x  2
                # first index is the index of parts without similar parts 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, .... 
                # second index is the number of similar part count:       1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4, .... 

        '''

        bids = []
        ids1 = []
        ids2 = []
        inds1 = []
        inds2 = []
        img_size = mask1.shape[-1]
        t = 0

        num_point = pts.shape[1]
        max_num_part = centers1.shape[1]

        # part_cnt = [item for sublist in part_cnt for item in sublist]

        with torch.no_grad():
            while t < similar_cnt.shape[0]:

                cnt = similar_cnt[t].item()
                bids = [t] * cnt
                cur_mask1 = mask1[t:t + cnt].unsqueeze(1).repeat(
                    1, cnt, 1, 1).view(-1, img_size, img_size)
                cur_mask2 = mask2[t:t + cnt].unsqueeze(0).repeat(
                    cnt, 1, 1, 1).view(-1, img_size, img_size)

                dist_mat_mask = self.get_mask_loss(cur_mask1,
                                                   cur_mask2).view(cnt, cnt)

                dist_mat_mask = torch.clamp(dist_mat_mask, max=-0.1)

                cur_pts = pts[t:t + cnt]

                cur_quats1 = quats1[t:t + cnt].unsqueeze(1).repeat(
                    1, num_point, 1)
                cur_centers1 = centers1[t:t + cnt].unsqueeze(1).repeat(
                    1, num_point, 1)
                cur_pts1 = qrot(cur_quats1, cur_pts) + cur_centers1

                cur_quats2 = quats2[t:t + cnt].unsqueeze(1).repeat(
                    1, num_point, 1)
                cur_centers2 = centers2[t:t + cnt].unsqueeze(1).repeat(
                    1, num_point, 1)
                cur_pts2 = qrot(cur_quats2, cur_pts) + cur_centers2

                cur_pts1 = cur_pts1.unsqueeze(1).repeat(1, cnt, 1, 1).view(
                    -1, num_point, 3)
                cur_pts2 = cur_pts2.unsqueeze(0).repeat(cnt, 1, 1, 1).view(
                    -1, num_point, 3)

                dist1, dist2 = chamfer_distance(cur_pts1,
                                                cur_pts2,
                                                transpose=False)
                dist_mat_pts = (dist1.mean(1) + dist2.mean(1)).view(cnt, cnt)

                dist_mat_pts = torch.clamp(dist_mat_pts, max=1) * 0.1

                dist_mat = torch.add(dist_mat_mask, dist_mat_pts)

                t += cnt
                rind, cind = linear_sum_assignment(dist_mat.cpu().numpy())

                ids1 = list(rind)
                ids2 = list(cind)
                inds1 += [bids[i] + ids1[i] for i in range(len(ids1))]
                inds2 += [bids[i] + ids2[i] for i in range(len(ids2))]
        return inds1, inds2
Example #8
0
    def forward(self, part_pcs, part_feat_old, part_cnt, equiv_edge_indices):
        after_part_feat_3d = self.pointnet(part_pcs)

        part_feat_8d = torch.cat([part_feat_old, after_part_feat_3d], dim=1)
        part_feat_8d = torch.relu(self.mlp1(part_feat_8d))

        # perform graph-conv
        t = 0
        i = 0
        output_part_feat_8d = []
        for cnt in part_cnt:
            child_feats = part_feat_8d[t:t + cnt]
            iter_feats = [child_feats]

            cur_equiv_edge_indices = equiv_edge_indices[i]
            cur_equiv_edge_feats = cur_equiv_edge_indices.new_zeros(
                cur_equiv_edge_indices.shape[0], 1)

            # detect adj
            topk = min(5, cnt - 1)
            with torch.no_grad():
                cur_part_pcs = part_pcs[t:t + cnt]
                A = cur_part_pcs.unsqueeze(0).repeat(cnt, 1, 1,
                                                     1).view(cnt * cnt, -1, 3)
                B = cur_part_pcs.unsqueeze(1).repeat(1, cnt, 1,
                                                     1).view(cnt * cnt, -1, 3)
                dist1, dist2 = chamfer_distance(A, B, transpose=False)
                dist = dist1.min(dim=1)[0].view(cnt, cnt)
                cur_adj_edge_indices = []
                for j in range(cnt):
                    for k in dist[j].argsort()[1:topk + 1]:
                        cur_adj_edge_indices.append([j, k.item()])
                cur_adj_edge_indices = torch.Tensor(
                    cur_adj_edge_indices).long().to(
                        cur_equiv_edge_indices.device)
                cur_adj_edge_feats = cur_adj_edge_indices.new_ones(
                    cur_adj_edge_indices.shape[0], 1)

            cur_edge_indices = torch.cat(
                [cur_equiv_edge_indices, cur_adj_edge_indices], dim=0)
            cur_edge_feats = torch.cat(
                [cur_equiv_edge_feats, cur_adj_edge_feats], dim=0).float()

            for j in range(self.num_iteration):
                node_edge_feats = torch.cat([
                    child_feats[cur_edge_indices[:, 0], :],
                    child_feats[cur_edge_indices[:, 1], :], cur_edge_feats
                ],
                                            dim=1)
                node_edge_feats = torch.relu(
                    self.node_edge_op[j](node_edge_feats))
                new_child_feats = child_feats.new_zeros(cnt, 256)
                new_child_feats = torch_scatter.scatter_mean(
                    node_edge_feats,
                    cur_edge_indices[:, 0],
                    dim=0,
                    out=new_child_feats)
                child_feats = new_child_feats
                iter_feats.append(child_feats)

            all_iter_feat = torch.cat(iter_feats, dim=1)
            output_part_feat_8d.append(all_iter_feat)
            t += cnt
            i += 1

        feat = torch.cat(output_part_feat_8d, dim=0)

        center, quat = self.pose_decoder(feat)

        return center, quat