Esempio n. 1
0
def render_mesh(verts, faces):
    device = verts[0].get_device()
    N = len(verts)
    num_verts_per_mesh = []
    for i in range(N):
        num_verts_per_mesh.append(verts[i].shape[0])
    verts_rgb = torch.ones((N, np.max(num_verts_per_mesh), 3),
                           requires_grad=False,
                           device=device)
    for i in range(N):
        verts_rgb[i, num_verts_per_mesh[i]:, :] = -1
    textures = Textures(verts_rgb=verts_rgb)

    meshes = Meshes(verts=verts, faces=faces, textures=textures)
    elev = torch.rand(N) * 30 - 15
    azim = torch.rand(N) * 360 - 180
    R, T = look_at_view_transform(dist=2, elev=elev, azim=azim)
    cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
    sigma = 1e-4
    raster_settings = RasterizationSettings(
        image_size=128,
        blur_radius=np.log(1. / 1e-4 - 1.) * sigma,
        faces_per_pixel=40,
        perspective_correct=False)
    renderer = MeshRenderer(rasterizer=MeshRasterizer(
        cameras=cameras, raster_settings=raster_settings),
                            shader=SoftSilhouetteShader())
    return renderer(meshes)
Esempio n. 2
0
 def render_sil(self, meshes):
     self.renderer = MeshRenderer(
         rasterizer=MeshRasterizer(
             cameras=self.cameras,
             raster_settings=self.text_raster_settings),
         shader=SoftSilhouetteShader(blend_params=self.blend_params))
     return self.renderer(meshes_world=meshes)
Esempio n. 3
0
    def __init__(self, meshes: Meshes, image_size=256, device='cuda'):
        """
        Initialization of MaskRenderer. Renderer is initialized with predefined rasterizer and shader.
        A soft silhouette shader is used to compute the projection mask.

        :param meshes: A batch of meshes. pytorch3d.structures.Meshes. Dimension meshes \in R^N.
                        View https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/structures/meshes.py
                        for additional information.
                        In our case it is usually only one batch which is the template for a certain category.
        :param device: The device, on which the computation is done.
        :param image_size: Image size for the rasterization. Default is 256.
        """
        super(MaskRenderer, self).__init__()

        self.device = device
        self._meshes = meshes

        cameras = OpenGLOrthographicCameras(device=device)

        # parameter settings as of Pytorch3D Tutorial
        # (https://pytorch3d.org/tutorials/camera_position_optimization_with_differentiable_rendering)

        self._rasterizer = MeshRasterizer(
            cameras=cameras,
            raster_settings=RasterizationSettings(image_size=image_size))

        self._shader = SoftSilhouetteShader(
            blend_params=(BlendParams(sigma=1e-4, gamma=1e-4)))
    def _set_renderer(self):
        if self.cameras is None:
            raise ValueError('cameras is None in pytorch3D renderer!')

        rasterizer = MeshRasterizer(cameras=self.cameras,
                                    raster_settings=self.raster_settings)

        silhouette_shader = SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
        self.mask_renderer = MeshRenderer(rasterizer=rasterizer,
                                          shader=silhouette_shader)
Esempio n. 5
0
    def init_silhouette_renderer(self):
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=self.size,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=100)

        silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=self.camera,
                                      raster_settings=raster_settings),
            shader=SoftSilhouetteShader(blend_params=blend_params))
        return silhouette_renderer
Esempio n. 6
0
def define_camera(image_size=640,
                  image_height=480,
                  image_width=640,
                  fx=500,
                  fy=500,
                  cx=320,
                  cy=240,
                  device="cuda:0"):
    # define camera
    cameras = OpenGLRealPerspectiveCameras(
        focal_length=((fx, fy), ),  # Nx2
        principal_point=((cx, cy), ),  # Nx2
        x0=0,
        y0=0,
        w=image_size,
        h=image_size,
        znear=0.0001,
        zfar=100.0,
        device=device)

    # We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
    phong_raster_settings = RasterizationSettings(image_size=image_size,
                                                  blur_radius=0.0,
                                                  faces_per_pixel=1)

    # We can add a point light in front of the object.
    # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
    phong_renderer = MeshRendererDepth(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=phong_raster_settings),
        shader=TexturedSoftPhongShader(device=device)).to(device)

    # To blend the 100 faces we set a few parameters which control the opacity and the sharpness of edges. Refer to blending.py for more details.
    blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

    # Define the settings for rasterization and shading. Here we set the output image to be of size 640x640. To form the blended image we use 100 faces for each pixel. Refer to rasterize_meshes.py for an explanation of this parameter.
    silhouette_raster_settings = RasterizationSettings(
        image_size=image_size,  # longer side or scaled longer side
        # blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
        blur_radius=0.0,
        # The nearest faces_per_pixel points along the z-axis.
        faces_per_pixel=1)

    # Create a silhouette mesh renderer by composing a rasterizer and a shader
    silhouete_renderer = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=silhouette_raster_settings),
        shader=SoftSilhouetteShader(blend_params=blend_params)).to(device)
    return phong_renderer, silhouete_renderer
Esempio n. 7
0
def silhouette_renderer(img_size: tuple, device: str):

    blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

    raster_settings = RasterizationSettings(
        image_size=img_size,
        blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
        faces_per_pixel=100,
        perspective_correct=False)

    # Create a silhouette mesh renderer by composing a rasterizer and a shader.
    silhouette_renderer = MeshRenderer(
        rasterizer=MeshRasterizer(raster_settings=raster_settings),
        shader=SoftSilhouetteShader(blend_params=blend_params),
    )

    return silhouette_renderer
Esempio n. 8
0
  def __init__(self, opt):
    super(PtRender, self).__init__()
    self.opt = opt
    self.input_res = opt.input_res
    model_path = os.path.join(os.path.dirname(__file__), '..', opt.BFM)
    self.BFM = BFM(model_path)

    f = 1015.
    self.f = f
    c = self.input_res / 2

    K = [[f,  0., c],
         [0., f,  c],
         [0., 0., 1.]]
    self.register_buffer('K', torch.FloatTensor(K))
    self.register_buffer('inv_K', torch.inverse(self.K).unsqueeze(0))
    self.K = self.K.unsqueeze(0)
    self.set_Illu_consts()

    # for pytorch3d
    self.t = torch.zeros([1, 3], dtype=torch.float32)
    self.pt = torch.zeros([1, 2], dtype=torch.float32)
    self.fl = f * 2 / self.input_res,
    ptR = [[[-1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.]]]
    self.ptR = torch.FloatTensor(ptR)

    blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0, 0, 0))
    raster_settings = RasterizationSettings(
      image_size=self.input_res,
      blur_radius=0,
      faces_per_pixel=1,
      max_faces_per_bin=1000000,
    )

    # renderer
    cameras = SfMPerspectiveCameras(focal_length=self.fl,
                                    R=self.ptR.expand(opt.batch_size, -1, -1),
                                    device='cuda')
    rasterizer = MeshRasterizer(raster_settings=raster_settings)
    shader_rgb = HardFlatShader(blend_params=blend_params)

    self.renderer = Renderer(rasterizer, shader_rgb, SoftSilhouetteShader(), cameras)
Esempio n. 9
0
    def _set_renderer(self):
        if self.cameras is None:
            raise ValueError('cameras is None in pytorch3D renderer!')

        rasterizer = MeshRasterizer(cameras=self.cameras,
                                    raster_settings=self.raster_settings)

        texture_shader = TexturedSoftPhongShader(device=self.device,
                                                 cameras=self.cameras,
                                                 lights=self.lights)

        silhouette_shader = SoftSilhouetteShader(
            blend_params=BlendParams(sigma=1e-4, gamma=1e-4))

        self.mesh_renderer = MeshRenderer(rasterizer=rasterizer,
                                          shader=texture_shader)

        self.mask_renderer = MeshRenderer(rasterizer=rasterizer,
                                          shader=silhouette_shader)
Esempio n. 10
0
    def __init__(self, image_size, device):
        super(Renderer, self).__init__()

        self.image_size = image_size
        R, T = look_at_view_transform(2.7, 0, 0, device=device) 
        self.cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
        self.mesh_color = torch.FloatTensor(config.MESH_COLOR).to(device)[None, None, :] / 255.0

        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=self.image_size, 
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
            faces_per_pixel=100, 
        )

        self.silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=self.cameras, 
                raster_settings=raster_settings
            ),
            shader=SoftSilhouetteShader(blend_params=blend_params)
        )

        raster_settings_color = RasterizationSettings(
            image_size=self.image_size, 
            blur_radius=0.0, 
            faces_per_pixel=1, 
        )
        
        lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])

        self.color_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=self.cameras, 
                raster_settings=raster_settings_color
            ),
            shader=HardPhongShader(
                device=device, 
                cameras=self.cameras,
                lights=lights,
            )
        )
