def __init__(self, cfg, output_dir="./vis"):
        """
        Args:
            cfg (CfgNode):
        """
        self.predictor = VoxMeshHead(cfg)

        #Load pretrained weights into model
        cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT))
        state_dict = clean_state_dict(cp["best_states"]["model"])
        self.predictor.load_state_dict(state_dict)

        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir
      def __init__(self, cfg, checkpoint_lp_model, output_dir="./vis"):
        """
        Args:
            cfg (CfgNode):
        """
        self.predictor =  VoxMeshHead(cfg)

        self.device = torch.device('cuda')
        #Load pretrained weights into model
        cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT))
        state_dict = clean_state_dict(cp["best_states"]["model"])
        self.predictor.load_state_dict(state_dict)

        self.loss_predictor = LossPredictionModule()
        #Path to trained prediction module
        state_dict = torch.load(checkpoint_lp_model, map_location='cuda:0')
        self.loss_predictor.load_state_dict(state_dict)

        self.predictor.to(self.device)
        self.loss_predictor.to(self.device)
        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir
Exemplo n.º 3
0
class VisualizationDemo(object):
    def __init__(self, cfg, output_dir="./vis"):
        """
        Args:
            cfg (CfgNode):
        """
        self.predictor = VoxMeshHead(cfg)

        self.device = torch.device('cuda')
        #Load pretrained weights into model
        cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT))
        state_dict = clean_state_dict(cp["best_states"]["model"])
        self.predictor.load_state_dict(state_dict)

        self.loss_predictor = LossPredictionModule()
        #Path to trained prediction module
        state_dict = torch.load(
            './output_prediction_module/lr1e-4/prediction_module_1500.pth',
            map_location='cuda:0')
        self.loss_predictor.load_state_dict(state_dict)

        self.predictor.to(self.device)
        self.loss_predictor.to(self.device)
        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir

    def run_on_image(self, image, id_str, gt_verts, gt_faces):
        deprocess = imagenet_deprocess(rescale_image=False)

        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(image.to(self.device))

        sid, mid, iid = id_str.split('-')
        iid = int(iid)

        #Transform vertex space
        metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed',
                                     sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RTs = metadata["extrinsics"].to(self.device)
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).to(RTs)

        mesh = meshes_pred[-1][0]
        #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
        #for the GT mesh
        invRT = torch.inverse(RTs[iid].mm(rot_y_90))
        invRT_no_rot = torch.inverse(RTs[iid])
        mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT)

        #Get look at view extrinsics
        render_metadata_path = os.path.join(
            'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid,
            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics'].to(self.device)

        verts, faces = mesh.get_mesh_verts_faces(0)
        verts_rgb = torch.ones_like(verts)[None]
        textures = Textures(verts_rgb=verts_rgb.to(self.device))
        mesh.textures = textures

        plt.figure(figsize=(10, 10))

        #Silhouette Renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )

        gt_verts = gt_verts.to(self.device)
        gt_faces = gt_faces.to(self.device)
        verts_rgb = torch.ones_like(gt_verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        #Invert without the rotation for the vehicle class
        if sid == '02958343':
            gt_verts = project_verts(gt_verts, invRT_no_rot.to(self.device))
        else:
            gt_verts = project_verts(gt_verts, invRT.to(self.device))
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures)

        probability_map = 0.01 * torch.ones((1, 24))
        viewgrid = torch.zeros(
            (1, 24, render_image_size, render_image_size)).to(self.device)
        fig = plt.figure(1)
        ax_pred = [fig.add_subplot(5, 5, i + 1) for i in range(24)]
        #fig = plt.figure(2)
        #ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)]

        for i in range(len(render_RTs)):
            if i == iid:  #Don't include current view
                continue

            R = render_RTs[i][:3, :3].unsqueeze(0)
            T = render_RTs[i][:3, 3].unsqueeze(0)
            cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)

            silhouette_renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftSilhouetteShader(blend_params=blend_params))

            ref_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) >
                         0).float()
            silhouette_image = (silhouette_renderer(
                meshes_world=mesh, R=R, T=T) > 0).float()

            # MSE Loss between both silhouettes
            silh_loss = torch.sum(
                (silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
            probability_map[0, i] = silh_loss.detach()

            viewgrid[0, i] = silhouette_image[..., -1]

            #ax_gt[i].imshow(ref_image[0,:,:,3].cpu().numpy())
            #ax_gt[i].set_title(i)

            ax_pred[i].imshow(silhouette_image[0, :, :, 3].cpu().numpy())
            ax_pred[i].set_title(i)

        img = image_to_numpy(deprocess(image[0]))
        #ax_gt[iid].imshow(img)
        ax_pred[iid].imshow(img)
        #fig = plt.figure(3)
        #ax = fig.add_subplot(111)
        #ax.imshow(img)

        pred_prob_map = self.loss_predictor(viewgrid)
        print('Highest actual loss: {}'.format(torch.argmax(probability_map)))
        print('Highest predicted loss: {}'.format(torch.argmax(pred_prob_map)))
        plt.show()
class VisualizationDemo(object):
      def __init__(self, cfg, checkpoint_lp_model, output_dir="./vis"):
        """
        Args:
            cfg (CfgNode):
        """
        self.predictor =  VoxMeshHead(cfg)

        self.device = torch.device('cuda')
        #Load pretrained weights into model
        cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT))
        state_dict = clean_state_dict(cp["best_states"]["model"])
        self.predictor.load_state_dict(state_dict)

        self.loss_predictor = LossPredictionModule()
        #Path to trained prediction module
        state_dict = torch.load(checkpoint_lp_model, map_location='cuda:0')
        self.loss_predictor.load_state_dict(state_dict)

        self.predictor.to(self.device)
        self.loss_predictor.to(self.device)
        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir

      def run_on_image(self, imgs, id_strs, meshes_gt, render_RTs, RTs):
        deprocess = imagenet_deprocess(rescale_image=False)

        #Silhouette Renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size, 
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
            faces_per_pixel=50, 
        )   

        rot_y_90 = torch.tensor([[0, 0, 1, 0], 
                                    [0, 1, 0, 0], 
                                    [-1, 0, 0, 0], 
                                    [0, 0, 0, 1]]).to(RTs) 
        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(imgs)

        B,_,H,W = imgs.shape
        probability_map = 0.01 * torch.ones((B, 24)).to(self.device)
        viewgrid = torch.zeros((B,24,render_image_size,render_image_size)).to(device) # batch size x 24 x H x W
        _meshes_pred = meshes_pred[-1]

        for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(zip(meshes_gt, _meshes_pred)):
            sid = id_strs[b].split('-')[0]

            RT = RTs[b]

            #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
            #for the GT mesh
            invRT = torch.inverse(RT.mm(rot_y_90))
            invRT_no_rot = torch.inverse(RT)

            cur_pred_mesh._verts_list[0] = project_verts(cur_pred_mesh._verts_list[0], invRT)
            #Invert without the rotation for the vehicle class
            if sid == '02958343':
                cur_gt_mesh._verts_list[0] = project_verts(
                    cur_gt_mesh._verts_list[0], invRT_no_rot)
            else:
                cur_gt_mesh._verts_list[0] = project_verts(
                    cur_gt_mesh._verts_list[0], invRT)

            '''
            plt.figure(figsize=(10, 10))

            fig = plt.figure(1)
            ax_pred = [fig.add_subplot(5,5,i+1) for i in range(24)]
            fig = plt.figure(2)
            ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)]
            '''

            for iid in range(len(render_RTs)):

                R = render_RTs[b][iid][:3,:3].unsqueeze(0)
                T = render_RTs[b][iid][:3,3].unsqueeze(0)
                cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)
                silhouette_renderer = MeshRenderer(
                    rasterizer=MeshRasterizer(
                    cameras=cameras, 
                    raster_settings=raster_settings
                    ),  
                shader=SoftSilhouetteShader(blend_params=blend_params)
                )   

                ref_image        = (silhouette_renderer(meshes_world=cur_gt_mesh, R=R, T=T)>0).float()
                silhouette_image = (silhouette_renderer(meshes_world=cur_pred_mesh, R=R, T=T)>0).float()

                #Add image silhouette to viewgrid
                viewgrid[b,iid] = silhouette_image[...,-1]

                # MSE Loss between both silhouettes
                silh_loss = torch.sum((silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3]) ** 2)
                probability_map[b, iid] = silh_loss.detach()

                '''
                ax_gt[iid].imshow(ref_image[0,:,:,3].cpu().numpy())
                ax_gt[iid].set_title(iid)
                ax_pred[iid].imshow(silhouette_image[0,:,:,3].cpu().numpy())
                ax_pred[iid].set_title(iid)
                '''

            '''
            img = image_to_numpy(deprocess(imgs[b]))
            ax_gt[iid].imshow(img)
            ax_pred[iid].imshow(img)
            fig = plt.figure(3)
            ax = fig.add_subplot(111)
            ax.imshow(img)

            #plt.show()
            '''

        probability_map = probability_map/(torch.max(probability_map, dim=1)[0].unsqueeze(1)) # Normalize
        probability_map = torch.nn.functional.softmax(probability_map, dim=1) # Softmax across images
        pred_prob_map = self.loss_predictor(viewgrid)

        gt_max = torch.argmax(probability_map, dim=1)
        pred_max = torch.argmax(pred_prob_map, dim=1)

        #print('--'*30)
        #print('Item: {}'.format(id_str))
        #print('Highest actual loss: {}'.format(gt_max))
        #print('Highest predicted loss: {}'.format(pred_max))

        #print('GT prob map: {}'.format(probability_map.squeeze()))
        #print('Pred prob map: {}'.format(pred_prob_map.squeeze()))

        correct = torch.sum(pred_max == gt_max).item()
        total   = len(pred_prob_map)

        return correct, total
