def test_render_normal(self, face_vertices_camera, face_vertices_image,
                           face_camera_normals_z, height, width, dtype,
                           device):
        batch_size = face_vertices_camera.shape[0]
        face_normals_unit = face_normals(face_vertices_camera, unit=True)
        face_attributes = face_normals_unit.unsqueeze(-2).repeat(1, 1, 3, 1)

        # imfeat is interpolated features
        # improb is the soft mask
        # imfaceidx is the face index map, which pixel is covered by which face
        # it starts from 1, 0 is void.
        imfeat, improb, imfaceidx = dibr_rasterization(
            height, width, face_vertices_camera[:, :, :, 2],
            face_vertices_image, face_attributes, face_camera_normals_z)
        images = (imfeat + 1) / 2
        images_gt = [
            torch.from_numpy(
                np.array(
                    Image.open(
                        os.path.join(SAMPLE_DIR, f'vertex_normal_{bs}.png'))))
            for bs in range(batch_size)
        ]
        images_gt = torch.stack(images_gt, dim=0).to(device, dtype) / 255.
        if dtype == torch.double:
            num_pix_diff_tol = 8
        else:
            num_pix_diff_tol = 0
        num_pix_diff = torch.sum(
            ~torch.isclose(images, images_gt, atol=1. / 255.))
        assert num_pix_diff <= num_pix_diff_tol
 def test_prepare_vertices(self, vertices, faces, camera_rot, camera_trans,
                           camera_proj, face_vertices_camera,
                           face_vertices_image):
     _face_vertices_camera, _face_vertices_image, _face_normals = \
         prepare_vertices(vertices, faces, camera_proj, camera_rot, camera_trans)
     assert torch.equal(face_vertices_camera, _face_vertices_camera)
     assert torch.equal(face_vertices_image, _face_vertices_image)
     assert torch.equal(face_normals(face_vertices_camera, unit=True),
                        _face_normals)
    def test_render_texture_with_light(self, uvs, faces, texture_maps, lights,
                                       face_vertices_camera,
                                       face_vertices_image,
                                       face_camera_normals_z, height, width,
                                       dtype, device):
        batch_size = faces.shape[0]
        # Note: in this example uv face is the same as mesh face
        # but they could be different
        face_uvs = index_vertices_by_faces(uvs, faces)

        # normal
        face_normals_unit = face_normals(face_vertices_camera, unit=True)
        face_normals_unit = face_normals_unit.unsqueeze(-2).repeat(1, 1, 3, 1)

        # merge them together
        face_attributes = [
            torch.ones((*face_uvs.shape[:-1], 1), device=device, dtype=dtype),
            face_uvs, face_normals_unit
        ]

        (texmask, texcoord, imnormal), improb, imidx = dibr_rasterization(
            height, width, face_vertices_camera[:, :, :, 2],
            face_vertices_image, face_attributes, face_camera_normals_z)

        texcolor = texture_mapping(texcoord, texture_maps, mode='nearest')
        coef = spherical_harmonic_lighting(imnormal, lights)
        images = torch.clamp(texmask * texcolor * coef.unsqueeze(-1), 0, 1)

        if dtype == torch.double:
            num_pix_diff_tol = 74  # (over 2 x 256 x 512 x 3 pixels)
        else:
            num_pix_diff_tol = 0

        images_gt = [
            torch.from_numpy(
                np.array(
                    Image.open(
                        os.path.join(SAMPLE_DIR, f'texture_light_{bs}.png'))))
            for bs in range(batch_size)
        ]
        images_gt = torch.stack(images_gt, dim=0).to(device, dtype) / 255.

        num_pix_diff = torch.sum(
            ~torch.isclose(images, images_gt, atol=1. / 255.))
        assert num_pix_diff <= num_pix_diff_tol
