def test_detach(self): tex = TexturesUV( maps=torch.ones((5, 16, 16, 3), requires_grad=True), faces_uvs=torch.rand(size=(5, 10, 3)), verts_uvs=torch.rand(size=(5, 15, 2)), ) tex.faces_uvs_list() tex.verts_uvs_list() tex_detached = tex.detach() self.assertFalse(tex_detached._maps_padded.requires_grad) self.assertClose(tex._maps_padded, tex_detached._maps_padded) self.assertFalse(tex_detached._verts_uvs_padded.requires_grad) self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded) self.assertFalse(tex_detached._faces_uvs_padded.requires_grad) self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded) for i in range(tex._N): self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad) self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i]) self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad) self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i]) # tex._maps_list is not use anywhere so it's not stored. We call it explicitly self.assertFalse(tex_detached.maps_list()[i].requires_grad) self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
def test_clone(self): tex = TexturesUV( maps=torch.ones((5, 16, 16, 3)), faces_uvs=torch.rand(size=(5, 10, 3)), verts_uvs=torch.rand(size=(5, 15, 2)), ) tex.faces_uvs_list() tex.verts_uvs_list() tex_cloned = tex.clone() self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded) self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded) self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded) self.assertClose(tex._maps_padded, tex_cloned._maps_padded) self.assertSeparate(tex.valid, tex_cloned.valid) self.assertTrue(tex.valid.eq(tex_cloned.valid).all()) for i in range(tex._N): self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i]) self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i]) self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i]) self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i]) # tex._maps_list is not use anywhere so it's not stored. We call it explicitly self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i]) self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
def test_padded_to_packed(self): # Case where each face in the mesh has 3 unique uv vertex indices # - i.e. even if a vertex is shared between multiple faces it will # have a unique uv coordinate for each face. N = 2 faces_uvs_list = [ torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]), torch.tensor([[0, 1, 2], [3, 4, 5]]), ] # (N, 3, 3) verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)] num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list] num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list] tex = TexturesUV( maps=torch.ones((N, 16, 16, 3)), faces_uvs=faces_uvs_list, verts_uvs=verts_uvs_list, ) # This is set inside Meshes when textures is passed as an input. # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity. tex1 = tex.clone() tex1._num_faces_per_mesh = num_faces_per_mesh tex1._num_verts_per_mesh = num_verts_per_mesh verts_list = tex1.verts_uvs_list() verts_padded = tex1.verts_uvs_padded() faces_list = tex1.faces_uvs_list() faces_padded = tex1.faces_uvs_padded() for f1, f2 in zip(faces_list, faces_uvs_list): self.assertTrue((f1 == f2).all().item()) for f1, f2 in zip(verts_list, verts_uvs_list): self.assertTrue((f1 == f2).all().item()) self.assertTrue(faces_padded.shape == (2, 3, 3)) self.assertTrue(verts_padded.shape == (2, 9, 2)) # Case where num_faces_per_mesh is not set and faces_verts_uvs # are initialized with a padded tensor. tex2 = TexturesUV( maps=torch.ones((N, 16, 16, 3)), verts_uvs=verts_padded, faces_uvs=faces_padded, ) faces_list = tex2.faces_uvs_list() verts_list = tex2.verts_uvs_list() for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)): n = num_faces_per_mesh[i] self.assertTrue((f1[:n] == f2).all().item()) for i, (f1, f2) in enumerate(zip(verts_list, verts_uvs_list)): n = num_verts_per_mesh[i] self.assertTrue((f1[:n] == f2).all().item())