class VisualizationDemo(object):
    def __init__(self, cfg, output_dir="./vis"):
        """
        Args:
            cfg (CfgNode):
        """
        self.predictor = VoxMeshHead(cfg)

        #Load pretrained weights into model
        cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT))
        state_dict = clean_state_dict(cp["best_states"]["model"])
        self.predictor.load_state_dict(state_dict)

        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir

    def run_on_image(self, image, id_str, gt_verts, gt_faces):
        deprocess = imagenet_deprocess(rescale_image=False)

        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(image)

        sid, mid, iid = id_str.split('-')
        iid = int(iid)

        #Transform vertex space
        metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed',
                                     sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RTs = metadata["extrinsics"]
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).to(RTs)

        mesh = meshes_pred[-1][0]
        #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
        #for the GT mesh
        invRT = torch.inverse(RTs[iid].mm(rot_y_90))
        invRT_no_rot = torch.inverse(RTs[iid])
        mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT.cpu())

        #Get look at view extrinsics
        render_metadata_path = os.path.join(
            'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid,
            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics']

        plt.figure(figsize=(10, 10))
        R = render_RTs[iid][:3, :3].unsqueeze(0)
        T = render_RTs[iid][:3, 3].unsqueeze(0)
        cameras = OpenGLPerspectiveCameras(R=R, T=T)

        #Phong Renderer
        lights = PointLights(location=[[0.0, 0.0, -3.0]])
        raster_settings = RasterizationSettings(image_size=256,
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                bin_size=0)
        phong_renderer = MeshRenderer(rasterizer=MeshRasterizer(
            cameras=cameras, raster_settings=raster_settings),
                                      shader=HardPhongShader(lights=lights))

        #Silhouette Renderer
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras,
                                      raster_settings=raster_settings),
            shader=SoftSilhouetteShader(blend_params=blend_params))

        verts, faces = mesh.get_mesh_verts_faces(0)
        verts_rgb = torch.ones_like(verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        mesh.textures = textures

        verts_rgb = torch.ones_like(gt_verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        #Invert without the rotation for the vehicle class
        if sid == '02958343':
            gt_verts = project_verts(gt_verts, invRT_no_rot.cpu())
        else:
            gt_verts = project_verts(gt_verts, invRT.cpu())
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures)

        img = image_to_numpy(deprocess(image[0]))
        mesh_image = phong_renderer(meshes_world=mesh, R=R, T=T)
        gt_silh_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) >
                         0).float()
        silhouette_image = (silhouette_renderer(meshes_world=mesh, R=R, T=T) >
                            0).float()

        plt.subplot(2, 2, 1)
        plt.imshow(img)
        plt.title('input image')
        plt.subplot(2, 2, 2)
        plt.imshow(mesh_image[0, ..., :3].cpu().numpy())
        plt.title('rendered mesh')
        plt.subplot(2, 2, 3)
        plt.imshow(gt_silh_image[0, ..., 3].cpu().numpy())
        plt.title('silhouette of gt mesh')
        plt.subplot(2, 2, 4)
        plt.imshow(silhouette_image[0, ..., 3].cpu().numpy())
        plt.title('silhouette of rendered mesh')

        plt.show()
        #plt.savefig('./output_demo/figures/'+id_str+'.png')

        vis_utils.visualize_prediction(id_str, img, mesh, self.output_dir)