Ejemplo n.º 1
0
 def _get_lights(self, lightning):
     ambient = lightning["ambient"]/2.+0.5
     diffuse = lightning["diffuse"]/2.+0.5
     direction = -lightning["direction"]
     # TODO: DEBUG
     # direction = torch.tensor([[ 0.2878, -0.1185, -0.9503]]*64)
     ambient_color = ambient.repeat(1,3)
     diffuse_color = diffuse.repeat(1,3)
     b, _  = ambient.shape
     specular_color=torch.zeros((b,3))
     lights = DirectionalLights(ambient_color=ambient_color, diffuse_color=diffuse_color, specular_color=specular_color, direction=direction)
     return lights.to(self.device)
Ejemplo n.º 2
0
    def set_renderer(self):
        cameras = OpenGLPerspectiveCameras(device=self.cuda_device,
                                           degrees=True,
                                           fov=VIEW['fov'],
                                           znear=VIEW['znear'],
                                           zfar=VIEW['zfar'])

        raster_settings = RasterizationSettings(image_size=VIEW['viewport'][0],
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                bin_size=0)

        lights = DirectionalLights(
            device=self.cuda_device,
            direction=((-40, 200, 100), ),
            ambient_color=((0.5, 0.5, 0.5), ),
            diffuse_color=((0.5, 0.5, 0.5), ),
            specular_color=((0.0, 0.0, 0.0), ),
        )

        self.renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras,
                                      raster_settings=raster_settings),
            shader=TexturedSoftPhongShader(device=self.cuda_device,
                                           cameras=cameras,
                                           lights=lights))
Ejemplo n.º 3
0
 def setup(self, device):
     R, T = look_at_view_transform(self.viewpoint_distance,
                                   self.viewpoint_elevation,
                                   self.viewpoint_azimuth,
                                   device=device)
     cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
     raster_settings = RasterizationSettings(
         image_size=self.opt.fast_image_size,
         blur_radius=self.opt.raster_blur_radius,
         faces_per_pixel=self.opt.raster_faces_per_pixel,
     )
     rasterizer = MeshRasterizer(cameras=cameras,
                                 raster_settings=raster_settings)
     lights = PointLights(device=device,
                          location=[self.opt.lights_location])
     lights = DirectionalLights(device=device,
                                direction=[self.opt.lights_direction])
     shader = SoftPhongShader(
         device=device,
         cameras=cameras,
         lights=lights,
         blend_params=BlendParams(
             self.opt.blend_params_sigma,
             self.opt.blend_params_gamma,
             self.opt.blend_params_background_color,
         ),
     )
     self.renderer = MeshRenderer(
         rasterizer=rasterizer,
         shader=shader,
     )
Ejemplo n.º 4
0
    def init_renderer(self):
        # nsh_face_mesh = meshio.Mesh('data/mesh/nsh_bfm_face.obj')
        # self.nsh_face_tri = torch.from_numpy(nsh_face_mesh.triangles).type(
        #     torch.int64).to(self.device)

        R, T = look_at_view_transform(10, 0, 0)
        cameras = OpenGLPerspectiveCameras(znear=0.001,
                                           zfar=30.0,
                                           aspect_ratio=1.0,
                                           fov=12.5936,
                                           degrees=True,
                                           R=R,
                                           T=T,
                                           device=self.device)
        raster_settings = RasterizationSettings(image_size=self.im_size,
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                bin_size=0,
                                                cull_backfaces=True)
        self.rasterizer = MeshRasterizer(cameras=cameras,
                                         raster_settings=raster_settings)
        lights = DirectionalLights(device=self.device)
        shader = TexturedSoftPhongShader(device=self.device,
                                         cameras=cameras,
                                         lights=lights)
        self.renderer = MeshRenderer(rasterizer=self.rasterizer, shader=shader)
