def test_gpu(self):
        pos = torch.randn([16, 100, 3]).cuda()
        pos_skip = torch.randn([16, 500, 3]).cuda()
        x = torch.randn([16, 30, 100], requires_grad=True).cuda()

        dist, idx = three_nn(pos_skip, pos)
        dist_recip = 1.0 / (dist + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight = dist_recip / norm
        interpolated_feats = three_interpolate(x, idx, weight)

        dist, idx = three_nn(pos_skip.cpu(), pos.cpu())
        dist_recip = 1.0 / (dist + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight = dist_recip / norm
        interpolated_feats_cpu = three_interpolate(x.cpu(), idx, weight)

        torch.testing.assert_allclose(interpolated_feats_cpu,
                                      interpolated_feats.cpu())
예제 #2
0
    def conv(self, pos, pos_skip, x):
        assert pos_skip.shape[2] == 3

        if pos is not None:
            dist, idx = tp.three_nn(pos_skip, pos)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            interpolated_feats = tp.three_interpolate(x, idx, weight)
        else:
            interpolated_feats = x.expand(*(x.size()[0:2] +
                                            (pos_skip.size(1), )))

        return interpolated_feats