def vertices_image(self, vertices_camera, camera_proj):
     return perspective_camera(vertices_camera, camera_proj)
 def vertices_images_zoom(self, vertices_camera, camera_proj_zoom):
     # the two faces are fully covering the camera
     return perspective_camera(vertices_camera, camera_proj_zoom)
    def test_optimize_vertex_position(self, vertices, faces, vertex_colors,
                                      vertices_image, camera_rot, camera_trans,
                                      camera_proj, height, width, dtype,
                                      device):
        batch_size = faces.shape[0]
        # face_vertex_colors
        camera_rot = camera_rot.to(device, dtype)
        camera_rot.requires_grad = False
        camera_trans = camera_trans.to(device, dtype)
        camera_trans.requires_grad = False
        camera_proj = camera_proj.to(device, dtype)
        camera_proj.requires_grad = False
        face_attributes = index_vertices_by_faces(
            vertex_colors.to(device, dtype), faces)
        vertices = vertices.to(device, dtype).clone().detach()
        vertices.requires_grad = False
        moved_vertices = vertices.to(device, dtype).clone()
        moved_vertices[:, 0, :2] += 0.4
        moved_vertices = moved_vertices.detach()
        moved_vertices.requires_grad = True

        images_gt = [
            torch.from_numpy(
                np.array(
                    Image.open(
                        os.path.join(SAMPLE_DIR, f'vertex_color_{bs}.png'))))
            for bs in range(batch_size)
        ]
        images_gt = torch.stack(images_gt, dim=0).to(device, dtype) / 255.

        with torch.no_grad():
            moved_vertices_camera = rotate_translate_points(
                moved_vertices, camera_rot, camera_trans)
            moved_vertices_image = perspective_camera(moved_vertices_camera,
                                                      camera_proj)

        # test that the vertex are far enough to fail the test.
        assert not torch.allclose(
            moved_vertices_image, vertices_image, atol=1e-2, rtol=1e-2)

        with torch.no_grad():
            face_moved_vertices_camera, face_moved_vertices_image, face_moved_normals = \
                prepare_vertices(moved_vertices, faces, camera_proj, camera_rot, camera_trans)
            face_moved_normals_z = face_moved_normals[:, :, 2]

            imfeat, _, _ = dibr_rasterization(
                height, width, face_moved_vertices_camera[:, :, :, 2],
                face_moved_vertices_image, face_attributes,
                face_moved_normals_z)
            original_loss = torch.mean(torch.abs(imfeat - images_gt))

        # test that the loss is high enough
        assert original_loss > 0.01
        optimizer = torch.optim.Adam([moved_vertices], lr=5e-3)

        for i in range(100):
            optimizer.zero_grad()
            face_moved_vertices_camera, face_moved_vertices_image, face_moved_normals = \
                prepare_vertices(moved_vertices, faces, camera_proj, camera_rot, camera_trans)
            face_moved_normals_z = face_moved_normals[:, :, 2]
            imfeat, _, _ = dibr_rasterization(
                height, width, face_moved_vertices_camera[:, :, :, 2],
                face_moved_vertices_image, face_attributes,
                face_moved_normals_z)
            loss = torch.mean(torch.abs(imfeat - images_gt))
            loss.backward()
            optimizer.step()

        moved_vertices_camera = rotate_translate_points(
            moved_vertices, camera_rot, camera_trans)
        moved_vertices_image = perspective_camera(moved_vertices_camera,
                                                  camera_proj)

        # test that the loss went down
        assert loss < 0.001
        # We only test on image plan since we don't change camera angle during training we don't expect depth to be correct.
        # We could probably fine-tune the test to have a lower tolerance (TODO: cfujitsang)
        assert torch.allclose(moved_vertices_image,
                              vertices_image,
                              atol=1e-2,
                              rtol=1e-2)