Esempio n. 11
0
def render_mesh(mesh, R, T, device, img_size=512, silhouette=False):
    cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)


    if silhouette:
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=img_size, 
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
            faces_per_pixel=100, 
        )
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras, 
                raster_settings=raster_settings
            ),
            shader=SoftSilhouetteShader(blend_params=blend_params)
        )
    else:
        raster_settings = RasterizationSettings(
            image_size=img_size, 
            blur_radius=0.0, 
            faces_per_pixel=1, 
        )
        lights = PointLights(device=device, location=[[0.0, 5.0, -10.0]])
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras, 
                raster_settings=raster_settings
            ),
            shader=SoftPhongShader(
                device=device, 
                cameras=cameras,
                lights=lights
            )
        )

    rendered_images = renderer(mesh, cameras=cameras)
    return rendered_images
Esempio n. 12
0
    def __init__(self, meshes: Meshes, image_size=256):
        """
        Initialization of the Renderer Class. Instances of the mask and depth renderer are create on corresponding
        device.

        :param device: The device, on which the computation is done.
        :param image_size: Image size for the rasterization. Default is 256.
        """
        super().__init__()
        self.meshes = meshes
        device = meshes.device

        # TODO: check how to implement weak perspective (scaled orthographic).
        cameras = OpenGLOrthographicCameras(device=device)

        self._rasterizer = MeshRasterizer(
            cameras=cameras,
            raster_settings=RasterizationSettings(image_size=image_size,
                                                  faces_per_pixel=100))

        self._shader = SoftSilhouetteShader(
            blend_params=(BlendParams(sigma=1e-4, gamma=1e-4)))
Esempio n. 13
0
    def __init__(self, meshes: Meshes, device: str, image_size: int = 256):
        """
           Initialization of DepthRenderer. Initialization of the default mesh rasterizer and silhouette shader which is
           used because of simplicity.

           :param device: The device, on which the computation is done, e.g. cpu or cuda.
           :param image_size: Image size for the rasterization. Default is 256.
           :param meshes: A batch of meshes. pytorch3d.structures.Meshes. Dimension meshes in R^N.
                View https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/structures/meshes.py
                for additional information.
       """
        super(DepthRenderer, self).__init__()
        self._meshes = meshes

        # TODO: check how to implement weak perspective (scaled orthographic).
        cameras = OpenGLOrthographicCameras(device=device)

        raster_settings = RasterizationSettings(image_size=image_size)
        self._rasterizer = MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings)

        self._shader = SoftSilhouetteShader(
            blend_params=(BlendParams(sigma=1e-4, gamma=1e-4)))
Esempio n. 14
0
    def __init__(self,
                 batch_size=100,
                 latent_size=128,
                 img_size=128,
                 seed_sphere_divisions=3):
        super(Generator, self).__init__()
        src_mesh = ico_sphere(seed_sphere_divisions)
        output_shape = src_mesh.verts_packed().shape
        self.src_meshes = src_mesh.extend(batch_size).cuda()
        self.layers = LinearNet(latent_size, output_shape).cuda()
        """ Setup rendering. """
        num_views = batch_size
        self.lights = PointLights(location=[[0.0, 0.0, -3.0]], device=device)

        sigma = 1e-4
        raster_settings_silhouette = RasterizationSettings(
            image_size=128,
            blur_radius=np.log(1. / 1e-4 - 1.) * sigma,
            faces_per_pixel=50,
        )
        self.renderer_silhouette = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=None, raster_settings=raster_settings_silhouette),
            shader=SoftSilhouetteShader()).cuda()
    def initRender(self, method, image_size):
        cameras = OpenGLPerspectiveCameras(device=self.device, fov=15)

        if (method == "soft-silhouette"):
            blend_params = BlendParams(sigma=1e-7, gamma=1e-7)

            raster_settings = RasterizationSettings(
                image_size=image_size,
                blur_radius=np.log(1. / 1e-7 - 1.) * blend_params.sigma,
                faces_per_pixel=self.faces_per_pixel)

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftSilhouetteShader(blend_params=blend_params))
        elif (method == "hard-silhouette"):
            blend_params = BlendParams(sigma=1e-7, gamma=1e-7)

            raster_settings = RasterizationSettings(
                image_size=image_size,
                blur_radius=np.log(1. / 1e-7 - 1.) * blend_params.sigma,
                faces_per_pixel=1)

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftSilhouetteShader(blend_params=blend_params))
        elif (method == "soft-depth"):
            # Soft Rasterizer - from https://github.com/facebookresearch/pytorch3d/issues/95
            #blend_params = BlendParams(sigma=1e-7, gamma=1e-7)
            blend_params = BlendParams(sigma=1e-3, gamma=1e-4)
            raster_settings = RasterizationSettings(
                image_size=image_size,
                #blur_radius= np.log(1. / 1e-7 - 1.) * blend_params.sigma,
                blur_radius=np.log(1. / 1e-3 - 1.) * blend_params.sigma,
                faces_per_pixel=self.faces_per_pixel)

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftDepthShader(blend_params=blend_params))
        elif (method == "hard-depth"):
            raster_settings = RasterizationSettings(image_size=image_size,
                                                    blur_radius=0,
                                                    faces_per_pixel=20)

            renderer = MeshRenderer(rasterizer=MeshRasterizer(
                cameras=cameras, raster_settings=raster_settings),
                                    shader=HardDepthShader())
        elif (method == "blurry-depth"):
            # Soft Rasterizer - from https://github.com/facebookresearch/pytorch3d/issues/95
            blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
            raster_settings = RasterizationSettings(
                image_size=image_size,
                blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
                faces_per_pixel=self.faces_per_pixel)

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftDepthShader(blend_params=blend_params))
        elif (method == "soft-phong"):
            blend_params = BlendParams(sigma=1e-3, gamma=1e-3)

            raster_settings = RasterizationSettings(
                image_size=image_size,
                blur_radius=np.log(1. / 1e-3 - 1.) * blend_params.sigma,
                faces_per_pixel=self.faces_per_pixel)

            # lights = DirectionalLights(device=self.device,
            #                            ambient_color=[[0.25, 0.25, 0.25]],
            #                            diffuse_color=[[0.6, 0.6, 0.6]],
            #                            specular_color=[[0.15, 0.15, 0.15]],
            #                            direction=[[0.0, 1.0, 0.0]])

            lights = DirectionalLights(device=self.device,
                                       direction=[[0.0, 1.0, 0.0]])

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftPhongShader(device=self.device,
                                       blend_params=blend_params,
                                       lights=lights))

        elif (method == "hard-phong"):
            blend_params = BlendParams(sigma=1e-8, gamma=1e-8)

            raster_settings = RasterizationSettings(image_size=image_size,
                                                    blur_radius=0.0,
                                                    faces_per_pixel=1)

            lights = DirectionalLights(device=self.device,
                                       ambient_color=[[0.25, 0.25, 0.25]],
                                       diffuse_color=[[0.6, 0.6, 0.6]],
                                       specular_color=[[0.15, 0.15, 0.15]],
                                       direction=[[-1.0, -1.0, 1.0]])
            renderer = MeshRenderer(rasterizer=MeshRasterizer(
                cameras=cameras, raster_settings=raster_settings),
                                    shader=HardPhongShader(device=self.device,
                                                           lights=lights))

        else:
            print("Unknown render method!")
            return None
        return renderer