Ejemplo n.º 5
0
    def __init__(self, cfgs):
        super().__init__()
        self.device = cfgs.get('device', 'cpu')
        self.image_size = cfgs.get('image_size', 64)
        self.min_depth = cfgs.get('min_depth', 0.9)
        self.max_depth = cfgs.get('max_depth', 1.1)
        self.rot_center_depth = cfgs.get('rot_center_depth',
                                         (self.min_depth + self.max_depth) / 2)
        self.border_depth = cfgs.get(
            'border_depth', 0.3 * self.min_depth + 0.7 * self.max_depth)
        self.fov = cfgs.get('fov', 10)

        #### camera intrinsics
        #             (u)   (x)
        #    d * K^-1 (v) = (y)
        #             (1)   (z)

        ## renderer for visualization
        R = [[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]
        R = torch.FloatTensor(R).to(self.device)
        t = torch.zeros(1, 3, dtype=torch.float32).to(self.device)
        fx = (self.image_size - 1) / 2 / (math.tan(
            self.fov / 2 * math.pi / 180))
        fy = (self.image_size - 1) / 2 / (math.tan(
            self.fov / 2 * math.pi / 180))
        cx = (self.image_size - 1) / 2
        cy = (self.image_size - 1) / 2
        K = [[fx, 0., cx], [0., fy, cy], [0., 0., 1.]]
        K = torch.FloatTensor(K).to(self.device)
        self.inv_K = torch.inverse(K).unsqueeze(0)
        self.K = K.unsqueeze(0)

        # Initialize an OpenGL perspective camera.
        R = look_at_rotation(((0, 0, 0), ),
                             at=((0, 0, 1), ),
                             up=((0, -1, 0), ))
        cameras = OpenGLPerspectiveCameras(device=self.device,
                                           fov=self.fov,
                                           R=R)
        lights = DirectionalLights(
            ambient_color=((1.0, 1.0, 1.0), ),
            diffuse_color=((0.0, 0.0, 0.0), ),
            specular_color=((0.0, 0.0, 0.0), ),
            direction=((0, 1, 0), ),
            device=self.device,
        )
        raster_settings = RasterizationSettings(
            image_size=self.image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
        )
        self.rasterizer_torch = MeshRasterizer(cameras=cameras,
                                               raster_settings=raster_settings)
def load_lights(moon_light: MoonLight):
    device = torch.device('cuda')
    direction = moon_light.position_light[:3]
    ambient_color = moon_light.ambient_light[:3]
    diffuse_color = moon_light.diffuse_light[:3]
    specular_color = (0.0, 0.0, 0.0)

    return DirectionalLights(device=device,
                             direction=(direction, ),
                             ambient_color=(ambient_color, ),
                             diffuse_color=(diffuse_color, ),
                             specular_color=(specular_color, ))
Ejemplo n.º 7
0
def build_renderer(_image_size):
    # Initialize an OpenGL perspective camera.
    cameras = OpenGLPerspectiveCameras(device=DEVICE, degrees=True, fov=FOV, znear=1e-4, zfar=100)

    raster_settings = RasterizationSettings(image_size=_image_size, blur_radius=0.0, faces_per_pixel=1, bin_size=0)

    lights = DirectionalLights(device=DEVICE, direction=((-40, 200, 100),), ambient_color=((0.5, 0.5, 0.5),),
                               diffuse_color=((0.5, 0.5, 0.5),), specular_color=((0.0, 0.0, 0.0),), )

    renderer = MeshRenderer(rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
                            shader=TexturedSoftPhongShader(device=DEVICE, cameras=cameras, lights=lights))

    return renderer
Ejemplo n.º 8
0
 def __init__(self, rasterizer, shader_rgb, shader_mask, cameras):
   super(Renderer, self).__init__()
   self.cameras = cameras
   self.rasterizer = rasterizer
   self.shader_rgb = shader_rgb
   self.shader_mask = shader_mask
   self.lights = DirectionalLights(
           ambient_color=((1, 1, 1),),
           diffuse_color=((0., 0., 0.),),
           specular_color=((0., 0., 0.),),
           direction=((0, 0, 1),),
           device='cuda',
       )
   self.materials = Materials(device='cuda')
Ejemplo n.º 9
0
    def rendering(self, light_params, coeffs, vertices, gen_uvmaps,
                  face_model):
        ambient_color = torch.clamp(0.5 + 0.5 * light_params[:, 0:3], 0, 1)
        diffuse_color = torch.clamp(0.5 + 0.5 * light_params[:, 3:6], 0, 1)
        specular_color = torch.clamp(0.2 + 0.2 * light_params[:, 6:9], 0, 1)
        direction = light_params[:, 9:12]
        directions = torch.cat([
            direction, direction *
            torch.tensor([[-1, 1, 1]], dtype=torch.float, device=self.device)
        ],
                               dim=0)
        lights = DirectionalLights(ambient_color=ambient_color.repeat(2, 1),
                                   diffuse_color=diffuse_color.repeat(2, 1),
                                   specular_color=specular_color.repeat(2, 1),
                                   direction=directions,
                                   device=self.device)
        self.renderer.shader.lights = lights

        _, _, _, angles, _, trans = utils.split_bfm09_coeff(coeffs)

        reflect_angles = torch.cat([
            angles, angles *
            torch.tensor([[1, -1, -1]], dtype=torch.float, device=self.device)
        ],
                                   dim=0)
        reflect_trans = torch.cat([
            trans, trans *
            torch.tensor([[-1, 1, 1]], dtype=torch.float, device=self.device)
        ],
                                  dim=0)
        rotated_vert = self.rotate_vert(vertices.repeat(2, 1, 1),
                                        reflect_angles, reflect_trans)

        fliped_uv = torch.flip(gen_uvmaps / 2 + 0.5,
                               (2, 3)).repeat(2, 1, 1, 1).permute(0, 2, 3, 1)
        texture = Textures(
            maps=fliped_uv,
            faces_uvs=self.meshes[face_model].textures.faces_uvs_padded(),
            verts_uvs=self.meshes[face_model].textures.verts_uvs_padded())
        meshes = Meshes(rotated_vert, self.meshes[face_model].faces_padded(),
                        texture)

        renders = self.renderer(meshes)

        renders[..., :3] = renders[..., :3] * 2 - 1
        renders[..., -1] = (renders[..., -1] > 0).float()
        renders = renders.permute(0, 3, 1, 2).contiguous()

        return renders
Ejemplo n.º 10
0
    def setup_light(self, **kwargs):
        if self.light_type == 'point':
            if 'position' in kwargs:
                position = kwargs['position']
                self.light = PointLights(ambient_color=self.ambient_color,
                                         diffuse_color=self.diffuse_color,
                                         specular_color=self.specular_color,
                                         location=[position],
                                         device='cuda')

        if self.light_type == 'directional':
            if 'direction' in kwargs:
                direction = kwargs['direction']
                self.light = DirectionalLights(
                    ambient_color=self.ambient_color,
                    diffuse_color=self.diffuse_color,
                    specular_color=self.specular_color,
                    direction=direction,
                    device='cuda')
Ejemplo n.º 11
0
def differentiable_face_render(vert, tri, colors, bg_img, h, w):
    """
    vert: (N, nver, 3)
    tri: (ntri, 3)
    colors: (N, nver. 3)
    bg_img: (N, 3, H, W)
    """
    assert h == w
    N, nver, _ = vert.shape
    ntri = tri.shape[0]
    tri = torch.from_numpy(tri).to(vert.device).unsqueeze(0).expand(N, ntri, 3)
    # Transform to Pytorch3D world space
    vert_t = vert + torch.tensor((0.5, 0.5, 0), dtype=torch.float, device=vert.device).view(1, 1, 3)
    vert_t = vert_t * torch.tensor((-1, 1, -1), dtype=torch.float, device=vert.device).view(1, 1, 3)
    mesh_torch = Meshes(verts=vert_t, faces=tri, textures=TexturesVertex(verts_features=colors))
    # Render
    R = look_at_rotation(camera_position=((0, 0, -300),)).to(vert.device).expand(N, 3, 3)
    T = torch.tensor((0, 0, 300), dtype=torch.float, device=vert.device).view(1, 3).expand(N, 3)
    focal = torch.tensor((2. / float(w), 2. / float(h)), dtype=torch.float, device=vert.device).view(1, 2).expand(N, 2)
    cameras = OrthographicCameras(device=vert.device, R=R, T=T, focal_length=focal)
    raster_settings = RasterizationSettings(image_size=h, blur_radius=0.0, faces_per_pixel=1)
    lights = DirectionalLights(ambient_color=((1., 1., 1.),), diffuse_color=((0., 0., 0.),),
                               specular_color=((0., 0., 0.),), direction=((0, 0, 1),), device=vert.device)
    blend_params = BlendParams(background_color=(0, 0, 0))
    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        ),
        shader=SoftPhongShader(
            device=vert.device,
            cameras=cameras,
            lights=lights,
            blend_params=blend_params
        )
    )
    images = renderer(mesh_torch)[:, :, :, :3]        # (N, H, W, 3)
    # Add background
    if bg_img is not None:
        bg_img = bg_img.permute(0, 2, 3, 1)         # (N, H, W, 3)
        images = torch.where(torch.eq(images.sum(dim=3, keepdim=True).expand(N, h, w, 3), 0), bg_img, images)
    return images
