def run_optimization( self, silhouettes: torch.tensor, R: torch.tensor, T: torch.tensor, writer=None, camera_settings=None, step: int = 0, ): """ Function: Runs a batched optimization procedure that aims to minimize 3 reconstruction losses: -Silhouette IoU Loss: between input silhouettes and re-projected mesh -Mesh Edge consistency -Mesh Normal smoothing Mini Batching: If the number silhouettes is greater than the allowed batch size then a random set of images/poses is sampled for supervision at each step Returns: -Reconstruction losses: 3 reconstruction losses measured during optimization -Timing: -Iterations / second -Total time elapsed in seconds """ if len(R.shape) == 4: R = R.squeeze(1) T = T.squeeze(1) tf_smaller = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(self.params.img_size), transforms.ToTensor(), ]) images_gt = torch.stack([ tf_smaller(s.cpu()).to(self.device) for s in silhouettes ]).squeeze(1) if images_gt.max() > 1.0: images_gt = images_gt / 255.0 loop = tqdm_notebook(range(self.params.mesh_steps)) start_time = time.time() for i in loop: batch_indices = (random.choices(list(range(images_gt.shape[0])), k=self.params.mesh_batch_size) if images_gt.shape[0] > self.params.mesh_batch_size else list(range(images_gt.shape[0]))) batch_silhouettes = images_gt[batch_indices] batch_R, batch_T = R[batch_indices], T[batch_indices] # apply right transform on the Twv to adjust the coordinate system shift from EVIMO to PyTorch3D if self.params.is_real_data: init_R = quaternion_to_matrix(self.init_camera_R) batch_R = _broadcast_bmm(batch_R, init_R) batch_T = ( _broadcast_bmm(batch_T[:, None, :], init_R) + self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :] focal_length = (torch.tensor([ camera_settings[0, 0], camera_settings[1, 1] ])[None]).expand(batch_R.shape[0], 2) principle_point = (torch.tensor([ camera_settings[0, 2], camera_settings[1, 2] ])[None]).expand(batch_R.shape[0], 2) # FIXME: in this PyTorch3D version, the image_size in RasterizationSettings is (W, H), while in PerspectiveCameras is (H, W) # If the future pytorch3d change the format, please change the settings here # We hope PyTorch3D will solve this issue in the future batch_cameras = PerspectiveCameras( device=self.device, R=batch_R, T=batch_T, focal_length=focal_length, principal_point=principle_point, image_size=((self.params.img_size[1], self.params.img_size[0]), )) else: batch_cameras = PerspectiveCameras(device=self.device, R=batch_R, T=batch_T) mesh, laplacian_loss, flatten_loss = self.forward( self.params.mesh_batch_size) images_pred = self.renderer(mesh, device=self.device, cameras=batch_cameras)[..., -1] iou_loss = IOULoss().forward(batch_silhouettes, images_pred) loss = (iou_loss * self.params.lambda_iou + laplacian_loss * self.params.lambda_laplacian + flatten_loss * self.params.lambda_flatten) loop.set_description("Optimizing (loss %.4f)" % loss.data) self.losses["iou"].append(iou_loss * self.params.lambda_iou) self.losses["laplacian"].append(laplacian_loss * self.params.lambda_laplacian) self.losses["flatten"].append(flatten_loss * self.params.lambda_flatten) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if i % (self.params.mesh_show_step / 2) == 0 and self.params.mesh_log: logging.info( f'Iteration: {i} IOU Loss: {iou_loss.item()} Flatten Loss: {flatten_loss.item()} Laplacian Loss: {laplacian_loss.item()}' ) if i % self.params.mesh_show_step == 0 and self.params.im_show: # Write images image = images_pred.detach().cpu().numpy()[0] if writer: writer.append_data((255 * image).astype(np.uint8)) plt.imshow(images_pred.detach().cpu().numpy()[0]) plt.show() plt.imshow(batch_silhouettes.detach().cpu().numpy()[0]) plt.show() plot_pointcloud(mesh[0], 'Mesh deformed') logging.info( f'Pose of init camera: {self.init_camera_R.detach().cpu().numpy()}, {self.init_camera_t.detach().cpu().numpy()}' ) # Set the final optimized mesh as an internal variable self.final_mesh = mesh[0].clone() results = dict( silhouette_loss=self.losses["iou"] [-1].detach().cpu().numpy().tolist(), laplacian_loss=self.losses["laplacian"] [-1].detach().cpu().numpy().tolist(), flatten_loss=self.losses["flatten"] [-1].detach().cpu().numpy().tolist(), iterations_per_second=self.params.mesh_steps / (time.time() - start_time), total_time_s=time.time() - start_time, ) if self.is_real_data: self.init_pose_R = self.init_camera_R.detach().cpu().numpy() self.init_pose_t = self.init_camera_t.detach().cpu().numpy() torch.cuda.empty_cache() return results
def render_final_mesh(self, poses, mode: str, out_size: list, camera_settings=None) -> dict: """Renders the final mesh obtained through optimization Supports two modes: -predict: renders both silhouettes and flat shaded images -train: only renders silhouettes Returns: -dict of renders {'silhouettes': tensor, 'images': tensor} """ R, T = poses if len(R.shape) == 4: R = R.squeeze(1) T = T.squeeze(1) sil_renderer = silhouette_renderer(out_size, self.device) image_renderer = flat_renderer(out_size, self.device) # Create a silhouette projection of the mesh across all views all_silhouettes = [] all_images = [] for i in range(0, R.shape[0]): batch_R, batch_T = R[[i]], T[[i]] if self.params.is_real_data: init_R = quaternion_to_matrix(self.init_camera_R) batch_R = _broadcast_bmm(batch_R, init_R) batch_T = ( _broadcast_bmm(batch_T[:, None, :], init_R) + self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :] focal_length = torch.tensor( [camera_settings[0, 0], camera_settings[1, 1]])[None] principle_point = torch.tensor( [camera_settings[0, 2], camera_settings[1, 2]])[None] t_cameras = PerspectiveCameras( device=self.device, R=batch_R, T=batch_T, focal_length=focal_length, principal_point=principle_point, image_size=((self.params.img_size[1], self.params.img_size[0]), )) else: t_cameras = PerspectiveCameras(device=self.device, R=batch_R, T=batch_T) all_silhouettes.append( sil_renderer(self._final_mesh, device=self.device, cameras=t_cameras).detach().cpu()[..., -1]) if mode == "predict": all_images.append( torch.clamp( image_renderer(self._final_mesh, device=self.device, cameras=t_cameras), 0, 1, ).detach().cpu()[..., :3]) torch.cuda.empty_cache() renders = dict( silhouettes=torch.cat(all_silhouettes).unsqueeze(-1).permute( 0, 3, 1, 2), images=torch.cat(all_images) if all_images else [], ) return renders
def render_img(face_shape, face_color, facemodel, image_size=224, fx=1015.0, fy=1015.0, px=112.0, py=112.0, device='cuda:0'): ''' ref: https://github.com/facebookresearch/pytorch3d/issues/184 The rendering function (just for test) Input: face_shape: Tensor[1, 35709, 3] face_color: Tensor[1, 35709, 3] in [0, 1] facemodel: contains `tri` (triangles[70789, 3], index start from 1) ''' from pytorch3d.structures import Meshes from pytorch3d.renderer.mesh.textures import TexturesVertex from pytorch3d.renderer import (PerspectiveCameras, PointLights, RasterizationSettings, MeshRenderer, MeshRasterizer, SoftPhongShader, BlendParams) face_color = TexturesVertex(verts_features=face_color.to(device)) face_buf = torch.from_numpy(facemodel.tri - 1) # index start from 1 face_idx = face_buf.unsqueeze(0) mesh = Meshes(face_shape.to(device), face_idx.to(device), face_color) R = torch.eye(3).view(1, 3, 3).to(device) R[0, 0, 0] *= -1.0 T = torch.zeros([1, 3]).to(device) half_size = (image_size - 1.0) / 2 focal_length = torch.tensor([fx / half_size, fy / half_size], dtype=torch.float32).reshape(1, 2).to(device) principal_point = torch.tensor([(half_size - px) / half_size, (py - half_size) / half_size], dtype=torch.float32).reshape(1, 2).to(device) cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=focal_length, principal_point=principal_point) raster_settings = RasterizationSettings(image_size=image_size, blur_radius=0.0, faces_per_pixel=1) lights = PointLights(device=device, ambient_color=((1.0, 1.0, 1.0), ), diffuse_color=((0.0, 0.0, 0.0), ), specular_color=((0.0, 0.0, 0.0), ), location=((0.0, 0.0, 1e5), )) blend_params = BlendParams(background_color=(0.0, 0.0, 0.0)) renderer = MeshRenderer(rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings), shader=SoftPhongShader(device=device, cameras=cameras, lights=lights, blend_params=blend_params)) images = renderer(mesh) images = torch.clamp(images, 0.0, 1.0) return images
def batch_render( verts, faces, faces_per_pixel=10, K=None, rot=None, trans=None, colors=None, color=(0.53, 0.53, 0.8), # light_purple ambient_col=0.5, specular_col=0.2, diffuse_col=0.3, face_colors=None, # color = (0.74117647, 0.85882353, 0.65098039), # light_blue image_sizes=None, out_res=512, bin_size=0, shading="soft", mode="rgb", blend_gamma=1e-4, min_depth=None, ): device = torch.device("cuda:0") K = K.to(device) width, height = image_sizes[0] out_size = int(max(image_sizes[0])) raster_settings = RasterizationSettings( image_size=out_size, blur_radius=0.0, faces_per_pixel=faces_per_pixel, bin_size=bin_size, ) fx = K[:, 0, 0] fy = K[:, 1, 1] focals = torch.stack([fx, fy], 1) px = K[:, 0, 2] py = K[:, 1, 2] principal_point = torch.stack([width - px, height - py], 1) if rot is None: rot = torch.eye(3).unsqueeze(0).to(device) if trans is None: trans = torch.zeros(3).unsqueeze(0).to(device) cameras = PerspectiveCameras( device=device, focal_length=focals, principal_point=principal_point, image_size=[(out_size, out_size) for _ in range(len(verts))], R=rot, T=trans, ) if mode == "rgb" and shading == "soft": lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]]) lights = DirectionalLights( device=device, direction=((0.6, -0.6, -0.6), ), ambient_color=((ambient_col, ambient_col, ambient_col), ), diffuse_color=((diffuse_col, diffuse_col, diffuse_col), ), specular_color=((specular_col, specular_col, specular_col), ), ) shader = SoftPhongShader(device=device, cameras=cameras, lights=lights) elif mode == "silh": blend_params = BlendParams(sigma=1e-4, gamma=1e-4) shader = SoftSilhouetteShader(blend_params=blend_params) elif shading == "faceidx": shader = FaceIdxShader() elif (mode == "facecolor") and (shading == "hard"): shader = FaceColorShader(face_colors=face_colors) elif (mode == "facecolor") and (shading == "soft"): shader = SoftFaceColorShader(face_colors=face_colors, blend_gamma=blend_gamma) else: raise ValueError( f"Unhandled mode {mode} and shading {shading} combination") renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), shader=shader, ) if min_depth is not None: verts = torch.cat([verts[:, :, :2], verts[:, :, 2:].clamp(min_depth)], 2) if mode == "rgb": if colors is None: colors = get_colors(verts, color) tex = textures.TexturesVertex(verts_features=colors) meshes = Meshes(verts=verts, faces=faces, textures=tex) elif mode in ["silh", "facecolor"]: meshes = Meshes(verts=verts, faces=faces) else: raise ValueError(f"Render mode {mode} not in [rgb|silh]") square_images = renderer(meshes, cameras=cameras) height_off = int(width - height) # from matplotlib import pyplot as plt # plt.imshow(square_images.cpu()[0, :, :, 0]) # plt.savefig("tmp.png") images = torch.flip(square_images, (1, 2))[:, height_off:] return images
def forward(self, verts, # under general camera coordinate rXdYfZ, N*V*3 faces, # indices in verts to define traingles, N*F*3 verts_uvs, # uv coordinate of corresponding verts, N*V*2 faces_uvs, # indices in verts to define triangles, N*F*3 tex_image, # under GCcd, N*H*W*3 R, # under GCcd, N*3*3 T, # under GCcd, N*3 f, # in pixel/m, N*1 C, # camera center, N*2 imgres, # int lightLoc = None): assert verts.shape[0] == 1,\ 'with some issues in pytorch3D, render 1 mesh per forward' # only need to convert either R and T or verts, we choose R and T here if self.convertToPytorch3D: R = torch.matmul(self.GCcdToPytorch3D, R) T = torch.matmul(self.GCcdToPytorch3D, T.unsqueeze(-1)).squeeze(-1) # prepare textures and mesh to render tex = TexturesUV( verts_uvs = verts_uvs, faces_uvs = faces_uvs, maps = tex_image ) mesh = Meshes(verts = verts, faces = faces, textures=tex) # Initialize a camera. The world coordinate is +Y up, +X left and +Z in. cameras = PerspectiveCameras( focal_length=f, principal_point=C, R=R, T=T, image_size=((imgres,imgres),), device=self.device ) # Define the settings for rasterization and shading. raster_settings = RasterizationSettings( image_size=imgres, blur_radius=0.0, faces_per_pixel=1, ) # Create a simple renderer by composing a rasterizer and a shader. # The simple textured shader will interpolate the texture uv coordinates # for each pixel, sample from a texture image. This renderer can # support lighting easily but we do not iimplement it. renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings ), shader=SimpleShader( device=self.device ) ) # render the rendered image(s) images = renderer(mesh) return images