Esempio n. 16
0
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device,
                  loss_fn):

    #if comm.is_main_process():
    #    wandb.init(project='MeshRCNN', config=cfg, name='prediction_module')

    Timer.timing = False
    iteration_timer = Timer("Iteration")

    # model.parameters() is surprisingly expensive at 150ms, so cache it
    if hasattr(model, "module"):
        params = list(model.module.parameters())
    else:
        params = list(model.parameters())
    loss_moving_average = cp.data.get("loss_moving_average", None)

    # Zhengyuan modification
    loss_predictor = LossPredictionModule().to(device)
    loss_pred_optim = torch.optim.Adam(loss_predictor.parameters(), lr=1e-5)

    while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
        if comm.is_main_process():
            logger.info("Starting epoch %d / %d" %
                        (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))

        # When using a DistributedSampler we need to manually set the epoch so that
        # the data is shuffled differently at each epoch
        for loader in loaders.values():
            if hasattr(loader.sampler, "set_epoch"):
                loader.sampler.set_epoch(cp.epoch)

        # Config settings for renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).float().to(device)

        for i, batch in enumerate(loaders["train"]):
            if i == 0:
                iteration_timer.start()
            else:
                iteration_timer.tick()

            batch = loaders["train"].postprocess(batch, device)
            if dataset == 'MeshVoxMulti':
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt, id_strs, _, render_RTs, RTs = batch
            else:
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch

            with inference_context(model):
                # NOTE: _imgs contains all of the other images in belonging to this model
                # We have to select the next-best-view from that list of images

                model_kwargs = {}
                if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS:
                    model_kwargs["voxel_only"] = True
                with Timer("Forward"):
                    voxel_scores, meshes_pred = model(imgs, **model_kwargs)

            total_silh_loss = torch.tensor(
                0.)  # Total silhouette loss, to be added to "loss" below
            # Voxel only training for first few iterations
            if not meshes_gt is None and not model_kwargs.get(
                    "voxel_only", False):
                _meshes_pred = meshes_pred[-1].clone()
                _meshes_gt = meshes_gt[-1].clone()

                # Render masks from predicted mesh for each view
                # GT probability map to supervise prediction module
                B = len(meshes_gt)
                probability_map = 0.01 * torch.ones(
                    (B, 24)).to(device)  # batch size x 24
                viewgrid = torch.zeros(
                    (B, 24, render_image_size,
                     render_image_size)).to(device)  # batch size x 24 x H x W
                for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(
                        zip(meshes_gt, _meshes_pred)):
                    # Maybe computationally expensive, but need to transform back to world space based on rendered image viewpoint
                    RT = RTs[b]
                    # Rotate 90 degrees about y-axis and invert
                    invRT = torch.inverse(RT.mm(rot_y_90))
                    invRT_no_rot = torch.inverse(RT)  # Just invert

                    cur_pred_mesh._verts_list[0] = project_verts(
                        cur_pred_mesh._verts_list[0], invRT)
                    sid = id_strs[b].split('-')[0]

                    # For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
                    if sid == '02958343':
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT_no_rot)
                    else:
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT)

                    for iid in range(len(render_RTs[b])):

                        R = render_RTs[b][iid][:3, :3].unsqueeze(0)
                        T = render_RTs[b][iid][:3, 3].unsqueeze(0)
                        cameras = OpenGLPerspectiveCameras(device=device,
                                                           R=R,
                                                           T=T)
                        silhouette_renderer = MeshRenderer(
                            rasterizer=MeshRasterizer(
                                cameras=cameras,
                                raster_settings=raster_settings),
                            shader=SoftSilhouetteShader(
                                blend_params=blend_params))

                        ref_image = (silhouette_renderer(
                            meshes_world=cur_gt_mesh, R=R, T=T) > 0).float()
                        image = (silhouette_renderer(
                            meshes_world=cur_pred_mesh, R=R, T=T) > 0).float()

                        #Add image silhouette to viewgrid
                        viewgrid[b, iid] = image[..., -1]
                        '''
                        import matplotlib.pyplot as plt
                        plt.subplot(1,2,1)
                        plt.imshow(ref_image[0,:,:,3].detach().cpu().numpy())
                        plt.subplot(1,2,2)
                        plt.imshow(image[0,:,:,3].detach().cpu().numpy())
                        plt.show()
                        '''

                        # MSE Loss between both silhouettes
                        silh_loss = torch.sum(
                            (image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
                        probability_map[b, iid] = silh_loss.detach()

                        total_silh_loss += silh_loss

                probability_map = probability_map / (torch.max(
                    probability_map, dim=1)[0].unsqueeze(1))  # Normalize

                probability_map = torch.nn.functional.softmax(
                    probability_map, dim=1).to(device)  # Softmax across images
                #nbv_idx = torch.argmax(probability_map, dim=1)  # Next-best view indices
                #nbv_imgs = _imgs[torch.arange(B), nbv_idx]  # Next-best view images

                # NOTE: Do a second forward pass through the model? This time for multi-view reconstruction
                # The input should be the first image and the next-best view
                #voxel_scores, meshes_pred = model(nbv_imgs, **model_kwargs)

                # Zhengyuan step loss_prediction
                predictor_loss = loss_predictor.train_batch(
                    viewgrid, probability_map, loss_pred_optim)
                if comm.is_main_process():
                    #wandb.log({'prediction module loss':predictor_loss})

                    if cp.t % 50 == 0:
                        print('{} predictor_loss: {}'.format(
                            cp.t, predictor_loss))

                    #Save checkpoint every t iteration
                    if cp.t % 500 == 0:
                        print(
                            'Saving loss prediction module at iter {}'.format(
                                cp.t))
                        os.makedirs('./output_prediction_module',
                                    exist_ok=True)
                        torch.save(
                            loss_predictor.state_dict(),
                            './output_prediction_module/prediction_module_' +
                            str(cp.t) + '.pth')

            cp.step()

            if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
                eval_and_save(model, loaders, optimizer, scheduler, cp)
        cp.step_epoch()
    eval_and_save(model, loaders, optimizer, scheduler, cp)

    if comm.is_main_process():
        logger.info("Evaluating on test set:")
        test_loader = build_data_loader(cfg, dataset, "test", multigpu=False)
        evaluate_test(model, test_loader)
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 256x256. To form the blended image we use 100 faces for each pixel. Refer to rasterize_meshes.py
# for an explanation of this parameter.
raster_settings = RasterizationSettings(image_size=256,
                                        blur_radius=np.log(1. / 1e-4 - 1.) *
                                        blend_params.sigma,
                                        faces_per_pixel=100,
                                        bin_size=0)

# Create a silhouette mesh renderer by composing a rasterizer and a shader.
silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(cameras=cameras,
                              raster_settings=raster_settings),
    shader=SoftSilhouetteShader(blend_params=blend_params))

# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(image_size=256,
                                        blur_radius=0.0,
                                        faces_per_pixel=1,
                                        bin_size=0)
# We can add a point light in front of the object.
lights = PointLights(device=device, location=((2.0, 2.0, -2.0), ))
phong_renderer = MeshRenderer(rasterizer=MeshRasterizer(
    cameras=cameras, raster_settings=raster_settings),
                              shader=HardPhongShader(device=device,
                                                     lights=lights))

# Select the viewpoint using spherical angles
viewpoint = [1.5, 240.0, 10.0]  # distance, elevation, azimuth, stuck..
Esempio n. 18
0
def main():
    # Set the cuda device
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
    mesh = load_objs_as_meshes(obj_paths, device=device)
    texture_image = mesh.textures.maps_padded()

    cameras = OpenGLRealPerspectiveCameras(
        focal_length=((K[0, 0], K[1, 1]), ),  # Nx2
        principal_point=((K[0, 2], K[1, 2]), ),  # Nx2
        x0=0,
        y0=0,
        w=H,
        h=H,  # HEIGHT,
        znear=ZNEAR,
        zfar=ZFAR,
        device=device,
    )

    # To blend the 100 faces we set a few parameters which control the opacity and the sharpness of
    # edges. Refer to blending.py for more details.
    blend_params = BlendParams(sigma=1e-4,
                               gamma=1e-4,
                               background_color=(0.0, 0.0, 0.0))

    # Define the settings for rasterization and shading. Here we set the output image to be of size
    # 640x640. To form the blended image we use 100 faces for each pixel. Refer to rasterize_meshes.py
    # for an explanation of this parameter.
    silhouette_raster_settings = RasterizationSettings(
        image_size=IMG_SIZE,  # longer side or scaled longer side
        blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
        faces_per_pixel=
        100,  # the nearest faces_per_pixel points along the z-axis.
        bin_size=0,
    )
    # Create a silhouette mesh renderer by composing a rasterizer and a shader.
    silhouette_renderer = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=silhouette_raster_settings),
        shader=SoftSilhouetteShader(blend_params=blend_params),
    )
    # We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
    phong_raster_settings = RasterizationSettings(image_size=IMG_SIZE,
                                                  blur_radius=0.0,
                                                  faces_per_pixel=1,
                                                  bin_size=0)
    # We can add a point light in front of the object.
    lights = PointLights(device=device, location=((2.0, 2.0, -2.0), ))
    phong_renderer = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=phong_raster_settings),
        shader=HardPhongShader(device=device,
                               cameras=cameras,
                               lights=lights,
                               blend_params=blend_params),
    )

    batch_R = torch.tensor(np.stack([R]), device=device,
                           dtype=torch.float32).permute(0, 2, 1)  # Bx3x3
    batch_T = torch.tensor(np.stack([t]), device=device,
                           dtype=torch.float32)  # Bx3

    silhouete = silhouette_renderer(meshes_world=mesh, R=batch_R, T=batch_T)
    image_ref = phong_renderer(meshes_world=mesh, R=batch_R, T=batch_T)
    # crop results
    silhouete = silhouete[:, :H, :W, :].cpu().numpy()
    image_ref = image_ref[:, :H, :W, :3].cpu().numpy()

    pred_images = image_ref

    opengl = mmcv.imread(osp.join("Render/OpenGL.png"), "color") / 255.0

    for i in range(pred_images.shape[0]):
        pred_mask = silhouete[i, :, :, 3].astype("float32")

        print("num rendered images", pred_images.shape[0])
        image = pred_images[i]

        diff_opengl = np.abs(opengl[:, :, ::-1].astype("float32") -
                             image.astype("float32"))
        print("image", image.shape, image.min(), image.max())

        print("dr mask area: ", pred_mask.sum())

        print_stat(pred_mask, "pred_mask")
        show_ims = [image, diff_opengl, opengl[:, :, ::-1]]
        show_titles = ["image", "diff_opengl", "opengl"]
        grid_show(show_ims, show_titles, row=1, col=3)