Ejemplo n.º 12
0
def batch_render(
    verts,
    faces,
    faces_per_pixel=10,
    K=None,
    rot=None,
    trans=None,
    colors=None,
    color=(0.53, 0.53, 0.8),  # light_purple
    ambient_col=0.5,
    specular_col=0.2,
    diffuse_col=0.3,
    face_colors=None,
    # color = (0.74117647, 0.85882353, 0.65098039),  # light_blue
    image_sizes=None,
    out_res=512,
    bin_size=0,
    shading="soft",
    mode="rgb",
    blend_gamma=1e-4,
    min_depth=None,
):
    device = torch.device("cuda:0")
    K = K.to(device)
    width, height = image_sizes[0]
    out_size = int(max(image_sizes[0]))
    raster_settings = RasterizationSettings(
        image_size=out_size,
        blur_radius=0.0,
        faces_per_pixel=faces_per_pixel,
        bin_size=bin_size,
    )

    fx = K[:, 0, 0]
    fy = K[:, 1, 1]
    focals = torch.stack([fx, fy], 1)
    px = K[:, 0, 2]
    py = K[:, 1, 2]
    principal_point = torch.stack([width - px, height - py], 1)
    if rot is None:
        rot = torch.eye(3).unsqueeze(0).to(device)
    if trans is None:
        trans = torch.zeros(3).unsqueeze(0).to(device)
    cameras = PerspectiveCameras(
        device=device,
        focal_length=focals,
        principal_point=principal_point,
        image_size=[(out_size, out_size) for _ in range(len(verts))],
        R=rot,
        T=trans,
    )
    if mode == "rgb":

        lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
        lights = DirectionalLights(
            device=device,
            direction=((0.6, -0.6, -0.6), ),
            ambient_color=((ambient_col, ambient_col, ambient_col), ),
            diffuse_color=((diffuse_col, diffuse_col, diffuse_col), ),
            specular_color=((specular_col, specular_col, specular_col), ),
        )
        if shading == "soft":
            shader = SoftPhongShader(device=device,
                                     cameras=cameras,
                                     lights=lights)
        elif shading == "hard":
            shader = HardPhongShader(device=device,
                                     cameras=cameras,
                                     lights=lights)
        else:
            raise ValueError(
                f"Shading {shading} for mode rgb not in [sort|hard]")
    elif mode == "silh":
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        shader = SoftSilhouetteShader(blend_params=blend_params)
    elif shading == "faceidx":
        shader = FaceIdxShader()
    elif (mode == "facecolor") and (shading == "hard"):
        shader = FaceColorShader(face_colors=face_colors)
    elif (mode == "facecolor") and (shading == "soft"):
        shader = SoftFaceColorShader(face_colors=face_colors,
                                     blend_gamma=blend_gamma)
    else:
        raise ValueError(
            f"Unhandled mode {mode} and shading {shading} combination")

    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=raster_settings),
        shader=shader,
    )
    if min_depth is not None:
        verts = torch.cat([verts[:, :, :2], verts[:, :, 2:].clamp(min_depth)],
                          2)
    if mode == "rgb":
        if colors is None:
            colors = get_colors(verts, color)
        tex = textures.TexturesVertex(verts_features=colors)

        meshes = Meshes(verts=verts, faces=faces, textures=tex)
    elif mode in ["silh", "facecolor"]:
        meshes = Meshes(verts=verts, faces=faces)
    else:
        raise ValueError(f"Render mode {mode} not in [rgb|silh]")

    square_images = renderer(meshes, cameras=cameras)
    square_images = torch.flip(square_images, (1, 2))
    height_off = abs(int(width - height))
    if width > height:
        images = square_images[:, height_off:, :]
    else:
        images = square_images[:, :, height_off:]
    return images
