def test_backward(self): device = get_random_cuda_device() mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() verts_cpu = verts.clone() edges_cpu = edges.clone() verts_cuda = verts.clone().to(device) edges_cuda = edges.clone().to(device) verts.requires_grad = True verts_cpu.requires_grad = True verts_cuda.requires_grad = True neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False) neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False) neighbor_sums = gather_scatter_python(verts, edges, False) randoms = torch.rand_like(neighbor_sums) (neighbor_sums_cuda * randoms.to(device)).sum().backward() (neighbor_sums_cpu * randoms).sum().backward() (neighbor_sums * randoms).sum().backward() self.assertClose(verts.grad, verts_cuda.grad.cpu()) self.assertClose(verts.grad, verts_cpu.grad)
def test_backward(self): device = torch.device("cuda:0") mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() verts_cuda = verts.clone().to(device) edges_cuda = edges.clone().to(device) verts.requires_grad = True verts_cuda.requires_grad = True neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False) neighbor_sums = gather_scatter_python(verts, edges, False) neighbor_sums_cuda.sum().backward() neighbor_sums.sum().backward() self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())