Ejemplo n.º 1
0
    def test_gather_scatter(self):
        """
        Check gather_scatter cuda and python versions give the same results.
        Check that gather_scatter cuda version throws an error if cpu tensors
        are given as input.
        """
        device = get_random_cuda_device()
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        w0 = nn.Linear(3, 1)
        input = w0(verts)

        # undirected
        output_python = gather_scatter_python(input, edges, False)
        output_cuda = _C.gather_scatter(input.to(device=device),
                                        edges.to(device=device), False, False)
        self.assertClose(output_cuda.cpu(), output_python)

        output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
        self.assertClose(output_cpu, output_python)

        # directed
        output_python = gather_scatter_python(input, edges, True)
        output_cuda = _C.gather_scatter(input.to(device=device),
                                        edges.to(device=device), True, False)
        self.assertClose(output_cuda.cpu(), output_python)
        output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
        self.assertClose(output_cpu, output_python)
Ejemplo n.º 2
0
    def test_gather_scatter(self):
        """
        Check gather_scatter cuda and python versions give the same results.
        Check that gather_scatter cuda version throws an error if cpu tensors
        are given as input.
        """
        device = torch.device("cuda:0")
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        w0 = nn.Linear(3, 1)
        input = w0(verts)

        # output
        output_cpu = gather_scatter_python(input, edges, False)
        output_cuda = _C.gather_scatter(input.to(device=device),
                                        edges.to(device=device), False, False)
        self.assertClose(output_cuda.cpu(), output_cpu)
        with self.assertRaises(Exception) as err:
            _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
        self.assertTrue("Not implemented on the CPU" in str(err.exception))

        # directed
        output_cpu = gather_scatter_python(input, edges, True)
        output_cuda = _C.gather_scatter(input.to(device=device),
                                        edges.to(device=device), True, False)
        self.assertClose(output_cuda.cpu(), output_cpu)
Ejemplo n.º 3
0
    def test_backward(self):
        device = torch.device("cuda:0")
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        verts_cuda = verts.clone().to(device)
        edges_cuda = edges.clone().to(device)
        verts.requires_grad = True
        verts_cuda.requires_grad = True

        neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
        neighbor_sums = gather_scatter_python(verts, edges, False)
        neighbor_sums_cuda.sum().backward()
        neighbor_sums.sum().backward()

        self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
Ejemplo n.º 4
0
    def test_backward(self):
        device = get_random_cuda_device()
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        verts_cpu = verts.clone()
        edges_cpu = edges.clone()
        verts_cuda = verts.clone().to(device)
        edges_cuda = edges.clone().to(device)
        verts.requires_grad = True
        verts_cpu.requires_grad = True
        verts_cuda.requires_grad = True

        neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
        neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
        neighbor_sums = gather_scatter_python(verts, edges, False)
        randoms = torch.rand_like(neighbor_sums)
        (neighbor_sums_cuda * randoms.to(device)).sum().backward()
        (neighbor_sums_cpu * randoms).sum().backward()
        (neighbor_sums * randoms).sum().backward()

        self.assertClose(verts.grad, verts_cuda.grad.cpu())
        self.assertClose(verts.grad, verts_cpu.grad)