コード例 #1
0
ファイル: test_graph_conv.py プロジェクト: yamCG/pytorch3d
    def test_gather_scatter(self):
        """
        Check gather_scatter cuda and python versions give the same results.
        Check that gather_scatter cuda version throws an error if cpu tensors
        are given as input.
        """
        device = torch.device("cuda:0")
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        w0 = nn.Linear(3, 1)
        input = w0(verts)

        # output
        output_cpu = gather_scatter_python(input, edges, False)
        output_cuda = _C.gather_scatter(input.to(device=device),
                                        edges.to(device=device), False, False)
        self.assertClose(output_cuda.cpu(), output_cpu)
        with self.assertRaises(Exception) as err:
            _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
        self.assertTrue("Not implemented on the CPU" in str(err.exception))

        # directed
        output_cpu = gather_scatter_python(input, edges, True)
        output_cuda = _C.gather_scatter(input.to(device=device),
                                        edges.to(device=device), True, False)
        self.assertClose(output_cuda.cpu(), output_cpu)
コード例 #2
0
ファイル: test_graph_conv.py プロジェクト: r23/pytorch3d
    def test_gather_scatter(self):
        """
        Check gather_scatter cuda and python versions give the same results.
        Check that gather_scatter cuda version throws an error if cpu tensors
        are given as input.
        """
        device = get_random_cuda_device()
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        w0 = nn.Linear(3, 1)
        input = w0(verts)

        # undirected
        output_python = gather_scatter_python(input, edges, False)
        output_cuda = _C.gather_scatter(input.to(device=device),
                                        edges.to(device=device), False, False)
        self.assertClose(output_cuda.cpu(), output_python)

        output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
        self.assertClose(output_cpu, output_python)

        # directed
        output_python = gather_scatter_python(input, edges, True)
        output_cuda = _C.gather_scatter(input.to(device=device),
                                        edges.to(device=device), True, False)
        self.assertClose(output_cuda.cpu(), output_python)
        output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
        self.assertClose(output_cpu, output_python)
コード例 #3
0
ファイル: graph_conv.py プロジェクト: yamCG/pytorch3d
    def forward(ctx, input, edges, directed=False):
        """
        Args:
            ctx: Context object used to calculate gradients.
            input: Tensor of shape (num_vertices, input_dim)
            edges: Tensor of edge indices of shape (num_edges, 2)
            directed: Bool indicating if edges are directed.

        Returns:
            output: Tensor of same shape as input.
        """
        if not (input.dim() == 2):
            raise ValueError("input can only have 2 dimensions.")
        if not (edges.dim() == 2):
            raise ValueError("edges can only have 2 dimensions.")
        if not (edges.shape[1] == 2):
            raise ValueError("edges must be of shape (num_edges, 2).")
        if not (input.dtype == torch.float32):
            raise ValueError("input has to be of type torch.float32.")

        ctx.directed = directed
        input, edges = input.contiguous(), edges.contiguous()
        ctx.save_for_backward(edges)
        backward = False
        output = _C.gather_scatter(input, edges, directed, backward)
        return output
コード例 #4
0
ファイル: graph_conv.py プロジェクト: yamCG/pytorch3d
 def backward(ctx, grad_output):
     grad_output = grad_output.contiguous()
     edges = ctx.saved_tensors[0]
     directed = ctx.directed
     backward = True
     grad_input = _C.gather_scatter(grad_output, edges, directed, backward)
     grad_edges = None
     grad_directed = None
     return grad_input, grad_edges, grad_directed