def test_gather_scatter(self): """ Check gather_scatter cuda and python versions give the same results. Check that gather_scatter cuda version throws an error if cpu tensors are given as input. """ device = torch.device("cuda:0") mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() w0 = nn.Linear(3, 1) input = w0(verts) # output output_cpu = gather_scatter_python(input, edges, False) output_cuda = _C.gather_scatter(input.to(device=device), edges.to(device=device), False, False) self.assertClose(output_cuda.cpu(), output_cpu) with self.assertRaises(Exception) as err: _C.gather_scatter(input.cpu(), edges.cpu(), False, False) self.assertTrue("Not implemented on the CPU" in str(err.exception)) # directed output_cpu = gather_scatter_python(input, edges, True) output_cuda = _C.gather_scatter(input.to(device=device), edges.to(device=device), True, False) self.assertClose(output_cuda.cpu(), output_cpu)
def test_gather_scatter(self): """ Check gather_scatter cuda and python versions give the same results. Check that gather_scatter cuda version throws an error if cpu tensors are given as input. """ device = get_random_cuda_device() mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() w0 = nn.Linear(3, 1) input = w0(verts) # undirected output_python = gather_scatter_python(input, edges, False) output_cuda = _C.gather_scatter(input.to(device=device), edges.to(device=device), False, False) self.assertClose(output_cuda.cpu(), output_python) output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False) self.assertClose(output_cpu, output_python) # directed output_python = gather_scatter_python(input, edges, True) output_cuda = _C.gather_scatter(input.to(device=device), edges.to(device=device), True, False) self.assertClose(output_cuda.cpu(), output_python) output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False) self.assertClose(output_cpu, output_python)
def forward(ctx, input, edges, directed=False): """ Args: ctx: Context object used to calculate gradients. input: Tensor of shape (num_vertices, input_dim) edges: Tensor of edge indices of shape (num_edges, 2) directed: Bool indicating if edges are directed. Returns: output: Tensor of same shape as input. """ if not (input.dim() == 2): raise ValueError("input can only have 2 dimensions.") if not (edges.dim() == 2): raise ValueError("edges can only have 2 dimensions.") if not (edges.shape[1] == 2): raise ValueError("edges must be of shape (num_edges, 2).") if not (input.dtype == torch.float32): raise ValueError("input has to be of type torch.float32.") ctx.directed = directed input, edges = input.contiguous(), edges.contiguous() ctx.save_for_backward(edges) backward = False output = _C.gather_scatter(input, edges, directed, backward) return output
def backward(ctx, grad_output): grad_output = grad_output.contiguous() edges = ctx.saved_tensors[0] directed = ctx.directed backward = True grad_input = _C.gather_scatter(grad_output, edges, directed, backward) grad_edges = None grad_directed = None return grad_input, grad_edges, grad_directed