コード例 #1
0
    def run_optimization(
        self,
        silhouettes: torch.tensor,
        R: torch.tensor,
        T: torch.tensor,
        writer=None,
        camera_settings=None,
        step: int = 0,
    ):
        """
        Function:
            Runs a batched optimization procedure that aims to minimize 3 reconstruction losses:
                -Silhouette IoU Loss: between input silhouettes and re-projected mesh
                -Mesh Edge consistency
                -Mesh Normal smoothing
            Mini Batching:
                If the number silhouettes is greater than the allowed batch size then a random set of images/poses is sampled for supervision at each step
        Returns:
            -Reconstruction losses: 3 reconstruction losses measured during optimization
            -Timing:
                -Iterations / second
                -Total time elapsed in seconds
        """

        if len(R.shape) == 4:
            R = R.squeeze(1)
            T = T.squeeze(1)

        tf_smaller = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(self.params.img_size),
            transforms.ToTensor(),
        ])

        images_gt = torch.stack([
            tf_smaller(s.cpu()).to(self.device) for s in silhouettes
        ]).squeeze(1)

        if images_gt.max() > 1.0:
            images_gt = images_gt / 255.0

        loop = tqdm_notebook(range(self.params.mesh_steps))

        start_time = time.time()
        for i in loop:
            batch_indices = (random.choices(list(range(images_gt.shape[0])),
                                            k=self.params.mesh_batch_size) if
                             images_gt.shape[0] > self.params.mesh_batch_size
                             else list(range(images_gt.shape[0])))
            batch_silhouettes = images_gt[batch_indices]

            batch_R, batch_T = R[batch_indices], T[batch_indices]
            # apply right transform on the Twv to adjust the coordinate system shift from EVIMO to PyTorch3D
            if self.params.is_real_data:
                init_R = quaternion_to_matrix(self.init_camera_R)
                batch_R = _broadcast_bmm(batch_R, init_R)
                batch_T = (
                    _broadcast_bmm(batch_T[:, None, :], init_R) +
                    self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :]
                focal_length = (torch.tensor([
                    camera_settings[0, 0], camera_settings[1, 1]
                ])[None]).expand(batch_R.shape[0], 2)
                principle_point = (torch.tensor([
                    camera_settings[0, 2], camera_settings[1, 2]
                ])[None]).expand(batch_R.shape[0], 2)
                # FIXME: in this PyTorch3D version, the image_size in RasterizationSettings is (W, H), while in PerspectiveCameras is (H, W)
                # If the future pytorch3d change the format, please change the settings here
                # We hope PyTorch3D will solve this issue in the future
                batch_cameras = PerspectiveCameras(
                    device=self.device,
                    R=batch_R,
                    T=batch_T,
                    focal_length=focal_length,
                    principal_point=principle_point,
                    image_size=((self.params.img_size[1],
                                 self.params.img_size[0]), ))
            else:
                batch_cameras = PerspectiveCameras(device=self.device,
                                                   R=batch_R,
                                                   T=batch_T)

            mesh, laplacian_loss, flatten_loss = self.forward(
                self.params.mesh_batch_size)

            images_pred = self.renderer(mesh,
                                        device=self.device,
                                        cameras=batch_cameras)[..., -1]

            iou_loss = IOULoss().forward(batch_silhouettes, images_pred)

            loss = (iou_loss * self.params.lambda_iou +
                    laplacian_loss * self.params.lambda_laplacian +
                    flatten_loss * self.params.lambda_flatten)

            loop.set_description("Optimizing (loss %.4f)" % loss.data)

            self.losses["iou"].append(iou_loss * self.params.lambda_iou)
            self.losses["laplacian"].append(laplacian_loss *
                                            self.params.lambda_laplacian)
            self.losses["flatten"].append(flatten_loss *
                                          self.params.lambda_flatten)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if i % (self.params.mesh_show_step /
                    2) == 0 and self.params.mesh_log:
                logging.info(
                    f'Iteration: {i} IOU Loss: {iou_loss.item()} Flatten Loss: {flatten_loss.item()} Laplacian Loss: {laplacian_loss.item()}'
                )

            if i % self.params.mesh_show_step == 0 and self.params.im_show:
                # Write images
                image = images_pred.detach().cpu().numpy()[0]

                if writer:
                    writer.append_data((255 * image).astype(np.uint8))
                plt.imshow(images_pred.detach().cpu().numpy()[0])
                plt.show()
                plt.imshow(batch_silhouettes.detach().cpu().numpy()[0])
                plt.show()
                plot_pointcloud(mesh[0], 'Mesh deformed')
                logging.info(
                    f'Pose of init camera: {self.init_camera_R.detach().cpu().numpy()}, {self.init_camera_t.detach().cpu().numpy()}'
                )

        # Set the final optimized mesh as an internal variable
        self.final_mesh = mesh[0].clone()
        results = dict(
            silhouette_loss=self.losses["iou"]
            [-1].detach().cpu().numpy().tolist(),
            laplacian_loss=self.losses["laplacian"]
            [-1].detach().cpu().numpy().tolist(),
            flatten_loss=self.losses["flatten"]
            [-1].detach().cpu().numpy().tolist(),
            iterations_per_second=self.params.mesh_steps /
            (time.time() - start_time),
            total_time_s=time.time() - start_time,
        )
        if self.is_real_data:
            self.init_pose_R = self.init_camera_R.detach().cpu().numpy()
            self.init_pose_t = self.init_camera_t.detach().cpu().numpy()

        torch.cuda.empty_cache()

        return results