示例#4
0
    def test_sample_points(self, vertices, faces, device, dtype):
        batch_size, num_vertices = vertices.shape[:2]
        num_faces = faces.shape[0]
        num_samples = 1000

        points, face_choices = mesh.sample_points(vertices, faces, num_samples)

        check_tensor(points,
                     shape=(batch_size, num_samples, 3),
                     dtype=dtype,
                     device=device)
        check_tensor(face_choices,
                     shape=(batch_size, num_samples),
                     dtype=torch.long,
                     device=device)

        # check that all faces are sampled
        num_0 = torch.sum(face_choices == 0, dim=1)
        assert torch.all(num_0 +
                         torch.sum(face_choices == 1, dim=1) == num_samples)
        sampling_prob = num_samples / 3.
        tolerance = sampling_prob * 0.1
        assert torch.all(num_0 < sampling_prob + tolerance) and \
               torch.all(num_0 > sampling_prob - tolerance)

        face_vertices = mesh.index_vertices_by_faces(vertices, faces)

        face_vertices_choices = torch.gather(
            face_vertices, 1, face_choices[:, :, None,
                                           None].repeat(1, 1, 3, 3))

        # compute distance from the point to the plan of the face picked
        face_normals = mesh.face_normals(face_vertices_choices, unit=True)

        v0_p = points - face_vertices_choices[:, :,
                                              0]  # batch_size x num_points x 3
        len_v0_p = torch.sqrt(torch.sum(v0_p**2, dim=-1))
        cos_a = torch.matmul(v0_p.reshape(-1, 1, 3),
                             face_normals.reshape(-1, 3, 1)).reshape(
                                 batch_size, num_samples) / len_v0_p
        point_to_face_dist = len_v0_p * cos_a

        if dtype == torch.half:
            atol = 1e-2
            rtol = 1e-3
        else:
            atol = 1e-4
            rtol = 1e-5

        # check that the point is close to the plan
        assert torch.allclose(point_to_face_dist,
                              torch.zeros((batch_size, num_samples),
                                          device=device,
                                          dtype=dtype),
                              atol=atol,
                              rtol=rtol)

        # check that the point lie in the triangle
        edges0 = face_vertices_choices[:, :, 1] - face_vertices_choices[:, :,
                                                                        0]
        edges1 = face_vertices_choices[:, :, 2] - face_vertices_choices[:, :,
                                                                        1]
        edges2 = face_vertices_choices[:, :, 0] - face_vertices_choices[:, :,
                                                                        2]

        v0_p = points - face_vertices_choices[:, :, 0]
        v1_p = points - face_vertices_choices[:, :, 1]
        v2_p = points - face_vertices_choices[:, :, 2]

        # Normals of the triangle formed by an edge and the point
        normals1 = torch.cross(edges0, v0_p)
        normals2 = torch.cross(edges1, v1_p)
        normals3 = torch.cross(edges2, v2_p)
        # cross-product of those normals with the face normals must be positive
        margin = -5e-3 if dtype == torch.half else 0.
        assert torch.all(
            torch.matmul(normals1.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
        assert torch.all(
            torch.matmul(normals2.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
        assert torch.all(
            torch.matmul(normals3.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
示例#5
0
    def test_packed_sample_points(self, packed_vertices_info,
                                  packed_faces_info, device, dtype):
        vertices, first_idx_vertices = packed_vertices_info
        faces, num_faces_per_mesh = packed_faces_info

        total_num_vertices = vertices.shape[0]
        total_num_faces = faces.shape[0]
        batch_size = num_faces_per_mesh.shape[0]
        num_samples = 1000

        points, face_choices = mesh.packed_sample_points(
            vertices, first_idx_vertices, faces, num_faces_per_mesh,
            num_samples)

        check_tensor(points,
                     shape=(batch_size, num_samples, 3),
                     dtype=dtype,
                     device=device)
        check_tensor(face_choices,
                     shape=(batch_size, num_samples),
                     dtype=torch.long,
                     device=device)

        # check that all faces are sampled
        assert torch.all(face_choices[1] == 2)
        num_0 = torch.sum(face_choices[0] == 0)
        assert num_0 + torch.sum(face_choices[0] == 1) == num_samples
        sampling_prob = num_samples / 3.
        tolerance = sampling_prob * 0.1
        assert (num_0 < sampling_prob + tolerance) and \
               (num_0 > sampling_prob - tolerance)

        merged_faces = faces + tile_to_packed(
            first_idx_vertices[:-1].to(vertices.device), num_faces_per_mesh)

        face_vertices = torch.index_select(vertices, 0,
                                           merged_faces.reshape(-1)).reshape(
                                               total_num_faces, 3, 3)

        face_vertices_choices = torch.gather(
            face_vertices, 0,
            face_choices.reshape(-1, 1,
                                 1).repeat(1, 3,
                                           3)).reshape(batch_size, num_samples,
                                                       3, 3)

        # compute distance from the point to the plan of the face picked
        face_normals = mesh.face_normals(face_vertices_choices, unit=True)
        v0_p = points - face_vertices_choices[:, :,
                                              0]  # batch_size x num_points x 3
        len_v0_p = torch.sqrt(torch.sum(v0_p**2, dim=-1))
        cos_a = torch.matmul(v0_p.reshape(-1, 1, 3),
                             face_normals.reshape(-1, 3, 1)).reshape(
                                 batch_size, num_samples) / len_v0_p
        point_to_face_dist = len_v0_p * cos_a

        if dtype == torch.half:
            atol = 1e-2
            rtol = 1e-3
        else:
            atol = 1e-4
            rtol = 1e-5

        # check that the point is close to the plan
        assert torch.allclose(point_to_face_dist,
                              torch.zeros((batch_size, num_samples),
                                          device=device,
                                          dtype=dtype),
                              atol=atol,
                              rtol=rtol)

        # check that the point lie in the triangle
        edges0 = face_vertices_choices[:, :, 1] - face_vertices_choices[:, :,
                                                                        0]
        edges1 = face_vertices_choices[:, :, 2] - face_vertices_choices[:, :,
                                                                        1]
        edges2 = face_vertices_choices[:, :, 0] - face_vertices_choices[:, :,
                                                                        2]

        v0_p = points - face_vertices_choices[:, :, 0]
        v1_p = points - face_vertices_choices[:, :, 1]
        v2_p = points - face_vertices_choices[:, :, 2]

        # Normals of the triangle formed by an edge and the point
        normals1 = torch.cross(edges0, v0_p)
        normals2 = torch.cross(edges1, v1_p)
        normals3 = torch.cross(edges2, v2_p)
        # cross-product of those normals with the face normals must be positive
        margin = -2e-3 if dtype == torch.half else 0.
        assert torch.all(
            torch.matmul(normals1.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
        assert torch.all(
            torch.matmul(normals2.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
        assert torch.all(
            torch.matmul(normals3.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
 def face_camera_normals_z(self, face_vertices_camera):
     face_normals_unit = face_normals(face_vertices_camera, unit=True)
     return face_normals_unit[:, :, 2]
示例#7
0
    def test_sample_points(self, vertices, faces, face_features, use_features,
                           device, dtype):
        batch_size, num_vertices = vertices.shape[:2]
        num_faces = faces.shape[0]
        num_samples = 1000

        if use_features:
            points, face_choices, interpolated_features = mesh.sample_points(
                vertices, faces, num_samples, face_features=face_features)
        else:
            points, face_choices = mesh.sample_points(vertices, faces,
                                                      num_samples)

        check_tensor(points,
                     shape=(batch_size, num_samples, 3),
                     dtype=dtype,
                     device=device)
        check_tensor(face_choices,
                     shape=(batch_size, num_samples),
                     dtype=torch.long,
                     device=device)

        # check that all faces are sampled
        num_0 = torch.sum(face_choices == 0, dim=1)
        assert torch.all(num_0 +
                         torch.sum(face_choices == 1, dim=1) == num_samples)
        sampling_prob = num_samples / 2
        tolerance = sampling_prob * 0.2
        assert torch.all(num_0 < sampling_prob + tolerance) and \
               torch.all(num_0 > sampling_prob - tolerance)

        face_vertices = mesh.index_vertices_by_faces(vertices, faces)

        face_vertices_choices = torch.gather(
            face_vertices, 1, face_choices[:, :, None,
                                           None].repeat(1, 1, 3, 3))

        # compute distance from the point to the plan of the face picked
        face_normals = mesh.face_normals(face_vertices_choices, unit=True)

        v0_p = points - face_vertices_choices[:, :,
                                              0]  # batch_size x num_points x 3
        len_v0_p = torch.sqrt(torch.sum(v0_p**2, dim=-1))
        cos_a = torch.matmul(v0_p.reshape(-1, 1, 3),
                             face_normals.reshape(-1, 3, 1)).reshape(
                                 batch_size, num_samples) / len_v0_p
        point_to_face_dist = len_v0_p * cos_a

        if dtype == torch.half:
            atol = 1e-2
            rtol = 1e-3
        else:
            atol = 1e-4
            rtol = 1e-5

        # check that the point is close to the plan
        assert torch.allclose(point_to_face_dist,
                              torch.zeros((batch_size, num_samples),
                                          device=device,
                                          dtype=dtype),
                              atol=atol,
                              rtol=rtol)

        # check that the point lie in the triangle
        edges0 = face_vertices_choices[:, :, 1] - face_vertices_choices[:, :,
                                                                        0]
        edges1 = face_vertices_choices[:, :, 2] - face_vertices_choices[:, :,
                                                                        1]
        edges2 = face_vertices_choices[:, :, 0] - face_vertices_choices[:, :,
                                                                        2]

        v0_p = points - face_vertices_choices[:, :, 0]
        v1_p = points - face_vertices_choices[:, :, 1]
        v2_p = points - face_vertices_choices[:, :, 2]

        # Normals of the triangle formed by an edge and the point
        normals1 = torch.cross(edges0, v0_p)
        normals2 = torch.cross(edges1, v1_p)
        normals3 = torch.cross(edges2, v2_p)
        # cross-product of those normals with the face normals must be positive
        margin = -5e-3 if dtype == torch.half else 0.
        assert torch.all(
            torch.matmul(normals1.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
        assert torch.all(
            torch.matmul(normals2.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
        assert torch.all(
            torch.matmul(normals3.reshape(-1, 1, 3),
                         face_normals.reshape(-1, 3, 1)) >= margin)
        if use_features:
            feat_dim = face_features.shape[-1]
            check_tensor(interpolated_features,
                         shape=(batch_size, num_samples, feat_dim),
                         dtype=dtype,
                         device=device)
            # face_vertices_choices (batch_size, num_samples, 3, 3)
            # points (batch_size, num_samples, 3)
            ax = face_vertices_choices[:, :, 0, 0]
            ay = face_vertices_choices[:, :, 0, 1]
            bx = face_vertices_choices[:, :, 1, 0]
            by = face_vertices_choices[:, :, 1, 1]
            cx = face_vertices_choices[:, :, 2, 0]
            cy = face_vertices_choices[:, :, 2, 1]
            m = bx - ax
            p = by - ay
            n = cx - ax
            q = cy - ay
            s = points[:, :, 0] - ax
            t = points[:, :, 1] - ay

            # sum_weights = torch.sum(weights, dim=-1)
            # zeros_idxs = torch.where(sum_weights == 0)
            #weights = weights / torch.sum(weights, keepdims=True, dim=-1)
            k1 = s * q - n * t
            k2 = m * t - s * p
            k3 = m * q - n * p
            w1 = k1 / (k3 + 1e-7)
            w2 = k2 / (k3 + 1e-7)
            w0 = (1. - w1) - w2
            weights = torch.stack([w0, w1, w2], dim=-1)

            gt_points = torch.sum(face_vertices_choices *
                                  weights.unsqueeze(-1),
                                  dim=-2)
            assert torch.allclose(points, gt_points, atol=atol, rtol=rtol)

            _face_choices = face_choices[..., None,
                                         None].repeat(1, 1, 3, feat_dim)
            face_features_choices = torch.gather(face_features, 1,
                                                 _face_choices)

            gt_interpolated_features = torch.sum(face_features_choices *
                                                 weights.unsqueeze(-1),
                                                 dim=-2)
            assert torch.allclose(interpolated_features,
                                  gt_interpolated_features,
                                  atol=atol,
                                  rtol=rtol)
示例#8
0
    def test_optimize_vertex_position(self, vertices, faces, vertex_colors, vertices_image,
                                      camera_rot, camera_trans, camera_proj,
                                      height, width, dtype, device):
        batch_size = faces.shape[0]
        # face_vertex_colors
        camera_rot = camera_rot.to(device, dtype)
        camera_trans = camera_trans.to(device, dtype)
        camera_proj = camera_proj.to(device, dtype)
        face_attributes = index_vertices_by_faces(vertex_colors.to(device, dtype), faces)
        vertices = vertices.to(device, dtype).clone().detach()
        vertices.requires_grad = False
        moved_vertices = vertices.to(device, dtype).clone()
        moved_vertices[:,0,:2] += 0.4
        moved_vertices = moved_vertices.detach()
        moved_vertices.requires_grad = True

        images_gt = [torch.from_numpy(np.array(Image.open(
                        os.path.join(SAMPLE_DIR, f'vertex_color_{bs}.png'))))
                     for bs in range(batch_size)]
        images_gt = torch.stack(images_gt, dim=0).to(device, dtype) / 255.

        moved_vertices_camera = rotate_translate_points(moved_vertices, camera_rot, camera_trans)
        moved_vertices_image = perspective_camera(moved_vertices_camera, camera_proj)

        # test that the vertex are far enough to fail the test.
        assert not torch.allclose(moved_vertices_image, vertices_image, atol=1e-2, rtol=1e-2)

        with torch.no_grad():
            moved_vertices_camera = rotate_translate_points(moved_vertices, camera_rot, camera_trans)
            moved_vertices_image = perspective_camera(moved_vertices_camera, camera_proj)
            face_moved_vertices_camera = index_vertices_by_faces(moved_vertices_camera, faces)
            face_moved_vertices_image = index_vertices_by_faces(moved_vertices_image, faces)
            face_moved_normals_z = face_normals(face_moved_vertices_camera,
                                                     unit=True)[:, :, 2]
            imfeat, _, _ = dibr_rasterization(height,
                                              width,
                                              face_moved_vertices_camera[:, :, :, 2],
                                              face_moved_vertices_image,
                                              face_attributes,
                                              face_moved_normals_z)
            original_loss = torch.mean(torch.abs(imfeat - images_gt))

        # test that the loss is high enough
        assert original_loss > 0.01
        optimizer = torch.optim.Adam([moved_vertices], lr=5e-3)

        for i in range(100):
            optimizer.zero_grad()
            moved_vertices_camera = rotate_translate_points(moved_vertices, camera_rot, camera_trans)
            moved_vertices_image = perspective_camera(moved_vertices_camera, camera_proj)
            face_moved_vertices_camera = index_vertices_by_faces(moved_vertices_camera, faces)
            face_moved_vertices_image = index_vertices_by_faces(moved_vertices_image, faces)
            face_moved_normals_z = face_normals(face_moved_vertices_camera,
                                                     unit=True)[:, :, 2]
            imfeat, _, _ = dibr_rasterization(height,
                                              width,
                                              face_moved_vertices_camera[:, :, :, 2],
                                              face_moved_vertices_image,
                                              face_attributes,
                                              face_moved_normals_z)
            loss = torch.mean(torch.abs(imfeat - images_gt))
            loss.backward()

            optimizer.step()

        moved_vertices_camera = rotate_translate_points(moved_vertices, camera_rot, camera_trans)
        moved_vertices_image = perspective_camera(moved_vertices_camera, camera_proj)

        # test that the loss went down
        assert loss < 0.001
        # We only test on image plan since we don't change camera angle during training we don't expect depth to be correct.
        # We could probably fine-tune the test to have a lower tolerance (TODO: cfujitsang)
        assert torch.allclose(moved_vertices_image, vertices_image, atol=1e-2, rtol=1e-2)