def __init__(self, cfg, output_dir="./vis"): """ Args: cfg (CfgNode): """ self.predictor = VoxMeshHead(cfg) #Load pretrained weights into model cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT)) state_dict = clean_state_dict(cp["best_states"]["model"]) self.predictor.load_state_dict(state_dict) os.makedirs(output_dir, exist_ok=True) self.output_dir = output_dir
def __init__(self, cfg, checkpoint_lp_model, output_dir="./vis"): """ Args: cfg (CfgNode): """ self.predictor = VoxMeshHead(cfg) self.device = torch.device('cuda') #Load pretrained weights into model cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT)) state_dict = clean_state_dict(cp["best_states"]["model"]) self.predictor.load_state_dict(state_dict) self.loss_predictor = LossPredictionModule() #Path to trained prediction module state_dict = torch.load(checkpoint_lp_model, map_location='cuda:0') self.loss_predictor.load_state_dict(state_dict) self.predictor.to(self.device) self.loss_predictor.to(self.device) os.makedirs(output_dir, exist_ok=True) self.output_dir = output_dir
class VisualizationDemo(object): def __init__(self, cfg, output_dir="./vis"): """ Args: cfg (CfgNode): """ self.predictor = VoxMeshHead(cfg) self.device = torch.device('cuda') #Load pretrained weights into model cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT)) state_dict = clean_state_dict(cp["best_states"]["model"]) self.predictor.load_state_dict(state_dict) self.loss_predictor = LossPredictionModule() #Path to trained prediction module state_dict = torch.load( './output_prediction_module/lr1e-4/prediction_module_1500.pth', map_location='cuda:0') self.loss_predictor.load_state_dict(state_dict) self.predictor.to(self.device) self.loss_predictor.to(self.device) os.makedirs(output_dir, exist_ok=True) self.output_dir = output_dir def run_on_image(self, image, id_str, gt_verts, gt_faces): deprocess = imagenet_deprocess(rescale_image=False) with torch.no_grad(): voxel_scores, meshes_pred = self.predictor(image.to(self.device)) sid, mid, iid = id_str.split('-') iid = int(iid) #Transform vertex space metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed', sid, mid, "metadata.pt") metadata = torch.load(metadata_path) K = metadata["intrinsic"] RTs = metadata["extrinsics"].to(self.device) rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).to(RTs) mesh = meshes_pred[-1][0] #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis #for the GT mesh invRT = torch.inverse(RTs[iid].mm(rot_y_90)) invRT_no_rot = torch.inverse(RTs[iid]) mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT) #Get look at view extrinsics render_metadata_path = os.path.join( 'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid, 'rendering_metadata.pt') render_metadata = torch.load(render_metadata_path) render_RTs = render_metadata['extrinsics'].to(self.device) verts, faces = mesh.get_mesh_verts_faces(0) verts_rgb = torch.ones_like(verts)[None] textures = Textures(verts_rgb=verts_rgb.to(self.device)) mesh.textures = textures plt.figure(figsize=(10, 10)) #Silhouette Renderer render_image_size = 256 blend_params = BlendParams(sigma=1e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=render_image_size, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=50, ) gt_verts = gt_verts.to(self.device) gt_faces = gt_faces.to(self.device) verts_rgb = torch.ones_like(gt_verts)[None] textures = Textures(verts_rgb=verts_rgb) #Invert without the rotation for the vehicle class if sid == '02958343': gt_verts = project_verts(gt_verts, invRT_no_rot.to(self.device)) else: gt_verts = project_verts(gt_verts, invRT.to(self.device)) gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures) probability_map = 0.01 * torch.ones((1, 24)) viewgrid = torch.zeros( (1, 24, render_image_size, render_image_size)).to(self.device) fig = plt.figure(1) ax_pred = [fig.add_subplot(5, 5, i + 1) for i in range(24)] #fig = plt.figure(2) #ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)] for i in range(len(render_RTs)): if i == iid: #Don't include current view continue R = render_RTs[i][:3, :3].unsqueeze(0) T = render_RTs[i][:3, 3].unsqueeze(0) cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T) silhouette_renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), shader=SoftSilhouetteShader(blend_params=blend_params)) ref_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) > 0).float() silhouette_image = (silhouette_renderer( meshes_world=mesh, R=R, T=T) > 0).float() # MSE Loss between both silhouettes silh_loss = torch.sum( (silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3])**2) probability_map[0, i] = silh_loss.detach() viewgrid[0, i] = silhouette_image[..., -1] #ax_gt[i].imshow(ref_image[0,:,:,3].cpu().numpy()) #ax_gt[i].set_title(i) ax_pred[i].imshow(silhouette_image[0, :, :, 3].cpu().numpy()) ax_pred[i].set_title(i) img = image_to_numpy(deprocess(image[0])) #ax_gt[iid].imshow(img) ax_pred[iid].imshow(img) #fig = plt.figure(3) #ax = fig.add_subplot(111) #ax.imshow(img) pred_prob_map = self.loss_predictor(viewgrid) print('Highest actual loss: {}'.format(torch.argmax(probability_map))) print('Highest predicted loss: {}'.format(torch.argmax(pred_prob_map))) plt.show()
class VisualizationDemo(object): def __init__(self, cfg, checkpoint_lp_model, output_dir="./vis"): """ Args: cfg (CfgNode): """ self.predictor = VoxMeshHead(cfg) self.device = torch.device('cuda') #Load pretrained weights into model cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT)) state_dict = clean_state_dict(cp["best_states"]["model"]) self.predictor.load_state_dict(state_dict) self.loss_predictor = LossPredictionModule() #Path to trained prediction module state_dict = torch.load(checkpoint_lp_model, map_location='cuda:0') self.loss_predictor.load_state_dict(state_dict) self.predictor.to(self.device) self.loss_predictor.to(self.device) os.makedirs(output_dir, exist_ok=True) self.output_dir = output_dir def run_on_image(self, imgs, id_strs, meshes_gt, render_RTs, RTs): deprocess = imagenet_deprocess(rescale_image=False) #Silhouette Renderer render_image_size = 256 blend_params = BlendParams(sigma=1e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=render_image_size, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=50, ) rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).to(RTs) with torch.no_grad(): voxel_scores, meshes_pred = self.predictor(imgs) B,_,H,W = imgs.shape probability_map = 0.01 * torch.ones((B, 24)).to(self.device) viewgrid = torch.zeros((B,24,render_image_size,render_image_size)).to(device) # batch size x 24 x H x W _meshes_pred = meshes_pred[-1] for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(zip(meshes_gt, _meshes_pred)): sid = id_strs[b].split('-')[0] RT = RTs[b] #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis #for the GT mesh invRT = torch.inverse(RT.mm(rot_y_90)) invRT_no_rot = torch.inverse(RT) cur_pred_mesh._verts_list[0] = project_verts(cur_pred_mesh._verts_list[0], invRT) #Invert without the rotation for the vehicle class if sid == '02958343': cur_gt_mesh._verts_list[0] = project_verts( cur_gt_mesh._verts_list[0], invRT_no_rot) else: cur_gt_mesh._verts_list[0] = project_verts( cur_gt_mesh._verts_list[0], invRT) ''' plt.figure(figsize=(10, 10)) fig = plt.figure(1) ax_pred = [fig.add_subplot(5,5,i+1) for i in range(24)] fig = plt.figure(2) ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)] ''' for iid in range(len(render_RTs)): R = render_RTs[b][iid][:3,:3].unsqueeze(0) T = render_RTs[b][iid][:3,3].unsqueeze(0) cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T) silhouette_renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings ), shader=SoftSilhouetteShader(blend_params=blend_params) ) ref_image = (silhouette_renderer(meshes_world=cur_gt_mesh, R=R, T=T)>0).float() silhouette_image = (silhouette_renderer(meshes_world=cur_pred_mesh, R=R, T=T)>0).float() #Add image silhouette to viewgrid viewgrid[b,iid] = silhouette_image[...,-1] # MSE Loss between both silhouettes silh_loss = torch.sum((silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3]) ** 2) probability_map[b, iid] = silh_loss.detach() ''' ax_gt[iid].imshow(ref_image[0,:,:,3].cpu().numpy()) ax_gt[iid].set_title(iid) ax_pred[iid].imshow(silhouette_image[0,:,:,3].cpu().numpy()) ax_pred[iid].set_title(iid) ''' ''' img = image_to_numpy(deprocess(imgs[b])) ax_gt[iid].imshow(img) ax_pred[iid].imshow(img) fig = plt.figure(3) ax = fig.add_subplot(111) ax.imshow(img) #plt.show() ''' probability_map = probability_map/(torch.max(probability_map, dim=1)[0].unsqueeze(1)) # Normalize probability_map = torch.nn.functional.softmax(probability_map, dim=1) # Softmax across images pred_prob_map = self.loss_predictor(viewgrid) gt_max = torch.argmax(probability_map, dim=1) pred_max = torch.argmax(pred_prob_map, dim=1) #print('--'*30) #print('Item: {}'.format(id_str)) #print('Highest actual loss: {}'.format(gt_max)) #print('Highest predicted loss: {}'.format(pred_max)) #print('GT prob map: {}'.format(probability_map.squeeze())) #print('Pred prob map: {}'.format(pred_prob_map.squeeze())) correct = torch.sum(pred_max == gt_max).item() total = len(pred_prob_map) return correct, total
class VisualizationDemo(object): def __init__(self, cfg, output_dir="./vis"): """ Args: cfg (CfgNode): """ self.predictor = VoxMeshHead(cfg) #Load pretrained weights into model cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT)) state_dict = clean_state_dict(cp["best_states"]["model"]) self.predictor.load_state_dict(state_dict) os.makedirs(output_dir, exist_ok=True) self.output_dir = output_dir def run_on_image(self, image, id_str, gt_verts, gt_faces): deprocess = imagenet_deprocess(rescale_image=False) with torch.no_grad(): voxel_scores, meshes_pred = self.predictor(image) sid, mid, iid = id_str.split('-') iid = int(iid) #Transform vertex space metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed', sid, mid, "metadata.pt") metadata = torch.load(metadata_path) K = metadata["intrinsic"] RTs = metadata["extrinsics"] rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).to(RTs) mesh = meshes_pred[-1][0] #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis #for the GT mesh invRT = torch.inverse(RTs[iid].mm(rot_y_90)) invRT_no_rot = torch.inverse(RTs[iid]) mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT.cpu()) #Get look at view extrinsics render_metadata_path = os.path.join( 'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid, 'rendering_metadata.pt') render_metadata = torch.load(render_metadata_path) render_RTs = render_metadata['extrinsics'] plt.figure(figsize=(10, 10)) R = render_RTs[iid][:3, :3].unsqueeze(0) T = render_RTs[iid][:3, 3].unsqueeze(0) cameras = OpenGLPerspectiveCameras(R=R, T=T) #Phong Renderer lights = PointLights(location=[[0.0, 0.0, -3.0]]) raster_settings = RasterizationSettings(image_size=256, blur_radius=0.0, faces_per_pixel=1, bin_size=0) phong_renderer = MeshRenderer(rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings), shader=HardPhongShader(lights=lights)) #Silhouette Renderer blend_params = BlendParams(sigma=1e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=256, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=50, ) silhouette_renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), shader=SoftSilhouetteShader(blend_params=blend_params)) verts, faces = mesh.get_mesh_verts_faces(0) verts_rgb = torch.ones_like(verts)[None] textures = Textures(verts_rgb=verts_rgb) mesh.textures = textures verts_rgb = torch.ones_like(gt_verts)[None] textures = Textures(verts_rgb=verts_rgb) #Invert without the rotation for the vehicle class if sid == '02958343': gt_verts = project_verts(gt_verts, invRT_no_rot.cpu()) else: gt_verts = project_verts(gt_verts, invRT.cpu()) gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures) img = image_to_numpy(deprocess(image[0])) mesh_image = phong_renderer(meshes_world=mesh, R=R, T=T) gt_silh_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) > 0).float() silhouette_image = (silhouette_renderer(meshes_world=mesh, R=R, T=T) > 0).float() plt.subplot(2, 2, 1) plt.imshow(img) plt.title('input image') plt.subplot(2, 2, 2) plt.imshow(mesh_image[0, ..., :3].cpu().numpy()) plt.title('rendered mesh') plt.subplot(2, 2, 3) plt.imshow(gt_silh_image[0, ..., 3].cpu().numpy()) plt.title('silhouette of gt mesh') plt.subplot(2, 2, 4) plt.imshow(silhouette_image[0, ..., 3].cpu().numpy()) plt.title('silhouette of rendered mesh') plt.show() #plt.savefig('./output_demo/figures/'+id_str+'.png') vis_utils.visualize_prediction(id_str, img, mesh, self.output_dir)