コード例 #2
0
    def render_final_mesh(self,
                          poses,
                          mode: str,
                          out_size: list,
                          camera_settings=None) -> dict:
        """Renders the final mesh obtained through optimization
            Supports two modes:
                -predict: renders both silhouettes and flat shaded images
                -train: only renders silhouettes
            Returns:
                -dict of renders {'silhouettes': tensor, 'images': tensor}
        """
        R, T = poses
        if len(R.shape) == 4:
            R = R.squeeze(1)
            T = T.squeeze(1)

        sil_renderer = silhouette_renderer(out_size, self.device)
        image_renderer = flat_renderer(out_size, self.device)

        # Create a silhouette projection of the mesh across all views
        all_silhouettes = []
        all_images = []
        for i in range(0, R.shape[0]):
            batch_R, batch_T = R[[i]], T[[i]]
            if self.params.is_real_data:
                init_R = quaternion_to_matrix(self.init_camera_R)
                batch_R = _broadcast_bmm(batch_R, init_R)
                batch_T = (
                    _broadcast_bmm(batch_T[:, None, :], init_R) +
                    self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :]
                focal_length = torch.tensor(
                    [camera_settings[0, 0], camera_settings[1, 1]])[None]
                principle_point = torch.tensor(
                    [camera_settings[0, 2], camera_settings[1, 2]])[None]
                t_cameras = PerspectiveCameras(
                    device=self.device,
                    R=batch_R,
                    T=batch_T,
                    focal_length=focal_length,
                    principal_point=principle_point,
                    image_size=((self.params.img_size[1],
                                 self.params.img_size[0]), ))
            else:
                t_cameras = PerspectiveCameras(device=self.device,
                                               R=batch_R,
                                               T=batch_T)
            all_silhouettes.append(
                sil_renderer(self._final_mesh,
                             device=self.device,
                             cameras=t_cameras).detach().cpu()[..., -1])

            if mode == "predict":
                all_images.append(
                    torch.clamp(
                        image_renderer(self._final_mesh,
                                       device=self.device,
                                       cameras=t_cameras),
                        0,
                        1,
                    ).detach().cpu()[..., :3])
            torch.cuda.empty_cache()
        renders = dict(
            silhouettes=torch.cat(all_silhouettes).unsqueeze(-1).permute(
                0, 3, 1, 2),
            images=torch.cat(all_images) if all_images else [],
        )

        return renders
