def test_apply(self): x = torch.tensor([1]) pos = torch.tensor([1]) d1 = Data(x=2 * x, pos=2 * pos) d2 = Data(x=3 * x, pos=3 * pos) data = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2]) data.apply(lambda x: 2 * x) self.assertEqual(data.x[0], 2) self.assertEqual(data.pos[0], 2) self.assertEqual(data.multiscale[0].pos[0], 4) self.assertEqual(data.multiscale[0].x[0], 4) self.assertEqual(data.multiscale[1].pos[0], 6) self.assertEqual(data.multiscale[1].x[0], 6)
def _multiscale_compute_fn(self, batch, collate_fn=None, precompute_multi_scale=False, num_scales=0, sample_method='random'): batch = collate_fn(batch) if not precompute_multi_scale: return batch multiscale = [] pos = batch.pos # [B, N, 3] for i in range(num_scales): neighbor_idx = self._knn_search(pos, pos, self.kernel_size[i]) # [B, N, K] sample_num = pos.shape[1] // self.ratio[i] if sample_method.lower() == 'random': choice = torch.randperm(pos.shape[1])[:sample_num] sub_pos = pos[:, choice, :] # random sampled pos [B, S, 3] sub_idx = neighbor_idx[:, choice, :] # the pool idx [B, S, K] elif sample_method.lower() == 'fps': choice = tpcuda.furthest_point_sampling(pos.cuda(), sample_num).to(torch.long).cpu() sub_pos = pos.gather(dim=1, index=choice.unsqueeze(-1).repeat(1, 1, pos.shape[-1])) sub_idx = neighbor_idx.gather(dim=1, index=choice.unsqueeze(-1).repeat(1, 1, neighbor_idx.shape[-1])) else: raise NotImplementedError('Only `random` or `fps` sampling method is implemented!') up_idx = self._knn_search(sub_pos, pos, 1) # [B, N, 1] multiscale.append(Data(pos=pos, neighbor_idx=neighbor_idx, sub_idx=sub_idx, up_idx=up_idx)) pos = sub_pos return MultiScaleData(x=batch.x, y=batch.y, point_idx=batch.point_idx, cloud_idx=batch.cloud_idx, multiscale=multiscale)
def __call__(self, data: Data) -> MultiScaleData: # Compute sequentially multi_scale indexes on cpu data.contiguous() ms_data = MultiScaleData.from_data(data) precomputed = [Data(pos=data.pos)] upsample = [] upsample_index = 0 for index in range(self.num_layers): sampler, neighbour_finder = self.strategies["sampler"][ index], self.strategies["neighbour_finder"][index] support = precomputed[index] new_data = Data(pos=support.pos) if sampler: query = sampler(new_data.clone()) query.contiguous() if len(self.strategies["upsample_op"]): if upsample_index >= len(self.strategies["upsample_op"]): raise ValueError( "You are missing some upsample blocks in your network" ) upsampler = self.strategies["upsample_op"][upsample_index] upsample_index += 1 pre_up = upsampler.precompute(query, support) upsample.append(pre_up) special_params = {} special_params["x_idx"] = query.num_nodes special_params["y_idx"] = support.num_nodes setattr( pre_up, "__inc__", self.__inc__wrapper(pre_up.__inc__, special_params)) else: query = new_data s_pos, q_pos = support.pos, query.pos if hasattr(query, "batch"): s_batch, q_batch = support.batch, query.batch else: s_batch, q_batch = ( torch.zeros((s_pos.shape[0]), dtype=torch.long), torch.zeros((q_pos.shape[0]), dtype=torch.long), ) idx_neighboors = neighbour_finder(s_pos, q_pos, batch_x=s_batch, batch_y=q_batch) special_params = {} special_params["idx_neighboors"] = s_pos.shape[0] setattr(query, "idx_neighboors", idx_neighboors) setattr(query, "__inc__", self.__inc__wrapper(query.__inc__, special_params)) precomputed.append(query) ms_data.multiscale = precomputed[1:] upsample.reverse() # Switch to inner layer first ms_data.upsample = upsample return ms_data
def test_batch(self): x = torch.tensor([1]) pos = x d1 = Data(x=x, pos=pos) d2 = Data(x=4 * x, pos=4 * pos) data1 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2]) x = torch.tensor([2]) pos = x d1 = Data(x=x, pos=pos) d2 = Data(x=4 * x, pos=4 * pos) data2 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2]) batch = MultiScaleBatch.from_data_list([data1, data2]) tt.assert_allclose(batch.x, torch.tensor([1, 2])) tt.assert_allclose(batch.batch, torch.tensor([0, 1])) ms_batches = batch.multiscale tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1])) tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1])) tt.assert_allclose(ms_batches[1].x, torch.tensor([4, 8]))