Esempio n. 19
0
    def run_on_image(self, image, id_str, gt_verts, gt_faces):
        deprocess = imagenet_deprocess(rescale_image=False)

        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(image.to(self.device))

        sid, mid, iid = id_str.split('-')
        iid = int(iid)

        #Transform vertex space
        metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed',
                                     sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RTs = metadata["extrinsics"].to(self.device)
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).to(RTs)

        mesh = meshes_pred[-1][0]
        #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
        #for the GT mesh
        invRT = torch.inverse(RTs[iid].mm(rot_y_90))
        invRT_no_rot = torch.inverse(RTs[iid])
        mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT)

        #Get look at view extrinsics
        render_metadata_path = os.path.join(
            'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid,
            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics'].to(self.device)

        verts, faces = mesh.get_mesh_verts_faces(0)
        verts_rgb = torch.ones_like(verts)[None]
        textures = Textures(verts_rgb=verts_rgb.to(self.device))
        mesh.textures = textures

        plt.figure(figsize=(10, 10))

        #Silhouette Renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )

        gt_verts = gt_verts.to(self.device)
        gt_faces = gt_faces.to(self.device)
        verts_rgb = torch.ones_like(gt_verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        #Invert without the rotation for the vehicle class
        if sid == '02958343':
            gt_verts = project_verts(gt_verts, invRT_no_rot.to(self.device))
        else:
            gt_verts = project_verts(gt_verts, invRT.to(self.device))
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures)

        probability_map = 0.01 * torch.ones((1, 24))
        viewgrid = torch.zeros(
            (1, 24, render_image_size, render_image_size)).to(self.device)
        fig = plt.figure(1)
        ax_pred = [fig.add_subplot(5, 5, i + 1) for i in range(24)]
        #fig = plt.figure(2)
        #ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)]

        for i in range(len(render_RTs)):
            if i == iid:  #Don't include current view
                continue

            R = render_RTs[i][:3, :3].unsqueeze(0)
            T = render_RTs[i][:3, 3].unsqueeze(0)
            cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)

            silhouette_renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftSilhouetteShader(blend_params=blend_params))

            ref_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) >
                         0).float()
            silhouette_image = (silhouette_renderer(
                meshes_world=mesh, R=R, T=T) > 0).float()

            # MSE Loss between both silhouettes
            silh_loss = torch.sum(
                (silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
            probability_map[0, i] = silh_loss.detach()

            viewgrid[0, i] = silhouette_image[..., -1]

            #ax_gt[i].imshow(ref_image[0,:,:,3].cpu().numpy())
            #ax_gt[i].set_title(i)

            ax_pred[i].imshow(silhouette_image[0, :, :, 3].cpu().numpy())
            ax_pred[i].set_title(i)

        img = image_to_numpy(deprocess(image[0]))
        #ax_gt[iid].imshow(img)
        ax_pred[iid].imshow(img)
        #fig = plt.figure(3)
        #ax = fig.add_subplot(111)
        #ax.imshow(img)

        pred_prob_map = self.loss_predictor(viewgrid)
        print('Highest actual loss: {}'.format(torch.argmax(probability_map)))
        print('Highest predicted loss: {}'.format(torch.argmax(pred_prob_map)))
        plt.show()
Esempio n. 20
0
    def init_differential_renderer(self):

        distance = 0.3
        R, T = look_at_view_transform(distance, 0, 0)
        cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)
        raster_settings = RasterizationSettings(image_size=self.opt.crop_size,
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                perspective_correct=True,
                                                cull_backfaces=True)
        silhouette_raster_settings = RasterizationSettings(
            image_size=self.opt.crop_size,
            blur_radius=0.0,
            faces_per_pixel=1,
            perspective_correct=True,
        )
        # Change specular color to green and change material shininess
        self.materials = Materials(
            device=self.device,
            ambient_color=[[1.0, 1.0, 1.0]],
            specular_color=[[0.0, 0.0, 0.0]],
            diffuse_color=[[1.0, 1.0, 1.0]],
        )
        bp = BlendParams(background_color=(0, 0, 0))  # black
        # bp = BlendParams(background_color=(1, 1, 1))  # white is default

        lights = PointLights(device=self.device, location=((0.0, 0.0, 2.0), ))

        self.renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras,
                                      raster_settings=raster_settings),
            shader=TexturedSoftPhongShader(
                # blend_params=bp,
                device=self.device,
                lights=lights,
                cameras=cameras,
            ))
        import cv2

        # segmentation_texture_map = cv2.imread(str(Path('resources') / 'part_segmentation_map_2048_gray_n_h.png'))[...,
        segmentation_texture_map = cv2.imread(
            str(Path('resources') /
                'Color_Map_Sag_symmetric.png'))[..., ::-1].astype(np.uint8)
        import matplotlib.pyplot as plt
        plt.imshow(segmentation_texture_map)
        plt.show()

        segmentation_texture_map = (torch.from_numpy(
            np.array(segmentation_texture_map))).unsqueeze(0).float()
        self.segmentation_3d_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras,
                                      raster_settings=raster_settings),
            shader=UVsCorrespondenceShader(blend_params=bp,
                                           device=self.device,
                                           cameras=cameras,
                                           colormap=segmentation_texture_map))

        # Create a silhouette mesh renderer by composing a rasterizer and a shader.
        self.silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras, raster_settings=silhouette_raster_settings),
            shader=SoftSilhouetteShader(
                blend_params=BlendParams(sigma=1e-10, gamma=1e-4)))
        self.negative_silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras,
                                      raster_settings=raster_settings),
            shader=SoftSilhouetteShader(
                blend_params=BlendParams(sigma=1e-10, gamma=1e-4)))
        self.texture_data = np.load('smpl_model/texture_data.npy',
                                    allow_pickle=True,
                                    encoding='latin1').item()
        self.verts_uvs1 = torch.tensor(self.texture_data['vt'],
                                       dtype=torch.float32).unsqueeze(0).cuda(
                                           self.device)
        self.faces_uvs1 = torch.tensor(
            self.texture_data['ft'].astype(np.int64),
            dtype=torch.int64).unsqueeze(0).cuda(self.device)
