Esempio n. 1
0
 def test_cmp_transform_rotate_translate(self, batch_size, device, width,
                                         height, camera_pos, object_pos,
                                         camera_up, vertices):
     mtx_rot, mtx_trans = generate_rotate_translate_matrices(
         camera_pos, object_pos, camera_up)
     mtx_transform = generate_transformation_matrix(camera_pos, object_pos,
                                                    camera_up)
     vertices_camera = rotate_translate_points(vertices, mtx_rot, mtx_trans)
     padded_vertices = torch.nn.functional.pad(vertices, (0, 1),
                                               mode='constant',
                                               value=1.)
     vertices_camera2 = padded_vertices @ mtx_transform
     assert torch.allclose(vertices_camera, vertices_camera2)
 def vertices_camera(self, vertices, camera_rot, camera_trans):
     return rotate_translate_points(vertices, camera_rot, camera_trans)
    def test_optimize_vertex_position(self, vertices, faces, vertex_colors,
                                      vertices_image, camera_rot, camera_trans,
                                      camera_proj, height, width, dtype,
                                      device):
        batch_size = faces.shape[0]
        # face_vertex_colors
        camera_rot = camera_rot.to(device, dtype)
        camera_rot.requires_grad = False
        camera_trans = camera_trans.to(device, dtype)
        camera_trans.requires_grad = False
        camera_proj = camera_proj.to(device, dtype)
        camera_proj.requires_grad = False
        face_attributes = index_vertices_by_faces(
            vertex_colors.to(device, dtype), faces)
        vertices = vertices.to(device, dtype).clone().detach()
        vertices.requires_grad = False
        moved_vertices = vertices.to(device, dtype).clone()
        moved_vertices[:, 0, :2] += 0.4
        moved_vertices = moved_vertices.detach()
        moved_vertices.requires_grad = True

        images_gt = [
            torch.from_numpy(
                np.array(
                    Image.open(
                        os.path.join(SAMPLE_DIR, f'vertex_color_{bs}.png'))))
            for bs in range(batch_size)
        ]
        images_gt = torch.stack(images_gt, dim=0).to(device, dtype) / 255.

        with torch.no_grad():
            moved_vertices_camera = rotate_translate_points(
                moved_vertices, camera_rot, camera_trans)
            moved_vertices_image = perspective_camera(moved_vertices_camera,
                                                      camera_proj)

        # test that the vertex are far enough to fail the test.
        assert not torch.allclose(
            moved_vertices_image, vertices_image, atol=1e-2, rtol=1e-2)

        with torch.no_grad():
            face_moved_vertices_camera, face_moved_vertices_image, face_moved_normals = \
                prepare_vertices(moved_vertices, faces, camera_proj, camera_rot, camera_trans)
            face_moved_normals_z = face_moved_normals[:, :, 2]

            imfeat, _, _ = dibr_rasterization(
                height, width, face_moved_vertices_camera[:, :, :, 2],
                face_moved_vertices_image, face_attributes,
                face_moved_normals_z)
            original_loss = torch.mean(torch.abs(imfeat - images_gt))

        # test that the loss is high enough
        assert original_loss > 0.01
        optimizer = torch.optim.Adam([moved_vertices], lr=5e-3)

        for i in range(100):
            optimizer.zero_grad()
            face_moved_vertices_camera, face_moved_vertices_image, face_moved_normals = \
                prepare_vertices(moved_vertices, faces, camera_proj, camera_rot, camera_trans)
            face_moved_normals_z = face_moved_normals[:, :, 2]
            imfeat, _, _ = dibr_rasterization(
                height, width, face_moved_vertices_camera[:, :, :, 2],
                face_moved_vertices_image, face_attributes,
                face_moved_normals_z)
            loss = torch.mean(torch.abs(imfeat - images_gt))
            loss.backward()
            optimizer.step()

        moved_vertices_camera = rotate_translate_points(
            moved_vertices, camera_rot, camera_trans)
        moved_vertices_image = perspective_camera(moved_vertices_camera,
                                                  camera_proj)

        # test that the loss went down
        assert loss < 0.001
        # We only test on image plan since we don't change camera angle during training we don't expect depth to be correct.
        # We could probably fine-tune the test to have a lower tolerance (TODO: cfujitsang)
        assert torch.allclose(moved_vertices_image,
                              vertices_image,
                              atol=1e-2,
                              rtol=1e-2)