コード例 #3
0
ファイル: reconstruction_mesh.py プロジェクト: walaa5/Face-X
def render_img(face_shape,
               face_color,
               facemodel,
               image_size=224,
               fx=1015.0,
               fy=1015.0,
               px=112.0,
               py=112.0,
               device='cuda:0'):
    '''
        ref: https://github.com/facebookresearch/pytorch3d/issues/184
        The rendering function (just for test)
        Input:
            face_shape:  Tensor[1, 35709, 3]
            face_color: Tensor[1, 35709, 3] in [0, 1]
            facemodel: contains `tri` (triangles[70789, 3], index start from 1)
    '''
    from pytorch3d.structures import Meshes
    from pytorch3d.renderer.mesh.textures import TexturesVertex
    from pytorch3d.renderer import (PerspectiveCameras, PointLights,
                                    RasterizationSettings, MeshRenderer,
                                    MeshRasterizer, SoftPhongShader,
                                    BlendParams)

    face_color = TexturesVertex(verts_features=face_color.to(device))
    face_buf = torch.from_numpy(facemodel.tri - 1)  # index start from 1
    face_idx = face_buf.unsqueeze(0)

    mesh = Meshes(face_shape.to(device), face_idx.to(device), face_color)

    R = torch.eye(3).view(1, 3, 3).to(device)
    R[0, 0, 0] *= -1.0
    T = torch.zeros([1, 3]).to(device)

    half_size = (image_size - 1.0) / 2
    focal_length = torch.tensor([fx / half_size, fy / half_size],
                                dtype=torch.float32).reshape(1, 2).to(device)
    principal_point = torch.tensor([(half_size - px) / half_size,
                                    (py - half_size) / half_size],
                                   dtype=torch.float32).reshape(1,
                                                                2).to(device)

    cameras = PerspectiveCameras(device=device,
                                 R=R,
                                 T=T,
                                 focal_length=focal_length,
                                 principal_point=principal_point)

    raster_settings = RasterizationSettings(image_size=image_size,
                                            blur_radius=0.0,
                                            faces_per_pixel=1)

    lights = PointLights(device=device,
                         ambient_color=((1.0, 1.0, 1.0), ),
                         diffuse_color=((0.0, 0.0, 0.0), ),
                         specular_color=((0.0, 0.0, 0.0), ),
                         location=((0.0, 0.0, 1e5), ))

    blend_params = BlendParams(background_color=(0.0, 0.0, 0.0))

    renderer = MeshRenderer(rasterizer=MeshRasterizer(
        cameras=cameras, raster_settings=raster_settings),
                            shader=SoftPhongShader(device=device,
                                                   cameras=cameras,
                                                   lights=lights,
                                                   blend_params=blend_params))
    images = renderer(mesh)
    images = torch.clamp(images, 0.0, 1.0)
    return images