Esempio n. 21
0
def batch_render(
    verts,
    faces,
    faces_per_pixel=10,
    K=None,
    rot=None,
    trans=None,
    colors=None,
    color=(0.53, 0.53, 0.8),  # light_purple
    ambient_col=0.5,
    specular_col=0.2,
    diffuse_col=0.3,
    face_colors=None,
    # color = (0.74117647, 0.85882353, 0.65098039),  # light_blue
    image_sizes=None,
    out_res=512,
    bin_size=0,
    shading="soft",
    mode="rgb",
    blend_gamma=1e-4,
    min_depth=None,
):
    device = torch.device("cuda:0")
    K = K.to(device)
    width, height = image_sizes[0]
    out_size = int(max(image_sizes[0]))
    raster_settings = RasterizationSettings(
        image_size=out_size,
        blur_radius=0.0,
        faces_per_pixel=faces_per_pixel,
        bin_size=bin_size,
    )

    fx = K[:, 0, 0]
    fy = K[:, 1, 1]
    focals = torch.stack([fx, fy], 1)
    px = K[:, 0, 2]
    py = K[:, 1, 2]
    principal_point = torch.stack([width - px, height - py], 1)
    if rot is None:
        rot = torch.eye(3).unsqueeze(0).to(device)
    if trans is None:
        trans = torch.zeros(3).unsqueeze(0).to(device)
    cameras = PerspectiveCameras(
        device=device,
        focal_length=focals,
        principal_point=principal_point,
        image_size=[(out_size, out_size) for _ in range(len(verts))],
        R=rot,
        T=trans,
    )
    if mode == "rgb":

        lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
        lights = DirectionalLights(
            device=device,
            direction=((0.6, -0.6, -0.6), ),
            ambient_color=((ambient_col, ambient_col, ambient_col), ),
            diffuse_color=((diffuse_col, diffuse_col, diffuse_col), ),
            specular_color=((specular_col, specular_col, specular_col), ),
        )
        if shading == "soft":
            shader = SoftPhongShader(device=device,
                                     cameras=cameras,
                                     lights=lights)
        elif shading == "hard":
            shader = HardPhongShader(device=device,
                                     cameras=cameras,
                                     lights=lights)
        else:
            raise ValueError(
                f"Shading {shading} for mode rgb not in [sort|hard]")
    elif mode == "silh":
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        shader = SoftSilhouetteShader(blend_params=blend_params)
    elif shading == "faceidx":
        shader = FaceIdxShader()
    elif (mode == "facecolor") and (shading == "hard"):
        shader = FaceColorShader(face_colors=face_colors)
    elif (mode == "facecolor") and (shading == "soft"):
        shader = SoftFaceColorShader(face_colors=face_colors,
                                     blend_gamma=blend_gamma)
    else:
        raise ValueError(
            f"Unhandled mode {mode} and shading {shading} combination")

    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=raster_settings),
        shader=shader,
    )
    if min_depth is not None:
        verts = torch.cat([verts[:, :, :2], verts[:, :, 2:].clamp(min_depth)],
                          2)
    if mode == "rgb":
        if colors is None:
            colors = get_colors(verts, color)
        tex = textures.TexturesVertex(verts_features=colors)

        meshes = Meshes(verts=verts, faces=faces, textures=tex)
    elif mode in ["silh", "facecolor"]:
        meshes = Meshes(verts=verts, faces=faces)
    else:
        raise ValueError(f"Render mode {mode} not in [rgb|silh]")

    square_images = renderer(meshes, cameras=cameras)
    square_images = torch.flip(square_images, (1, 2))
    height_off = abs(int(width - height))
    if width > height:
        images = square_images[:, height_off:, :]
    else:
        images = square_images[:, :, height_off:]
    return images
    def run_on_image(self, image, id_str, gt_verts, gt_faces):
        deprocess = imagenet_deprocess(rescale_image=False)

        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(image)

        sid, mid, iid = id_str.split('-')
        iid = int(iid)

        #Transform vertex space
        metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed',
                                     sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RTs = metadata["extrinsics"]
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).to(RTs)

        mesh = meshes_pred[-1][0]
        #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
        #for the GT mesh
        invRT = torch.inverse(RTs[iid].mm(rot_y_90))
        invRT_no_rot = torch.inverse(RTs[iid])
        mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT.cpu())

        #Get look at view extrinsics
        render_metadata_path = os.path.join(
            'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid,
            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics']

        plt.figure(figsize=(10, 10))
        R = render_RTs[iid][:3, :3].unsqueeze(0)
        T = render_RTs[iid][:3, 3].unsqueeze(0)
        cameras = OpenGLPerspectiveCameras(R=R, T=T)

        #Phong Renderer
        lights = PointLights(location=[[0.0, 0.0, -3.0]])
        raster_settings = RasterizationSettings(image_size=256,
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                bin_size=0)
        phong_renderer = MeshRenderer(rasterizer=MeshRasterizer(
            cameras=cameras, raster_settings=raster_settings),
                                      shader=HardPhongShader(lights=lights))

        #Silhouette Renderer
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras,
                                      raster_settings=raster_settings),
            shader=SoftSilhouetteShader(blend_params=blend_params))

        verts, faces = mesh.get_mesh_verts_faces(0)
        verts_rgb = torch.ones_like(verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        mesh.textures = textures

        verts_rgb = torch.ones_like(gt_verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        #Invert without the rotation for the vehicle class
        if sid == '02958343':
            gt_verts = project_verts(gt_verts, invRT_no_rot.cpu())
        else:
            gt_verts = project_verts(gt_verts, invRT.cpu())
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures)

        img = image_to_numpy(deprocess(image[0]))
        mesh_image = phong_renderer(meshes_world=mesh, R=R, T=T)
        gt_silh_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) >
                         0).float()
        silhouette_image = (silhouette_renderer(meshes_world=mesh, R=R, T=T) >
                            0).float()

        plt.subplot(2, 2, 1)
        plt.imshow(img)
        plt.title('input image')
        plt.subplot(2, 2, 2)
        plt.imshow(mesh_image[0, ..., :3].cpu().numpy())
        plt.title('rendered mesh')
        plt.subplot(2, 2, 3)
        plt.imshow(gt_silh_image[0, ..., 3].cpu().numpy())
        plt.title('silhouette of gt mesh')
        plt.subplot(2, 2, 4)
        plt.imshow(silhouette_image[0, ..., 3].cpu().numpy())
        plt.title('silhouette of rendered mesh')

        plt.show()
        #plt.savefig('./output_demo/figures/'+id_str+'.png')

        vis_utils.visualize_prediction(id_str, img, mesh, self.output_dir)
raster_settings_soft = RasterizationSettings(
    image_size=128,
    blur_radius=np.log(1. / 1e-4 - 1.) * sigma,
    faces_per_pixel=50,
)

renderer_textured = MeshRenderer(rasterizer=MeshRasterizer(
    cameras=camera, raster_settings=raster_settings_soft),
                                 shader=SoftPhongShader(device=device,
                                                        cameras=camera,
                                                        lights=lights))

# Silhouette renderer
renderer_silhouette = MeshRenderer(rasterizer=MeshRasterizer(
    cameras=camera, raster_settings=raster_settings_silhouette),
                                   shader=SoftSilhouetteShader())

# Render silhouette images.  The 3rd channel of the rendering output is
# the alpha/silhouette channel
silhouette_images = renderer_silhouette(meshes, cameras=cameras, lights=lights)
target_silhouette = [silhouette_images[i, ..., 3] for i in range(num_views)]

# Visualize silhouette images
image_grid(silhouette_images.cpu().numpy(), rows=4, cols=5, rgb=False)
plt.show()


