def test_gpu(self): pos = torch.randn([16, 100, 3]).cuda() pos_skip = torch.randn([16, 500, 3]).cuda() x = torch.randn([16, 30, 100], requires_grad=True).cuda() dist, idx = three_nn(pos_skip, pos) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_feats = three_interpolate(x, idx, weight) dist, idx = three_nn(pos_skip.cpu(), pos.cpu()) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_feats_cpu = three_interpolate(x.cpu(), idx, weight) torch.testing.assert_allclose(interpolated_feats_cpu, interpolated_feats.cpu())
def conv(self, pos, pos_skip, x): assert pos_skip.shape[2] == 3 if pos is not None: dist, idx = tp.three_nn(pos_skip, pos) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_feats = tp.three_interpolate(x, idx, weight) else: interpolated_feats = x.expand(*(x.size()[0:2] + (pos_skip.size(1), ))) return interpolated_feats