Ejemplo n.º 13
0
def load_lights():
    return DirectionalLights(device=config.cuda.device,
                             direction=((-40.0, 200.0, 100.0), ),
                             ambient_color=((0.7, 0.7, 0.7), ),
                             diffuse_color=((0.8, 0.8, 0.8), ),
                             specular_color=((0.0, 0.0, 0.0), ))
Ejemplo n.º 14
0
# model_verts = scene.verts_padded()
# scene = scene.update_padded(t.transform_points(model_verts))
# teapot_mesh = join_meshes_as_scene(scene)

# Initialize a perspective camera.
cameras = FoVOrthographicCameras(device=device,
                                 max_x=80.0,
                                 max_y=93.0,
                                 min_x=-80.0,
                                 min_y=-93.0,
                                 scale_xyz=((1, 1, 1), ))

lights = DirectionalLights(
    device=str(device),
    direction=((0., 0., 1.), ),
    ambient_color=((0.9, 0.9, 0.9), ),
    diffuse_color=((0, 0, 0), ),
    specular_color=((0, 0, 0), ),
)

raster_settings = RasterizationSettings(
    image_size=128,
    blur_radius=1e-9,
    faces_per_pixel=10,
)

background_color = torch.tensor([144 / 255, 71 / 255, 16 / 255], device=device)
blend_params = BlendParams(sigma=1e-4,
                           gamma=1e-4,
                           background_color=background_color)