def visualize_prediction(predicted_mesh,
                         renderer=renderer_silhouette,
                         target_image=target_rgb[1],
                         title='',
Esempio n. 24
0
def main_function(experiment_directory, continue_from, iterations,
                  marching_cubes_resolution, regularize):

    device = torch.device('cuda:0')
    specs = ws.load_experiment_specifications(experiment_directory)

    print("Reconstruction from experiment description: \n" +
          ' '.join([str(elem) for elem in specs["Description"]]))

    data_source = specs["DataSource"]
    test_split_file = specs["TestSplit"]

    arch_encoder = __import__("lib.models." + specs["NetworkEncoder"],
                              fromlist=["ResNet"])
    arch_decoder = __import__("lib.models." + specs["NetworkDecoder"],
                              fromlist=["DeepSDF"])
    latent_size = specs["CodeLength"]

    encoder = arch_encoder.ResNet(latent_size,
                                  specs["Depth"],
                                  norm_type=specs["NormType"]).cuda()
    decoder = arch_decoder.DeepSDF(latent_size, **specs["NetworkSpecs"]).cuda()

    encoder = torch.nn.DataParallel(encoder)
    decoder = torch.nn.DataParallel(decoder)

    print("testing with {} GPU(s)".format(torch.cuda.device_count()))

    num_samp_per_scene = specs["SamplesPerScene"]
    with open(test_split_file, "r") as f:
        test_split = json.load(f)

    sdf_dataset_test = lib.data.RGBA2SDF(data_source,
                                         test_split,
                                         num_samp_per_scene,
                                         is_train=False,
                                         num_views=specs["NumberOfViews"])
    torch.manual_seed(int(time.time() * 1000.0))
    sdf_loader_test = data_utils.DataLoader(
        sdf_dataset_test,
        batch_size=1,
        shuffle=True,
        num_workers=1,
        drop_last=False,
    )

    num_scenes = len(sdf_loader_test)
    print("There are {} scenes".format(num_scenes))

    print('Loading epoch "{}"'.format(continue_from))

    ws.load_model_parameters(experiment_directory, continue_from, encoder,
                             decoder)
    encoder.eval()

    optimization_meshes_dir = os.path.join(args.experiment_directory,
                                           ws.reconstructions_subdir,
                                           str(continue_from))

    if not os.path.isdir(optimization_meshes_dir):
        os.makedirs(optimization_meshes_dir)

    for sdf_data, image, intrinsic, extrinsic, name in sdf_loader_test:

        out_name = name[0].split("/")[-1]
        # store input stuff
        image_filename = os.path.join(optimization_meshes_dir, out_name,
                                      "input.png")
        # skip if it is already there
        if os.path.exists(os.path.dirname(image_filename)):
            print(name[0], " exists already ")
            continue
        print('Reconstructing {}...'.format(out_name))

        if not os.path.exists(os.path.dirname(image_filename)):
            os.makedirs(os.path.dirname(image_filename))

        image_export = 255 * image[0].permute(1, 2, 0).cpu().numpy()
        imageio.imwrite(image_filename, image_export.astype(np.uint8))

        image_filename = os.path.join(optimization_meshes_dir, out_name,
                                      "input_silhouette.png")
        image_export = 255 * image[0].permute(1, 2, 0).cpu().numpy()[..., 3]
        imageio.imwrite(image_filename, image_export.astype(np.uint8))

        # get latent code from image
        latent = encoder(image)
        # get estimated mesh
        verts, faces, samples, next_indices = lib.mesh.create_mesh(
            decoder, latent, N=marching_cubes_resolution, output_mesh=True)

        # store raw output
        mesh_filename = os.path.join(optimization_meshes_dir, out_name,
                                     "predicted.ply")
        lib.mesh.write_verts_faces_to_file(verts, faces, mesh_filename)

        verts_dr = torch.tensor(verts[None, :, :].copy(),
                                dtype=torch.float32,
                                requires_grad=False).cuda()
        faces_dr = torch.tensor(faces[None, :, :].copy()).cuda()

        IMG_SIZE = image.shape[-1]
        K_cuda = torch.tensor(intrinsic[:, 0:3, 0:3]).float().cuda()
        R_cuda = torch.tensor(extrinsic[:, 0:3,
                                        0:3]).float().cuda().permute(0, 2, 1)
        t_cuda = torch.tensor(extrinsic[:, 0:3, 3]).float().cuda()
        lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
        cameras = PerspectiveCameras(device=device,
                                     focal_length=-K_cuda[:, 0, 0] /
                                     K_cuda[:, 0, 2],
                                     image_size=((IMG_SIZE, IMG_SIZE), ),
                                     R=R_cuda,
                                     T=t_cuda)
        raster_settings = RasterizationSettings(
            image_size=IMG_SIZE,
            blur_radius=0.000001,
            faces_per_pixel=1,
        )
        raster_settings_soft = RasterizationSettings(
            image_size=IMG_SIZE,
            blur_radius=np.log(1. / 1e-4 - 1.) * 1e-5,
            faces_per_pixel=25,
        )

        # instantiate renderers
        silhouette_renderer = MeshRenderer(rasterizer=MeshRasterizer(
            cameras=cameras, raster_settings=raster_settings_soft),
                                           shader=SoftSilhouetteShader())
        depth_renderer = MeshRasterizer(cameras=cameras,
                                        raster_settings=raster_settings)
        renderer = Renderer(silhouette_renderer,
                            depth_renderer,
                            image_size=IMG_SIZE)

        meshes = Meshes(verts_dr, faces_dr)
        verts_shape = meshes.verts_packed().shape
        verts_rgb = torch.full([1, verts_shape[0], 3],
                               0.5,
                               device=device,
                               requires_grad=False)
        meshes.textures = TexturesVertex(verts_features=verts_rgb)

        with torch.no_grad():

            normal_out, silhouette_out = renderer(meshes_world=meshes,
                                                  cameras=cameras,
                                                  lights=lights)

            image_out_export = 255 * silhouette_out.detach().cpu().numpy()[0]
            image_out_filename = os.path.join(optimization_meshes_dir,
                                              out_name,
                                              "predicted_silhouette.png")
            imageio.imwrite(image_out_filename,
                            image_out_export.astype(np.uint8))

            image_out_export = 255 * normal_out.detach().cpu().numpy()[0]
            image_out_filename = os.path.join(optimization_meshes_dir,
                                              out_name, "predicted.png")
            imageio.imwrite(image_out_filename,
                            image_out_export.astype(np.uint8))

        # load ground truth mesh for metrics
        mesh_filename = os.path.join(
            data_source, name[0].replace("samples", "meshes") + ".obj")

        mesh = trimesh.load(mesh_filename)
        vertices = torch.tensor(mesh.vertices).float().cuda()
        faces = torch.tensor(mesh.faces).float().cuda()

        vertices = vertices.unsqueeze(0)
        faces = faces.unsqueeze(0)
        meshes_gt = Meshes(vertices, faces)

        with torch.no_grad():
            normal_tgt, _ = renderer(meshes_world=meshes_gt,
                                     cameras=cameras,
                                     lights=lights)

        latent_for_optim = torch.tensor(latent, requires_grad=True)
        lr = 5e-5
        optimizer = torch.optim.Adam([latent_for_optim], lr=lr)

        decoder.eval()

        log_silhouette = []
        log_latent = []
        log_chd = []
        log_nc = []

        for e in range(iterations + 1):

            optimizer.zero_grad()

            # first create mesh
            verts, faces, samples, next_indices = lib.mesh.create_mesh_optim_fast(
                samples,
                next_indices,
                decoder,
                latent_for_optim,
                N=marching_cubes_resolution)

            # now assemble loss function
            xyz_upstream = torch.tensor(verts.astype(float),
                                        requires_grad=True,
                                        dtype=torch.float32,
                                        device=device)
            faces_upstream = torch.tensor(faces.astype(float),
                                          requires_grad=False,
                                          dtype=torch.float32,
                                          device=device)

            meshes_dr = Meshes(xyz_upstream.unsqueeze(0),
                               faces_upstream.unsqueeze(0))
            verts_shape = meshes_dr.verts_packed().shape
            verts_rgb = torch.full([1, verts_shape[0], 3],
                                   0.5,
                                   device=device,
                                   requires_grad=False)
            meshes_dr.textures = TexturesVertex(verts_features=verts_rgb)

            normal, silhouette = renderer(meshes_world=meshes_dr,
                                          cameras=cameras,
                                          lights=lights)
            # compute loss
            loss_silhouette = (torch.abs(silhouette -
                                         image[:, 3].cuda())).mean()

            # now store upstream gradients
            loss_silhouette.backward()
            dL_dx_i = xyz_upstream.grad
            # take care of weird stuff possibly happening
            dL_dx_i[torch.isnan(dL_dx_i)] = 0

            # log stuff
            with torch.no_grad():
                log_silhouette.append(loss_silhouette.detach().cpu().numpy())

                meshes_gt_pts = sample_points_from_meshes(meshes_gt)
                meshes_dr_pts = sample_points_from_meshes(meshes_dr)
                metric_chd, _ = chamfer_distance(meshes_gt_pts, meshes_dr_pts)
                log_chd.append(metric_chd.detach().cpu().numpy())

                log_nc.append(compute_normal_consistency(normal_tgt, normal))

                log_latent.append(
                    torch.mean(
                        (latent_for_optim).pow(2)).detach().cpu().numpy())

            # use vertices to compute full backward pass
            optimizer.zero_grad()
            xyz = torch.tensor(verts.astype(float),
                               requires_grad=True,
                               dtype=torch.float32,
                               device=torch.device('cuda:0'))
            latent_inputs = latent_for_optim.expand(xyz.shape[0], -1)
            #first compute normals
            pred_sdf = decoder(latent_inputs, xyz)
            loss_normals = torch.sum(pred_sdf)
            loss_normals.backward(retain_graph=True)
            normals = xyz.grad / torch.norm(xyz.grad, 2, 1).unsqueeze(-1)
            # now assemble inflow derivative
            optimizer.zero_grad()
            dL_ds_i = -torch.matmul(dL_dx_i.unsqueeze(1),
                                    normals.unsqueeze(-1)).squeeze(-1)
            # finally assemble full backward pass
            loss_backward = torch.sum(
                dL_ds_i * pred_sdf) + regularize * torch.mean(
                    (latent_for_optim).pow(2))
            loss_backward.backward()
            # and update params
            optimizer.step()

        # store all
        with torch.no_grad():
            verts, faces, samples, next_indices = lib.mesh.create_mesh_optim_fast(
                samples,
                next_indices,
                decoder,
                latent_for_optim,
                N=marching_cubes_resolution)
            mesh_filename = os.path.join(optimization_meshes_dir, out_name,
                                         "refined.ply")
            lib.mesh.write_verts_faces_to_file(verts, faces, mesh_filename)
            xyz_upstream = torch.tensor(verts.astype(float),
                                        requires_grad=True,
                                        dtype=torch.float32,
                                        device=device)
            faces_upstream = torch.tensor(faces.astype(float),
                                          requires_grad=False,
                                          dtype=torch.float32,
                                          device=device)

            meshes_dr = Meshes(xyz_upstream.unsqueeze(0),
                               faces_upstream.unsqueeze(0))
            verts_shape = meshes_dr.verts_packed().shape
            verts_rgb = torch.full([1, verts_shape[0], 3],
                                   0.5,
                                   device=device,
                                   requires_grad=False)
            meshes_dr.textures = TexturesVertex(verts_features=verts_rgb)

            normal, silhouette = renderer(meshes_world=meshes_dr,
                                          cameras=cameras,
                                          lights=lights)

            image_out_export = 255 * silhouette.detach().cpu().numpy()[0]
            image_out_filename = os.path.join(optimization_meshes_dir,
                                              out_name,
                                              "refined_silhouette.png")
            imageio.imwrite(image_out_filename,
                            image_out_export.astype(np.uint8))
            image_out_export = 255 * normal.detach().cpu().numpy()[0]
            image_out_filename = os.path.join(optimization_meshes_dir,
                                              out_name, "refined.png")
            imageio.imwrite(image_out_filename,
                            image_out_export.astype(np.uint8))

        log_filename = os.path.join(optimization_meshes_dir, out_name,
                                    "log_silhouette.npy")
        np.save(log_filename, log_silhouette)
        log_filename = os.path.join(optimization_meshes_dir, out_name,
                                    "log_chd.npy")
        np.save(log_filename, log_chd)
        log_filename = os.path.join(optimization_meshes_dir, out_name,
                                    "log_nc.npy")
        np.save(log_filename, log_nc)

        compute_normal_consistency(normal_tgt, normal)

        log_filename = os.path.join(optimization_meshes_dir, out_name,
                                    "log_latent.npy")
        np.save(log_filename, log_latent)
        print('Done with refinement.')
        print('Improvement in CHD {:.2f} %'.format(
            100 * (log_chd[0] - log_chd[-1]) / log_chd[0]))
        print('Improvement in NC {:.2f} %'.format(
            100 * (log_nc[-1] - log_nc[0]) / log_nc[0]))
def generate_cow_renders(num_views: int = 40,
                         data_dir: str = DATA_DIR,
                         azimuth_range: float = 180):
    """
    This function generates `num_views` renders of a cow mesh.
    The renders are generated from viewpoints sampled at uniformly distributed
    azimuth intervals. The elevation is kept constant so that the camera's
    vertical position coincides with the equator.

    For a more detailed explanation of this code, please refer to the
    docs/tutorials/fit_textured_mesh.ipynb notebook.

    Args:
        num_views: The number of generated renders.
        data_dir: The folder that contains the cow mesh files. If the cow mesh
            files do not exist in the folder, this function will automatically
            download them.

    Returns:
        cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
            images are rendered.
        images: A tensor of shape `(num_views, height, width, 3)` containing
            the rendered images.
        silhouettes: A tensor of shape `(num_views, height, width)` containing
            the rendered silhouettes.
    """

    # set the paths

    # download the cow mesh if not done before
    cow_mesh_files = [
        os.path.join(data_dir, fl)
        for fl in ("cow.obj", "cow.mtl", "cow_texture.png")
    ]
    if any(not os.path.isfile(f) for f in cow_mesh_files):
        os.makedirs(data_dir, exis_ok=True)
        os.system(
            f"wget -P {data_dir} " +
            "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj")
        os.system(
            f"wget -P {data_dir} " +
            "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl")
        os.system(
            f"wget -P {data_dir} " +
            "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
        )

    # Setup
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")

    # Load obj file
    obj_filename = os.path.join(data_dir, "cow.obj")
    mesh = load_objs_as_meshes([obj_filename], device=device)

    # We scale normalize and center the target mesh to fit in a sphere of radius 1
    # centered at (0,0,0). (scale, center) will be used to bring the predicted mesh
    # to its original center and scale.  Note that normalizing the target mesh,
    # speeds up the optimization but is not necessary!
    verts = mesh.verts_packed()
    N = verts.shape[0]
    center = verts.mean(0)
    scale = max((verts - center).abs().max(0)[0])
    mesh.offset_verts_(-(center.expand(N, 3)))
    mesh.scale_verts_((1.0 / float(scale)))

    # Get a batch of viewing angles.
    elev = torch.linspace(0, 0, num_views)  # keep constant
    azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0

    # Place a point light in front of the object. As mentioned above, the front of
    # the cow is facing the -z direction.
    lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

    # Initialize an OpenGL perspective camera that represents a batch of different
    # viewing angles. All the cameras helper methods support mixed type inputs and
    # broadcasting. So we can view the camera from the a distance of dist=2.7, and
    # then specify elevation and azimuth angles for each viewpoint as tensors.
    R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
    cameras = FoVPerspectiveCameras(device=device, R=R, T=T)

    # Define the settings for rasterization and shading. Here we set the output
    # image to be of size 128X128. As we are rendering images for visualization
    # purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to
    # rasterize_meshes.py for explanations of these parameters.  We also leave
    # bin_size and max_faces_per_bin to their default values of None, which sets
    # their values using huristics and ensures that the faster coarse-to-fine
    # rasterization method is used.  Refer to docs/notes/renderer.md for an
    # explanation of the difference between naive and coarse-to-fine rasterization.
    raster_settings = RasterizationSettings(image_size=128,
                                            blur_radius=0.0,
                                            faces_per_pixel=1)

    # Create a phong renderer by composing a rasterizer and a shader. The textured
    # phong shader will interpolate the texture uv coordinates for each vertex,
    # sample from a texture image and apply the Phong lighting model
    blend_params = BlendParams(sigma=1e-4,
                               gamma=1e-4,
                               background_color=(0.0, 0.0, 0.0))
    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=raster_settings),
        shader=SoftPhongShader(device=device,
                               cameras=cameras,
                               lights=lights,
                               blend_params=blend_params),
    )

    # Create a batch of meshes by repeating the cow mesh and associated textures.
    # Meshes has a useful `extend` method which allows us do this very easily.
    # This also extends the textures.
    meshes = mesh.extend(num_views)

    # Render the cow mesh from each viewing angle
    target_images = renderer(meshes, cameras=cameras, lights=lights)

    # Rasterization settings for silhouette rendering
    sigma = 1e-4
    raster_settings_silhouette = RasterizationSettings(
        image_size=128,
        blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,
        faces_per_pixel=50)

    # Silhouette renderer
    renderer_silhouette = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=raster_settings_silhouette),
        shader=SoftSilhouetteShader(),
    )

    # Render silhouette images.  The 3rd channel of the rendering output is
    # the alpha/silhouette channel
    silhouette_images = renderer_silhouette(meshes,
                                            cameras=cameras,
                                            lights=lights)

    # binary silhouettes
    silhouette_binary = (silhouette_images[..., 3] > 1e-4).float()

    return cameras, target_images[..., :3], silhouette_binary
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device,
                  loss_fn):

    if comm.is_main_process():
        wandb.init(project='MeshRCNN', config=cfg, name='meshrcnn')

    Timer.timing = False
    iteration_timer = Timer("Iteration")

    # model.parameters() is surprisingly expensive at 150ms, so cache it
    if hasattr(model, "module"):
        params = list(model.module.parameters())
    else:
        params = list(model.parameters())
    loss_moving_average = cp.data.get("loss_moving_average", None)

    # Zhengyuan modification
    loss_predictor = LossPredictionModule().to(device)
    loss_pred_optim = torch.optim.SGD(loss_predictor.parameters(),
                                      lr=1e-3,
                                      momentum=0.9)

    while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
        if comm.is_main_process():
            logger.info("Starting epoch %d / %d" %
                        (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))

        # When using a DistributedSampler we need to manually set the epoch so that
        # the data is shuffled differently at each epoch
        for loader in loaders.values():
            if hasattr(loader.sampler, "set_epoch"):
                loader.sampler.set_epoch(cp.epoch)

        # Config settings for renderer
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).float().to(device)

        for i, batch in enumerate(loaders["train"]):
            if i == 0:
                iteration_timer.start()
            else:
                iteration_timer.tick()

            batch = loaders["train"].postprocess(batch, device)
            if dataset == 'MeshVoxMulti':
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt, id_strs, _imgs, render_RTs, RTs = batch
            else:
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch

            # NOTE: _imgs contains all of the other images in belonging to this model
            # We have to select the next-best-view from that list of images

            num_infinite_params = 0
            for p in params:
                num_infinite_params += (torch.isfinite(
                    p.data) == 0).sum().item()
            if num_infinite_params > 0:
                msg = "ERROR: Model has %d non-finite params (before forward!)"
                logger.info(msg % num_infinite_params)
                return

            model_kwargs = {}
            if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS:
                model_kwargs["voxel_only"] = True
            with Timer("Forward"):
                voxel_scores, meshes_pred = model(imgs, **model_kwargs)

            num_infinite = 0
            for cur_meshes in meshes_pred:
                cur_verts = cur_meshes.verts_packed()
                num_infinite += (torch.isfinite(cur_verts) == 0).sum().item()
            if num_infinite > 0:
                logger.info("ERROR: Got %d non-finite verts" % num_infinite)
                return

            total_silh_loss = torch.tensor(
                0.)  # Total silhouette loss, to be added to "loss" below
            # Voxel only training for first few iterations
            if not meshes_gt is None and not model_kwargs.get(
                    "voxel_only", False):
                import pdb
                pdb.set_trace()
                _meshes_pred = meshes_pred[-1].clone()
                _meshes_gt = meshes_gt[-1].clone()

                # Render masks from predicted mesh for each view
                # GT probability map to supervise prediction module
                B = len(meshes_gt)
                probability_map = 0.01 * torch.ones((B, 24))  # batch size x 24
                for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(
                        zip(meshes_gt, _meshes_pred)):
                    # Maybe computationally expensive, but need to transform back to world space based on rendered image viewpoint
                    RT = RTs[b]
                    # Rotate 90 degrees about y-axis and invert
                    invRT = torch.inverse(RT.mm(rot_y_90))
                    invRT_no_rot = torch.inverse(RT)  # Just invert

                    cur_pred_mesh._verts_list[0] = project_verts(
                        cur_pred_mesh._verts_list[0], invRT)
                    sid = id_strs[b].split('-')[0]

                    # For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
                    if sid == '02958343':
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT_no_rot)
                    else:
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT)

                    for iid in range(len(render_RTs[b])):

                        R = render_RTs[b][iid][:3, :3].unsqueeze(0)
                        T = render_RTs[b][iid][:3, 3].unsqueeze(0)
                        cameras = OpenGLPerspectiveCameras(device=device,
                                                           R=R,
                                                           T=T)
                        silhouette_renderer = MeshRenderer(
                            rasterizer=MeshRasterizer(
                                cameras=cameras,
                                raster_settings=raster_settings),
                            shader=SoftSilhouetteShader(
                                blend_params=blend_params))

                        ref_image = silhouette_renderer(
                            meshes_world=cur_gt_mesh, R=R, T=T)
                        image = silhouette_renderer(meshes_world=cur_pred_mesh,
                                                    R=R,
                                                    T=T)
                        '''
                        import matplotlib.pyplot as plt
                        plt.subplot(1,2,1)
                        plt.imshow(ref_image[0,:,:,3].detach().cpu().numpy())
                        plt.subplot(1,2,2)
                        plt.imshow(image[0,:,:,3].detach().cpu().numpy())
                        plt.show()
                        '''

                        # MSE Loss between both silhouettes
                        silh_loss = torch.sum(
                            (image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
                        probability_map[b, iid] = silh_loss.detach()

                        total_silh_loss += silh_loss

                probability_map = torch.nn.functional.softmax(
                    probability_map, dim=1)  # Softmax across images
                nbv_idx = torch.argmax(probability_map,
                                       dim=1)  # Next-best view indices
                nbv_imgs = _imgs[torch.arange(B),
                                 nbv_idx]  # Next-best view images

                # NOTE: Do a second forward pass through the model? This time for multi-view reconstruction
                # The input should be the first image and the next-best view
                #voxel_scores, meshes_pred = model(nbv_imgs, **model_kwargs)

            loss, losses = None, {}
            if num_infinite == 0:
                loss, losses = loss_fn(voxel_scores, meshes_pred, voxels_gt,
                                       (points_gt, normals_gt))
            skip = loss is None
            if loss is None or (torch.isfinite(loss) == 0).sum().item() > 0:
                logger.info("WARNING: Got non-finite loss %f" % loss)
                skip = True

            # Add silhouette loss to total loss
            silh_weight = 1.0  # TODO: Add a weight for the silhouette loss?
            loss = loss + total_silh_loss * silh_weight
            losses['silhouette'] = total_silh_loss

            if model_kwargs.get("voxel_only", False):
                for k, v in losses.items():
                    if k != "voxel":
                        losses[k] = 0.0 * v

            if loss is not None and cp.t % cfg.SOLVER.LOGGING_PERIOD == 0:
                if comm.is_main_process():
                    cp.store_metric(loss=loss.item())
                    str_out = "Iteration: %d, epoch: %d, lr: %.5f," % (
                        cp.t,
                        cp.epoch,
                        optimizer.param_groups[0]["lr"],
                    )
                    for k, v in losses.items():
                        str_out += "  %s loss: %.4f," % (k, v.item())
                    str_out += "  total loss: %.4f," % loss.item()

                    # memory allocaged
                    if torch.cuda.is_available():
                        max_mem_mb = torch.cuda.max_memory_allocated(
                        ) / 1024.0 / 1024.0
                        str_out += " mem: %d" % max_mem_mb

                    if len(meshes_pred) > 0:
                        mean_V = meshes_pred[-1].num_verts_per_mesh().float(
                        ).mean().item()
                        mean_F = meshes_pred[-1].num_faces_per_mesh().float(
                        ).mean().item()
                        str_out += ", mesh size = (%d, %d)" % (mean_V, mean_F)
                    logger.info(str_out)

                # Log with Weights & Biases, comment out if not installed
                wandb.log(losses)

            if loss_moving_average is None and loss is not None:
                loss_moving_average = loss.item()

            # Skip backprop for this batch if the loss is above the skip factor times
            # the moving average for losses
            if loss is None:
                pass
            elif loss.item(
            ) > cfg.SOLVER.SKIP_LOSS_THRESH * loss_moving_average:
                logger.info("Warning: Skipping loss %f on GPU %d" %
                            (loss.item(), comm.get_rank()))
                cp.store_metric(losses_skipped=loss.item())
                skip = True
            else:
                # Update the moving average of our loss
                gamma = cfg.SOLVER.LOSS_SKIP_GAMMA
                loss_moving_average *= gamma
                loss_moving_average += (1.0 - gamma) * loss.item()
                cp.store_data("loss_moving_average", loss_moving_average)

            if skip:
                logger.info("Dummy backprop on GPU %d" % comm.get_rank())
                loss = 0.0 * sum(p.sum() for p in params)

            # Backprop and step
            scheduler.step()
            optimizer.zero_grad()
            with Timer("Backward"):
                loss.backward()

            # Zhengyuan step loss_prediction
            loss_predictor.train_batch(image, probability_map, loss_pred_optim)

            # When training with normal loss, sometimes I get NaNs in gradient that
            # cause the model to explode. Check for this before performing a gradient
            # update. This is safe in mult-GPU since gradients have already been
            # summed, so each GPU has the same gradients.
            num_infinite_grad = 0
            for p in params:
                num_infinite_grad += (torch.isfinite(p.grad) == 0).sum().item()
            if num_infinite_grad == 0:
                optimizer.step()
            else:
                msg = "WARNING: Got %d non-finite elements in gradient; skipping update"
                logger.info(msg % num_infinite_grad)
            cp.step()

            if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
                eval_and_save(model, loaders, optimizer, scheduler, cp)
        cp.step_epoch()
    eval_and_save(model, loaders, optimizer, scheduler, cp)

    if comm.is_main_process():
        logger.info("Evaluating on test set:")
        test_loader = build_data_loader(cfg, dataset, "test", multigpu=False)
        evaluate_test(model, test_loader)
      def run_on_image(self, imgs, id_strs, meshes_gt, render_RTs, RTs):
        deprocess = imagenet_deprocess(rescale_image=False)

        #Silhouette Renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size, 
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
            faces_per_pixel=50, 
        )   

        rot_y_90 = torch.tensor([[0, 0, 1, 0], 
                                    [0, 1, 0, 0], 
                                    [-1, 0, 0, 0], 
                                    [0, 0, 0, 1]]).to(RTs) 
        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(imgs)

        B,_,H,W = imgs.shape
        probability_map = 0.01 * torch.ones((B, 24)).to(self.device)
        viewgrid = torch.zeros((B,24,render_image_size,render_image_size)).to(device) # batch size x 24 x H x W
        _meshes_pred = meshes_pred[-1]

        for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(zip(meshes_gt, _meshes_pred)):
            sid = id_strs[b].split('-')[0]

            RT = RTs[b]

            #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
            #for the GT mesh
            invRT = torch.inverse(RT.mm(rot_y_90))
            invRT_no_rot = torch.inverse(RT)

            cur_pred_mesh._verts_list[0] = project_verts(cur_pred_mesh._verts_list[0], invRT)
            #Invert without the rotation for the vehicle class
            if sid == '02958343':
                cur_gt_mesh._verts_list[0] = project_verts(
                    cur_gt_mesh._verts_list[0], invRT_no_rot)
            else:
                cur_gt_mesh._verts_list[0] = project_verts(
                    cur_gt_mesh._verts_list[0], invRT)

            '''
            plt.figure(figsize=(10, 10))

            fig = plt.figure(1)
            ax_pred = [fig.add_subplot(5,5,i+1) for i in range(24)]
            fig = plt.figure(2)
            ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)]
            '''

            for iid in range(len(render_RTs)):

                R = render_RTs[b][iid][:3,:3].unsqueeze(0)
                T = render_RTs[b][iid][:3,3].unsqueeze(0)
                cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)
                silhouette_renderer = MeshRenderer(
                    rasterizer=MeshRasterizer(
                    cameras=cameras, 
                    raster_settings=raster_settings
                    ),  
                shader=SoftSilhouetteShader(blend_params=blend_params)
                )   

                ref_image        = (silhouette_renderer(meshes_world=cur_gt_mesh, R=R, T=T)>0).float()
                silhouette_image = (silhouette_renderer(meshes_world=cur_pred_mesh, R=R, T=T)>0).float()

                #Add image silhouette to viewgrid
                viewgrid[b,iid] = silhouette_image[...,-1]

                # MSE Loss between both silhouettes
                silh_loss = torch.sum((silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3]) ** 2)
                probability_map[b, iid] = silh_loss.detach()

                '''
                ax_gt[iid].imshow(ref_image[0,:,:,3].cpu().numpy())
                ax_gt[iid].set_title(iid)
                ax_pred[iid].imshow(silhouette_image[0,:,:,3].cpu().numpy())
                ax_pred[iid].set_title(iid)
                '''

            '''
            img = image_to_numpy(deprocess(imgs[b]))
            ax_gt[iid].imshow(img)
            ax_pred[iid].imshow(img)
            fig = plt.figure(3)
            ax = fig.add_subplot(111)
            ax.imshow(img)

            #plt.show()
            '''

        probability_map = probability_map/(torch.max(probability_map, dim=1)[0].unsqueeze(1)) # Normalize
        probability_map = torch.nn.functional.softmax(probability_map, dim=1) # Softmax across images
        pred_prob_map = self.loss_predictor(viewgrid)

        gt_max = torch.argmax(probability_map, dim=1)
        pred_max = torch.argmax(pred_prob_map, dim=1)

        #print('--'*30)
        #print('Item: {}'.format(id_str))
        #print('Highest actual loss: {}'.format(gt_max))
        #print('Highest predicted loss: {}'.format(pred_max))

        #print('GT prob map: {}'.format(probability_map.squeeze()))
        #print('Pred prob map: {}'.format(pred_prob_map.squeeze()))

        correct = torch.sum(pred_max == gt_max).item()
        total   = len(pred_prob_map)

        return correct, total