def test_interpolate_attributes_grad(self):
     verts = torch.randn((4, 3), dtype=torch.float32)
     faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
     vert_tex = torch.tensor(
         [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
         dtype=torch.float32,
         requires_grad=True,
     )
     tex = Textures(verts_rgb=vert_tex[None, :])
     mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
     pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
     barycentric_coords = torch.tensor([[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]],
                                       dtype=torch.float32).view(
                                           1, 1, 1, 2, -1)
     fragments = Fragments(
         pix_to_face=pix_to_face,
         bary_coords=barycentric_coords,
         zbuf=torch.ones_like(pix_to_face),
         dists=torch.ones_like(pix_to_face),
     )
     grad_vert_tex = torch.tensor(
         [
             [0.3, 0.3, 0.3],
             [0.9, 0.9, 0.9],
             [0.5, 0.5, 0.5],
             [0.3, 0.3, 0.3],
         ],
         dtype=torch.float32,
     )
     texels = interpolate_vertex_colors(fragments, mesh)
     texels.sum().backward()
     self.assertTrue(hasattr(vert_tex, "grad"))
     self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
 def test_interpolate_attributes(self):
     """
     This tests both interpolate_vertex_colors as well as
     interpolate_face_attributes.
     """
     verts = torch.randn((4, 3), dtype=torch.float32)
     faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
     vert_tex = torch.tensor([[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
                             dtype=torch.float32)
     tex = Textures(verts_rgb=vert_tex[None, :])
     mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
     pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
     barycentric_coords = torch.tensor([[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]],
                                       dtype=torch.float32).view(
                                           1, 1, 1, 2, -1)
     expected_vals = torch.tensor([[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]],
                                  dtype=torch.float32).view(1, 1, 1, 2, -1)
     fragments = Fragments(
         pix_to_face=pix_to_face,
         bary_coords=barycentric_coords,
         zbuf=torch.ones_like(pix_to_face),
         dists=torch.ones_like(pix_to_face),
     )
     texels = interpolate_vertex_colors(fragments, mesh)
     self.assertTrue(torch.allclose(texels, expected_vals[None, :]))