def test_gather_scatter(self): """ Check gather_scatter cuda and python versions give the same results. Check that gather_scatter cuda version throws an error if cpu tensors are given as input. """ device = get_random_cuda_device() mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() w0 = nn.Linear(3, 1) input = w0(verts) # undirected output_python = gather_scatter_python(input, edges, False) output_cuda = _C.gather_scatter(input.to(device=device), edges.to(device=device), False, False) self.assertClose(output_cuda.cpu(), output_python) output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False) self.assertClose(output_cpu, output_python) # directed output_python = gather_scatter_python(input, edges, True) output_cuda = _C.gather_scatter(input.to(device=device), edges.to(device=device), True, False) self.assertClose(output_cuda.cpu(), output_python) output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False) self.assertClose(output_cpu, output_python)
def test_gather_scatter(self): """ Check gather_scatter cuda and python versions give the same results. Check that gather_scatter cuda version throws an error if cpu tensors are given as input. """ device = torch.device("cuda:0") mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() w0 = nn.Linear(3, 1) input = w0(verts) # output output_cpu = gather_scatter_python(input, edges, False) output_cuda = _C.gather_scatter(input.to(device=device), edges.to(device=device), False, False) self.assertClose(output_cuda.cpu(), output_cpu) with self.assertRaises(Exception) as err: _C.gather_scatter(input.cpu(), edges.cpu(), False, False) self.assertTrue("Not implemented on the CPU" in str(err.exception)) # directed output_cpu = gather_scatter_python(input, edges, True) output_cuda = _C.gather_scatter(input.to(device=device), edges.to(device=device), True, False) self.assertClose(output_cuda.cpu(), output_cpu)
def test_backward(self): device = torch.device("cuda:0") mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() verts_cuda = verts.clone().to(device) edges_cuda = edges.clone().to(device) verts.requires_grad = True verts_cuda.requires_grad = True neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False) neighbor_sums = gather_scatter_python(verts, edges, False) neighbor_sums_cuda.sum().backward() neighbor_sums.sum().backward() self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
def test_backward(self): device = get_random_cuda_device() mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() verts_cpu = verts.clone() edges_cpu = edges.clone() verts_cuda = verts.clone().to(device) edges_cuda = edges.clone().to(device) verts.requires_grad = True verts_cpu.requires_grad = True verts_cuda.requires_grad = True neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False) neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False) neighbor_sums = gather_scatter_python(verts, edges, False) randoms = torch.rand_like(neighbor_sums) (neighbor_sums_cuda * randoms.to(device)).sum().backward() (neighbor_sums_cpu * randoms).sum().backward() (neighbor_sums * randoms).sum().backward() self.assertClose(verts.grad, verts_cuda.grad.cpu()) self.assertClose(verts.grad, verts_cpu.grad)