def test_render_normal(self, face_vertices_camera, face_vertices_image, face_camera_normals_z, height, width, dtype, device): batch_size = face_vertices_camera.shape[0] face_normals_unit = face_normals(face_vertices_camera, unit=True) face_attributes = face_normals_unit.unsqueeze(-2).repeat(1, 1, 3, 1) # imfeat is interpolated features # improb is the soft mask # imfaceidx is the face index map, which pixel is covered by which face # it starts from 1, 0 is void. imfeat, improb, imfaceidx = dibr_rasterization( height, width, face_vertices_camera[:, :, :, 2], face_vertices_image, face_attributes, face_camera_normals_z) images = (imfeat + 1) / 2 images_gt = [ torch.from_numpy( np.array( Image.open( os.path.join(SAMPLE_DIR, f'vertex_normal_{bs}.png')))) for bs in range(batch_size) ] images_gt = torch.stack(images_gt, dim=0).to(device, dtype) / 255. if dtype == torch.double: num_pix_diff_tol = 8 else: num_pix_diff_tol = 0 num_pix_diff = torch.sum( ~torch.isclose(images, images_gt, atol=1. / 255.)) assert num_pix_diff <= num_pix_diff_tol
def test_prepare_vertices(self, vertices, faces, camera_rot, camera_trans, camera_proj, face_vertices_camera, face_vertices_image): _face_vertices_camera, _face_vertices_image, _face_normals = \ prepare_vertices(vertices, faces, camera_proj, camera_rot, camera_trans) assert torch.equal(face_vertices_camera, _face_vertices_camera) assert torch.equal(face_vertices_image, _face_vertices_image) assert torch.equal(face_normals(face_vertices_camera, unit=True), _face_normals)
def test_render_texture_with_light(self, uvs, faces, texture_maps, lights, face_vertices_camera, face_vertices_image, face_camera_normals_z, height, width, dtype, device): batch_size = faces.shape[0] # Note: in this example uv face is the same as mesh face # but they could be different face_uvs = index_vertices_by_faces(uvs, faces) # normal face_normals_unit = face_normals(face_vertices_camera, unit=True) face_normals_unit = face_normals_unit.unsqueeze(-2).repeat(1, 1, 3, 1) # merge them together face_attributes = [ torch.ones((*face_uvs.shape[:-1], 1), device=device, dtype=dtype), face_uvs, face_normals_unit ] (texmask, texcoord, imnormal), improb, imidx = dibr_rasterization( height, width, face_vertices_camera[:, :, :, 2], face_vertices_image, face_attributes, face_camera_normals_z) texcolor = texture_mapping(texcoord, texture_maps, mode='nearest') coef = spherical_harmonic_lighting(imnormal, lights) images = torch.clamp(texmask * texcolor * coef.unsqueeze(-1), 0, 1) if dtype == torch.double: num_pix_diff_tol = 74 # (over 2 x 256 x 512 x 3 pixels) else: num_pix_diff_tol = 0 images_gt = [ torch.from_numpy( np.array( Image.open( os.path.join(SAMPLE_DIR, f'texture_light_{bs}.png')))) for bs in range(batch_size) ] images_gt = torch.stack(images_gt, dim=0).to(device, dtype) / 255. num_pix_diff = torch.sum( ~torch.isclose(images, images_gt, atol=1. / 255.)) assert num_pix_diff <= num_pix_diff_tol
def test_sample_points(self, vertices, faces, device, dtype): batch_size, num_vertices = vertices.shape[:2] num_faces = faces.shape[0] num_samples = 1000 points, face_choices = mesh.sample_points(vertices, faces, num_samples) check_tensor(points, shape=(batch_size, num_samples, 3), dtype=dtype, device=device) check_tensor(face_choices, shape=(batch_size, num_samples), dtype=torch.long, device=device) # check that all faces are sampled num_0 = torch.sum(face_choices == 0, dim=1) assert torch.all(num_0 + torch.sum(face_choices == 1, dim=1) == num_samples) sampling_prob = num_samples / 3. tolerance = sampling_prob * 0.1 assert torch.all(num_0 < sampling_prob + tolerance) and \ torch.all(num_0 > sampling_prob - tolerance) face_vertices = mesh.index_vertices_by_faces(vertices, faces) face_vertices_choices = torch.gather( face_vertices, 1, face_choices[:, :, None, None].repeat(1, 1, 3, 3)) # compute distance from the point to the plan of the face picked face_normals = mesh.face_normals(face_vertices_choices, unit=True) v0_p = points - face_vertices_choices[:, :, 0] # batch_size x num_points x 3 len_v0_p = torch.sqrt(torch.sum(v0_p**2, dim=-1)) cos_a = torch.matmul(v0_p.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)).reshape( batch_size, num_samples) / len_v0_p point_to_face_dist = len_v0_p * cos_a if dtype == torch.half: atol = 1e-2 rtol = 1e-3 else: atol = 1e-4 rtol = 1e-5 # check that the point is close to the plan assert torch.allclose(point_to_face_dist, torch.zeros((batch_size, num_samples), device=device, dtype=dtype), atol=atol, rtol=rtol) # check that the point lie in the triangle edges0 = face_vertices_choices[:, :, 1] - face_vertices_choices[:, :, 0] edges1 = face_vertices_choices[:, :, 2] - face_vertices_choices[:, :, 1] edges2 = face_vertices_choices[:, :, 0] - face_vertices_choices[:, :, 2] v0_p = points - face_vertices_choices[:, :, 0] v1_p = points - face_vertices_choices[:, :, 1] v2_p = points - face_vertices_choices[:, :, 2] # Normals of the triangle formed by an edge and the point normals1 = torch.cross(edges0, v0_p) normals2 = torch.cross(edges1, v1_p) normals3 = torch.cross(edges2, v2_p) # cross-product of those normals with the face normals must be positive margin = -5e-3 if dtype == torch.half else 0. assert torch.all( torch.matmul(normals1.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin) assert torch.all( torch.matmul(normals2.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin) assert torch.all( torch.matmul(normals3.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin)
def test_packed_sample_points(self, packed_vertices_info, packed_faces_info, device, dtype): vertices, first_idx_vertices = packed_vertices_info faces, num_faces_per_mesh = packed_faces_info total_num_vertices = vertices.shape[0] total_num_faces = faces.shape[0] batch_size = num_faces_per_mesh.shape[0] num_samples = 1000 points, face_choices = mesh.packed_sample_points( vertices, first_idx_vertices, faces, num_faces_per_mesh, num_samples) check_tensor(points, shape=(batch_size, num_samples, 3), dtype=dtype, device=device) check_tensor(face_choices, shape=(batch_size, num_samples), dtype=torch.long, device=device) # check that all faces are sampled assert torch.all(face_choices[1] == 2) num_0 = torch.sum(face_choices[0] == 0) assert num_0 + torch.sum(face_choices[0] == 1) == num_samples sampling_prob = num_samples / 3. tolerance = sampling_prob * 0.1 assert (num_0 < sampling_prob + tolerance) and \ (num_0 > sampling_prob - tolerance) merged_faces = faces + tile_to_packed( first_idx_vertices[:-1].to(vertices.device), num_faces_per_mesh) face_vertices = torch.index_select(vertices, 0, merged_faces.reshape(-1)).reshape( total_num_faces, 3, 3) face_vertices_choices = torch.gather( face_vertices, 0, face_choices.reshape(-1, 1, 1).repeat(1, 3, 3)).reshape(batch_size, num_samples, 3, 3) # compute distance from the point to the plan of the face picked face_normals = mesh.face_normals(face_vertices_choices, unit=True) v0_p = points - face_vertices_choices[:, :, 0] # batch_size x num_points x 3 len_v0_p = torch.sqrt(torch.sum(v0_p**2, dim=-1)) cos_a = torch.matmul(v0_p.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)).reshape( batch_size, num_samples) / len_v0_p point_to_face_dist = len_v0_p * cos_a if dtype == torch.half: atol = 1e-2 rtol = 1e-3 else: atol = 1e-4 rtol = 1e-5 # check that the point is close to the plan assert torch.allclose(point_to_face_dist, torch.zeros((batch_size, num_samples), device=device, dtype=dtype), atol=atol, rtol=rtol) # check that the point lie in the triangle edges0 = face_vertices_choices[:, :, 1] - face_vertices_choices[:, :, 0] edges1 = face_vertices_choices[:, :, 2] - face_vertices_choices[:, :, 1] edges2 = face_vertices_choices[:, :, 0] - face_vertices_choices[:, :, 2] v0_p = points - face_vertices_choices[:, :, 0] v1_p = points - face_vertices_choices[:, :, 1] v2_p = points - face_vertices_choices[:, :, 2] # Normals of the triangle formed by an edge and the point normals1 = torch.cross(edges0, v0_p) normals2 = torch.cross(edges1, v1_p) normals3 = torch.cross(edges2, v2_p) # cross-product of those normals with the face normals must be positive margin = -2e-3 if dtype == torch.half else 0. assert torch.all( torch.matmul(normals1.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin) assert torch.all( torch.matmul(normals2.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin) assert torch.all( torch.matmul(normals3.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin)
def face_camera_normals_z(self, face_vertices_camera): face_normals_unit = face_normals(face_vertices_camera, unit=True) return face_normals_unit[:, :, 2]
def test_sample_points(self, vertices, faces, face_features, use_features, device, dtype): batch_size, num_vertices = vertices.shape[:2] num_faces = faces.shape[0] num_samples = 1000 if use_features: points, face_choices, interpolated_features = mesh.sample_points( vertices, faces, num_samples, face_features=face_features) else: points, face_choices = mesh.sample_points(vertices, faces, num_samples) check_tensor(points, shape=(batch_size, num_samples, 3), dtype=dtype, device=device) check_tensor(face_choices, shape=(batch_size, num_samples), dtype=torch.long, device=device) # check that all faces are sampled num_0 = torch.sum(face_choices == 0, dim=1) assert torch.all(num_0 + torch.sum(face_choices == 1, dim=1) == num_samples) sampling_prob = num_samples / 2 tolerance = sampling_prob * 0.2 assert torch.all(num_0 < sampling_prob + tolerance) and \ torch.all(num_0 > sampling_prob - tolerance) face_vertices = mesh.index_vertices_by_faces(vertices, faces) face_vertices_choices = torch.gather( face_vertices, 1, face_choices[:, :, None, None].repeat(1, 1, 3, 3)) # compute distance from the point to the plan of the face picked face_normals = mesh.face_normals(face_vertices_choices, unit=True) v0_p = points - face_vertices_choices[:, :, 0] # batch_size x num_points x 3 len_v0_p = torch.sqrt(torch.sum(v0_p**2, dim=-1)) cos_a = torch.matmul(v0_p.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)).reshape( batch_size, num_samples) / len_v0_p point_to_face_dist = len_v0_p * cos_a if dtype == torch.half: atol = 1e-2 rtol = 1e-3 else: atol = 1e-4 rtol = 1e-5 # check that the point is close to the plan assert torch.allclose(point_to_face_dist, torch.zeros((batch_size, num_samples), device=device, dtype=dtype), atol=atol, rtol=rtol) # check that the point lie in the triangle edges0 = face_vertices_choices[:, :, 1] - face_vertices_choices[:, :, 0] edges1 = face_vertices_choices[:, :, 2] - face_vertices_choices[:, :, 1] edges2 = face_vertices_choices[:, :, 0] - face_vertices_choices[:, :, 2] v0_p = points - face_vertices_choices[:, :, 0] v1_p = points - face_vertices_choices[:, :, 1] v2_p = points - face_vertices_choices[:, :, 2] # Normals of the triangle formed by an edge and the point normals1 = torch.cross(edges0, v0_p) normals2 = torch.cross(edges1, v1_p) normals3 = torch.cross(edges2, v2_p) # cross-product of those normals with the face normals must be positive margin = -5e-3 if dtype == torch.half else 0. assert torch.all( torch.matmul(normals1.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin) assert torch.all( torch.matmul(normals2.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin) assert torch.all( torch.matmul(normals3.reshape(-1, 1, 3), face_normals.reshape(-1, 3, 1)) >= margin) if use_features: feat_dim = face_features.shape[-1] check_tensor(interpolated_features, shape=(batch_size, num_samples, feat_dim), dtype=dtype, device=device) # face_vertices_choices (batch_size, num_samples, 3, 3) # points (batch_size, num_samples, 3) ax = face_vertices_choices[:, :, 0, 0] ay = face_vertices_choices[:, :, 0, 1] bx = face_vertices_choices[:, :, 1, 0] by = face_vertices_choices[:, :, 1, 1] cx = face_vertices_choices[:, :, 2, 0] cy = face_vertices_choices[:, :, 2, 1] m = bx - ax p = by - ay n = cx - ax q = cy - ay s = points[:, :, 0] - ax t = points[:, :, 1] - ay # sum_weights = torch.sum(weights, dim=-1) # zeros_idxs = torch.where(sum_weights == 0) #weights = weights / torch.sum(weights, keepdims=True, dim=-1) k1 = s * q - n * t k2 = m * t - s * p k3 = m * q - n * p w1 = k1 / (k3 + 1e-7) w2 = k2 / (k3 + 1e-7) w0 = (1. - w1) - w2 weights = torch.stack([w0, w1, w2], dim=-1) gt_points = torch.sum(face_vertices_choices * weights.unsqueeze(-1), dim=-2) assert torch.allclose(points, gt_points, atol=atol, rtol=rtol) _face_choices = face_choices[..., None, None].repeat(1, 1, 3, feat_dim) face_features_choices = torch.gather(face_features, 1, _face_choices) gt_interpolated_features = torch.sum(face_features_choices * weights.unsqueeze(-1), dim=-2) assert torch.allclose(interpolated_features, gt_interpolated_features, atol=atol, rtol=rtol)
def test_optimize_vertex_position(self, vertices, faces, vertex_colors, vertices_image, camera_rot, camera_trans, camera_proj, height, width, dtype, device): batch_size = faces.shape[0] # face_vertex_colors camera_rot = camera_rot.to(device, dtype) camera_trans = camera_trans.to(device, dtype) camera_proj = camera_proj.to(device, dtype) face_attributes = index_vertices_by_faces(vertex_colors.to(device, dtype), faces) vertices = vertices.to(device, dtype).clone().detach() vertices.requires_grad = False moved_vertices = vertices.to(device, dtype).clone() moved_vertices[:,0,:2] += 0.4 moved_vertices = moved_vertices.detach() moved_vertices.requires_grad = True images_gt = [torch.from_numpy(np.array(Image.open( os.path.join(SAMPLE_DIR, f'vertex_color_{bs}.png')))) for bs in range(batch_size)] images_gt = torch.stack(images_gt, dim=0).to(device, dtype) / 255. moved_vertices_camera = rotate_translate_points(moved_vertices, camera_rot, camera_trans) moved_vertices_image = perspective_camera(moved_vertices_camera, camera_proj) # test that the vertex are far enough to fail the test. assert not torch.allclose(moved_vertices_image, vertices_image, atol=1e-2, rtol=1e-2) with torch.no_grad(): moved_vertices_camera = rotate_translate_points(moved_vertices, camera_rot, camera_trans) moved_vertices_image = perspective_camera(moved_vertices_camera, camera_proj) face_moved_vertices_camera = index_vertices_by_faces(moved_vertices_camera, faces) face_moved_vertices_image = index_vertices_by_faces(moved_vertices_image, faces) face_moved_normals_z = face_normals(face_moved_vertices_camera, unit=True)[:, :, 2] imfeat, _, _ = dibr_rasterization(height, width, face_moved_vertices_camera[:, :, :, 2], face_moved_vertices_image, face_attributes, face_moved_normals_z) original_loss = torch.mean(torch.abs(imfeat - images_gt)) # test that the loss is high enough assert original_loss > 0.01 optimizer = torch.optim.Adam([moved_vertices], lr=5e-3) for i in range(100): optimizer.zero_grad() moved_vertices_camera = rotate_translate_points(moved_vertices, camera_rot, camera_trans) moved_vertices_image = perspective_camera(moved_vertices_camera, camera_proj) face_moved_vertices_camera = index_vertices_by_faces(moved_vertices_camera, faces) face_moved_vertices_image = index_vertices_by_faces(moved_vertices_image, faces) face_moved_normals_z = face_normals(face_moved_vertices_camera, unit=True)[:, :, 2] imfeat, _, _ = dibr_rasterization(height, width, face_moved_vertices_camera[:, :, :, 2], face_moved_vertices_image, face_attributes, face_moved_normals_z) loss = torch.mean(torch.abs(imfeat - images_gt)) loss.backward() optimizer.step() moved_vertices_camera = rotate_translate_points(moved_vertices, camera_rot, camera_trans) moved_vertices_image = perspective_camera(moved_vertices_camera, camera_proj) # test that the loss went down assert loss < 0.001 # We only test on image plan since we don't change camera angle during training we don't expect depth to be correct. # We could probably fine-tune the test to have a lower tolerance (TODO: cfujitsang) assert torch.allclose(moved_vertices_image, vertices_image, atol=1e-2, rtol=1e-2)