Ejemplo n.º 15
0
    def __init__(self, config, crop_size, device):
        super(TexturedFLAME, self).__init__(config)
        self.crop_size = crop_size
        texture_model = np.load(config.texture_path)

        self.texture_shape = texture_model['mean'].shape
        self.texture_num_pc = texture_model['tex_dir'].shape[-1]
        self.register_buffer(
            'texture_mean',
            torch.reshape(
                torch.from_numpy(texture_model['mean']).to(dtype=self.dtype),
                (1, -1)))
        self.register_buffer(
            'texture_dir',
            torch.reshape(
                torch.from_numpy(
                    texture_model['tex_dir']).to(dtype=self.dtype),
                (-1, self.texture_num_pc)).t())

        # faces_uvs = torch.cat(flame_conf.batch_size * [
        #     torch.tensor(np.int64(faces_uvs), dtype=torch.int64, device=devices[0]).unsqueeze(0)])
        # # verts_uvs = torch.cat(flame_conf.batch_size * [torch.tensor(np.float32(verts_uvs), dtype=torch.float32, device=devices[0]).unsqueeze(0)])
        self.register_buffer(
            'faces_uvs',
            torch.cat(
                self.batch_size *
                [torch.from_numpy(np.int64(texture_model['ft'])).unsqueeze(0)
                 ]))
        self.register_buffer(
            'verts_uvs',
            torch.cat(self.batch_size * [
                torch.from_numpy(
                    texture_model['vt']).to(dtype=self.dtype).unsqueeze(0)
            ]))

        self.register_parameter(
            'texture_params',
            nn.Parameter(torch.zeros((1, self.texture_num_pc),
                                     dtype=self.dtype,
                                     requires_grad=True),
                         requires_grad=True))

        raster_settings = RasterizationSettings(image_size=crop_size,
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                bin_size=None,
                                                max_faces_per_bin=None)

        R, T = look_at_view_transform(1.0, 0.5, 0)
        self.register_buffer('renderer_R', R)
        self.register_buffer('renderer_T', T)
        self.renderer_camera = OpenGLPerspectiveCameras(R=self.renderer_R,
                                                        T=self.renderer_T,
                                                        fov=20,
                                                        device=device)
        renderer_lights = DirectionalLights(
            direction=[[0.0, 0.0, 2.0]],
            specular_color=[[0.0, 0.0, 0.0]],
            device=device)  # PointLights(location=[[0.0, 0.0, -1.0]])
        renderer_rasterizer = MeshRasterizer(cameras=self.renderer_camera,
                                             raster_settings=raster_settings)
        renderer_shader = TexturedSoftPhongShader(cameras=self.renderer_camera,
                                                  lights=renderer_lights,
                                                  device=device)
        self.renderer = MeshRenderer(rasterizer=renderer_rasterizer,
                                     shader=renderer_shader)
# Initialize an OpenGL perspective camera.
cameras = OpenGLPerspectiveCameras(device=device, fov=25.0, aspect_ratio=1.1)

# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(image_size=540,
                                        blur_radius=0.0,
                                        faces_per_pixel=1,
                                        bin_size=0)
# We can add a point light in front of the object.
#lights = PointLights(device=device, location=((-1.0, -1.0, -2.0),))
#"ambient_color", "diffuse_color", "specular_color"
# 'ambient':0.4,'diffuse':0.8, 'specular':0.3
lights = DirectionalLights(device=device,
                           ambient_color=[[0.25, 0.25, 0.25]],
                           diffuse_color=[[0.6, 0.6, 0.6]],
                           specular_color=[[0.15, 0.15, 0.15]],
                           direction=[[-1.0, -1.0, 1.0]])
