Exemplo n.º 1
0
    def test_backward(self):
        device = get_random_cuda_device()
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        verts_cpu = verts.clone()
        edges_cpu = edges.clone()
        verts_cuda = verts.clone().to(device)
        edges_cuda = edges.clone().to(device)
        verts.requires_grad = True
        verts_cpu.requires_grad = True
        verts_cuda.requires_grad = True

        neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
        neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
        neighbor_sums = gather_scatter_python(verts, edges, False)
        randoms = torch.rand_like(neighbor_sums)
        (neighbor_sums_cuda * randoms.to(device)).sum().backward()
        (neighbor_sums_cpu * randoms).sum().backward()
        (neighbor_sums * randoms).sum().backward()

        self.assertClose(verts.grad, verts_cuda.grad.cpu())
        self.assertClose(verts.grad, verts_cpu.grad)
Exemplo n.º 2
0
    def test_backward(self):
        device = torch.device("cuda:0")
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        verts_cuda = verts.clone().to(device)
        edges_cuda = edges.clone().to(device)
        verts.requires_grad = True
        verts_cuda.requires_grad = True

        neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
        neighbor_sums = gather_scatter_python(verts, edges, False)
        neighbor_sums_cuda.sum().backward()
        neighbor_sums.sum().backward()

        self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())