def test_search_key(self): samples = ocnn.octree_samples(['octree_1', 'octree_1']) octree = ocnn.octree_batch(samples).cuda() key = torch.cuda.LongTensor([28673, 281474976739335, 10]) idx_gt = torch.cuda.IntTensor([1, 15, -1]) idx = ocnn.octree_search_key(key, octree, 5, key_is_xyz=False, nempty=False) self.assertTrue((idx == idx_gt).cpu().numpy().all()) key = torch.cuda.LongTensor([28672, 28673, 281474976739328, 10]) idx_gt = torch.cuda.IntTensor([0, -1, 1, -1]) idx = ocnn.octree_search_key(key, octree, 5, key_is_xyz=False, nempty=True) self.assertTrue((idx == idx_gt).cpu().numpy().all())
def octree_trilinear_pts(data, octree, depth, pts): ''' Linear Interpolatation with input points. pts: (N, 4), i.e. N x (x, y, z, id). data: (1, C, H, 1) !!! Note: the pts should be scaled into [0, 2^depth] ''' mask = torch.cuda.FloatTensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]) masku = torch.cuda.LongTensor( [0, 4294967296, 65536, 4295032832, 1, 4294967297, 65537, 4295032833]) # 1. Neighborhood searching xyzf, ids = torch.split(pts, [3, 1], 1) xyzf = xyzf - 0.5 # since the value is defined on the center of each voxel xyzi = torch.floor(xyzf) # the integer part (N, 3) frac = xyzf - xyzi # the fraction part (N, 3) key = torch.cat([xyzi, ids], dim=1).short() # (N, 4) key = ocnn.octree_encode_key(key).long() # (N, ) key = (torch.unsqueeze(key, dim=1) + masku).view( -1) # (N, 1)->(N, 8)->(8*N,) idx = ocnn.octree_search_key(key, octree, depth, True) flgs = idx > -1 # valid indices idx = idx[flgs] # 2. Build the sparse matrix npt = pts.shape[0] ids = torch.arange(npt).cuda() ids = ids.view(-1, 1).repeat(1, 8).view(-1) ids = ids[flgs] indices = torch.cat( [torch.unsqueeze(ids, dim=1), torch.unsqueeze(idx, dim=1)], dim=1).long() maskc = 1 - mask frac = maskc - torch.unsqueeze(frac, dim=1) weight = torch.abs(torch.prod(frac, dim=2).view(-1)) weight = weight[flgs] h = data.shape[2] mat = torch.sparse.FloatTensor(indices.t(), weight, torch.Size([npt, h])).cuda() # 3. Interpolatation data = torch.squeeze(torch.squeeze(data, dim=0), dim=-1) data = torch.transpose(data, 0, 1) output = torch.sparse.mm(mat, data) norm = torch.sparse.mm(mat, torch.ones(h, 1).cuda()) output = torch.div(output, norm + 1e-10) output = torch.unsqueeze((torch.unsqueeze(output.t(), dim=0)), dim=-1) return output
def octree_nearest_pts(data, octree, depth, pts, nempty=False): key = pts.short() # (x, y, z, id) key = ocnn.octree_encode_key(key).long() # (N, ) idx = ocnn.octree_search_key(key, octree, depth, True, nempty) flgs = idx > -1 # valid indices idx = idx * flgs data = torch.squeeze(data).t() # (1, C, H, 1) -> (H, C) output = data[idx.long()] * flgs.unsqueeze(-1) output = torch.unsqueeze((torch.unsqueeze(output.t(), dim=0)), dim=-1) return output