コード例 #1
0
 def test_apply(self):
     x = torch.tensor([1])
     pos = torch.tensor([1])
     d1 = Data(x=2 * x, pos=2 * pos)
     d2 = Data(x=3 * x, pos=3 * pos)
     data = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2])
     data.apply(lambda x: 2 * x)
     self.assertEqual(data.x[0], 2)
     self.assertEqual(data.pos[0], 2)
     self.assertEqual(data.multiscale[0].pos[0], 4)
     self.assertEqual(data.multiscale[0].x[0], 4)
     self.assertEqual(data.multiscale[1].pos[0], 6)
     self.assertEqual(data.multiscale[1].x[0], 6)
コード例 #2
0
    def _multiscale_compute_fn(self,
                               batch,
                               collate_fn=None,
                               precompute_multi_scale=False,
                               num_scales=0,
                               sample_method='random'):
        batch = collate_fn(batch)
        if not precompute_multi_scale:
            return batch
        multiscale = []
        pos = batch.pos     # [B, N, 3]
        for i in range(num_scales):
            neighbor_idx = self._knn_search(pos, pos, self.kernel_size[i])      # [B, N, K]
            sample_num = pos.shape[1] // self.ratio[i]
            if sample_method.lower() == 'random':
                choice = torch.randperm(pos.shape[1])[:sample_num]
                sub_pos = pos[:, choice, :]             # random sampled pos   [B, S, 3]
                sub_idx = neighbor_idx[:, choice, :]    # the pool idx  [B, S, K]
            elif sample_method.lower() == 'fps':
                choice = tpcuda.furthest_point_sampling(pos.cuda(), sample_num).to(torch.long).cpu()
                sub_pos = pos.gather(dim=1, index=choice.unsqueeze(-1).repeat(1, 1, pos.shape[-1]))
                sub_idx = neighbor_idx.gather(dim=1, index=choice.unsqueeze(-1).repeat(1, 1, neighbor_idx.shape[-1]))
            else:
                raise NotImplementedError('Only `random` or `fps` sampling method is implemented!')

            up_idx = self._knn_search(sub_pos, pos, 1)      # [B, N, 1]
            multiscale.append(Data(pos=pos, neighbor_idx=neighbor_idx, sub_idx=sub_idx, up_idx=up_idx))
            pos = sub_pos

        return MultiScaleData(x=batch.x,
                              y=batch.y,
                              point_idx=batch.point_idx,
                              cloud_idx=batch.cloud_idx,
                              multiscale=multiscale)
コード例 #3
0
    def __call__(self, data: Data) -> MultiScaleData:
        # Compute sequentially multi_scale indexes on cpu
        data.contiguous()
        ms_data = MultiScaleData.from_data(data)
        precomputed = [Data(pos=data.pos)]
        upsample = []
        upsample_index = 0
        for index in range(self.num_layers):
            sampler, neighbour_finder = self.strategies["sampler"][
                index], self.strategies["neighbour_finder"][index]
            support = precomputed[index]
            new_data = Data(pos=support.pos)
            if sampler:
                query = sampler(new_data.clone())
                query.contiguous()

                if len(self.strategies["upsample_op"]):
                    if upsample_index >= len(self.strategies["upsample_op"]):
                        raise ValueError(
                            "You are missing some upsample blocks in your network"
                        )

                    upsampler = self.strategies["upsample_op"][upsample_index]
                    upsample_index += 1
                    pre_up = upsampler.precompute(query, support)
                    upsample.append(pre_up)
                    special_params = {}
                    special_params["x_idx"] = query.num_nodes
                    special_params["y_idx"] = support.num_nodes
                    setattr(
                        pre_up, "__inc__",
                        self.__inc__wrapper(pre_up.__inc__, special_params))
            else:
                query = new_data

            s_pos, q_pos = support.pos, query.pos
            if hasattr(query, "batch"):
                s_batch, q_batch = support.batch, query.batch
            else:
                s_batch, q_batch = (
                    torch.zeros((s_pos.shape[0]), dtype=torch.long),
                    torch.zeros((q_pos.shape[0]), dtype=torch.long),
                )

            idx_neighboors = neighbour_finder(s_pos,
                                              q_pos,
                                              batch_x=s_batch,
                                              batch_y=q_batch)
            special_params = {}
            special_params["idx_neighboors"] = s_pos.shape[0]
            setattr(query, "idx_neighboors", idx_neighboors)
            setattr(query, "__inc__",
                    self.__inc__wrapper(query.__inc__, special_params))
            precomputed.append(query)
        ms_data.multiscale = precomputed[1:]
        upsample.reverse()  # Switch to inner layer first
        ms_data.upsample = upsample
        return ms_data
コード例 #4
0
    def test_batch(self):
        x = torch.tensor([1])
        pos = x
        d1 = Data(x=x, pos=pos)
        d2 = Data(x=4 * x, pos=4 * pos)
        data1 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2])

        x = torch.tensor([2])
        pos = x
        d1 = Data(x=x, pos=pos)
        d2 = Data(x=4 * x, pos=4 * pos)
        data2 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2])

        batch = MultiScaleBatch.from_data_list([data1, data2])
        tt.assert_allclose(batch.x, torch.tensor([1, 2]))
        tt.assert_allclose(batch.batch, torch.tensor([0, 1]))

        ms_batches = batch.multiscale
        tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1]))
        tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1]))
        tt.assert_allclose(ms_batches[1].x, torch.tensor([4, 8]))