コード例 #4
0
def batch_render(
    verts,
    faces,
    faces_per_pixel=10,
    K=None,
    rot=None,
    trans=None,
    colors=None,
    color=(0.53, 0.53, 0.8),  # light_purple
    ambient_col=0.5,
    specular_col=0.2,
    diffuse_col=0.3,
    face_colors=None,
    # color = (0.74117647, 0.85882353, 0.65098039),  # light_blue
    image_sizes=None,
    out_res=512,
    bin_size=0,
    shading="soft",
    mode="rgb",
    blend_gamma=1e-4,
    min_depth=None,
):
    device = torch.device("cuda:0")
    K = K.to(device)
    width, height = image_sizes[0]
    out_size = int(max(image_sizes[0]))
    raster_settings = RasterizationSettings(
        image_size=out_size,
        blur_radius=0.0,
        faces_per_pixel=faces_per_pixel,
        bin_size=bin_size,
    )

    fx = K[:, 0, 0]
    fy = K[:, 1, 1]
    focals = torch.stack([fx, fy], 1)
    px = K[:, 0, 2]
    py = K[:, 1, 2]
    principal_point = torch.stack([width - px, height - py], 1)
    if rot is None:
        rot = torch.eye(3).unsqueeze(0).to(device)
    if trans is None:
        trans = torch.zeros(3).unsqueeze(0).to(device)
    cameras = PerspectiveCameras(
        device=device,
        focal_length=focals,
        principal_point=principal_point,
        image_size=[(out_size, out_size) for _ in range(len(verts))],
        R=rot,
        T=trans,
    )
    if mode == "rgb" and shading == "soft":
        lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
        lights = DirectionalLights(
            device=device,
            direction=((0.6, -0.6, -0.6), ),
            ambient_color=((ambient_col, ambient_col, ambient_col), ),
            diffuse_color=((diffuse_col, diffuse_col, diffuse_col), ),
            specular_color=((specular_col, specular_col, specular_col), ),
        )
        shader = SoftPhongShader(device=device, cameras=cameras, lights=lights)
    elif mode == "silh":
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        shader = SoftSilhouetteShader(blend_params=blend_params)
    elif shading == "faceidx":
        shader = FaceIdxShader()
    elif (mode == "facecolor") and (shading == "hard"):
        shader = FaceColorShader(face_colors=face_colors)
    elif (mode == "facecolor") and (shading == "soft"):
        shader = SoftFaceColorShader(face_colors=face_colors,
                                     blend_gamma=blend_gamma)
    else:
        raise ValueError(
            f"Unhandled mode {mode} and shading {shading} combination")

    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=raster_settings),
        shader=shader,
    )
    if min_depth is not None:
        verts = torch.cat([verts[:, :, :2], verts[:, :, 2:].clamp(min_depth)],
                          2)
    if mode == "rgb":
        if colors is None:
            colors = get_colors(verts, color)
        tex = textures.TexturesVertex(verts_features=colors)

        meshes = Meshes(verts=verts, faces=faces, textures=tex)
    elif mode in ["silh", "facecolor"]:
        meshes = Meshes(verts=verts, faces=faces)
    else:
        raise ValueError(f"Render mode {mode} not in [rgb|silh]")

    square_images = renderer(meshes, cameras=cameras)
    height_off = int(width - height)
    # from matplotlib import pyplot as plt
    # plt.imshow(square_images.cpu()[0, :, :, 0])
    # plt.savefig("tmp.png")
    images = torch.flip(square_images, (1, 2))[:, height_off:]
    return images
コード例 #5
0
    def forward(self,
                verts,        # under general camera coordinate rXdYfZ,  N*V*3    
                faces,        # indices in verts to define traingles,    N*F*3
                verts_uvs,    # uv coordinate of corresponding verts,    N*V*2
                faces_uvs,    # indices in verts to define triangles,    N*F*3
                tex_image,    # under GCcd,                            N*H*W*3
                R,            # under GCcd,                              N*3*3 
                T,            # under GCcd,                              N*3
                f,            # in pixel/m,                              N*1
                C,            # camera center,                           N*2
                imgres,       # int
                lightLoc = None):
        
        assert verts.shape[0] == 1,\
            'with some issues in pytorch3D, render 1 mesh per forward'
        
        # only need to convert either R and T or verts, we choose R and T here
        if self.convertToPytorch3D:
            R = torch.matmul(self.GCcdToPytorch3D, R)
            T = torch.matmul(self.GCcdToPytorch3D, T.unsqueeze(-1)).squeeze(-1)
        
        # prepare textures and mesh to render
        tex = TexturesUV(
            verts_uvs = verts_uvs, 
            faces_uvs = faces_uvs, 
            maps = tex_image
        ) 
        mesh = Meshes(verts = verts, faces = faces, textures=tex)
        
        # Initialize a camera. The world coordinate is +Y up, +X left and +Z in. 
        cameras = PerspectiveCameras(
            focal_length=f,
            principal_point=C,
            R=R, 
            T=T, 
            image_size=((imgres,imgres),),
            device=self.device
        )

        # Define the settings for rasterization and shading. 
        raster_settings = RasterizationSettings(
            image_size=imgres, 
            blur_radius=0.0, 
            faces_per_pixel=1, 
        )

        # Create a simple renderer by composing a rasterizer and a shader.
        # The simple textured shader will interpolate the texture uv coordinates 
        # for each pixel, sample from a texture image. This renderer can
        # support lighting easily but we do not iimplement it.
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras, 
                raster_settings=raster_settings
            ),
            shader=SimpleShader(
                device=self.device
            )
        )
        
        # render the rendered image(s)
        images = renderer(mesh)
        
        return images