phong_renderer = MeshRenderer(rasterizer=MeshRasterizer(
    cameras=cameras, raster_settings=raster_settings),
                              shader=HardPhongShader(device=device,
                                                     lights=lights))

R = torch.tensor(R, device=device).unsqueeze(0)
T = torch.tensor(t, device=device).unsqueeze(0)
image = phong_renderer(meshes_world=mesh, R=R, T=T)
image = image.cpu().numpy()[0, :, :, :3] * 255
cv2.imwrite("pytorch-render.png", image.astype(np.uint8))

print(cameras.get_projection_transform().get_matrix())
    def initRender(self, method, image_size):
        cameras = OpenGLPerspectiveCameras(device=self.device, fov=15)

        if (method == "soft-silhouette"):
            blend_params = BlendParams(sigma=1e-7, gamma=1e-7)

            raster_settings = RasterizationSettings(
                image_size=image_size,
                blur_radius=np.log(1. / 1e-7 - 1.) * blend_params.sigma,
                faces_per_pixel=self.faces_per_pixel)

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftSilhouetteShader(blend_params=blend_params))
        elif (method == "hard-silhouette"):
            blend_params = BlendParams(sigma=1e-7, gamma=1e-7)

            raster_settings = RasterizationSettings(
                image_size=image_size,
                blur_radius=np.log(1. / 1e-7 - 1.) * blend_params.sigma,
                faces_per_pixel=1)

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftSilhouetteShader(blend_params=blend_params))
        elif (method == "soft-depth"):
            # Soft Rasterizer - from https://github.com/facebookresearch/pytorch3d/issues/95
            #blend_params = BlendParams(sigma=1e-7, gamma=1e-7)
            blend_params = BlendParams(sigma=1e-3, gamma=1e-4)
            raster_settings = RasterizationSettings(
                image_size=image_size,
                #blur_radius= np.log(1. / 1e-7 - 1.) * blend_params.sigma,
                blur_radius=np.log(1. / 1e-3 - 1.) * blend_params.sigma,
                faces_per_pixel=self.faces_per_pixel)

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftDepthShader(blend_params=blend_params))
        elif (method == "hard-depth"):
            raster_settings = RasterizationSettings(image_size=image_size,
                                                    blur_radius=0,
                                                    faces_per_pixel=20)

            renderer = MeshRenderer(rasterizer=MeshRasterizer(
                cameras=cameras, raster_settings=raster_settings),
                                    shader=HardDepthShader())
        elif (method == "blurry-depth"):
            # Soft Rasterizer - from https://github.com/facebookresearch/pytorch3d/issues/95
            blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
            raster_settings = RasterizationSettings(
                image_size=image_size,
                blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
                faces_per_pixel=self.faces_per_pixel)

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftDepthShader(blend_params=blend_params))
        elif (method == "soft-phong"):
            blend_params = BlendParams(sigma=1e-3, gamma=1e-3)

            raster_settings = RasterizationSettings(
                image_size=image_size,
                blur_radius=np.log(1. / 1e-3 - 1.) * blend_params.sigma,
                faces_per_pixel=self.faces_per_pixel)

            # lights = DirectionalLights(device=self.device,
            #                            ambient_color=[[0.25, 0.25, 0.25]],
            #                            diffuse_color=[[0.6, 0.6, 0.6]],
            #                            specular_color=[[0.15, 0.15, 0.15]],
            #                            direction=[[0.0, 1.0, 0.0]])

            lights = DirectionalLights(device=self.device,
                                       direction=[[0.0, 1.0, 0.0]])

            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftPhongShader(device=self.device,
                                       blend_params=blend_params,
                                       lights=lights))

        elif (method == "hard-phong"):
            blend_params = BlendParams(sigma=1e-8, gamma=1e-8)

            raster_settings = RasterizationSettings(image_size=image_size,
                                                    blur_radius=0.0,
                                                    faces_per_pixel=1)

            lights = DirectionalLights(device=self.device,
                                       ambient_color=[[0.25, 0.25, 0.25]],
                                       diffuse_color=[[0.6, 0.6, 0.6]],
                                       specular_color=[[0.15, 0.15, 0.15]],
                                       direction=[[-1.0, -1.0, 1.0]])
            renderer = MeshRenderer(rasterizer=MeshRasterizer(
                cameras=cameras, raster_settings=raster_settings),
                                    shader=HardPhongShader(device=self.device,
                                                           lights=lights))

        else:
            print("Unknown render method!")
            return None
        return renderer