def testRendersTwoCubesInBatch(self):
        """Renders a simple cube in two viewpoints to test the python wrapper."""

        vertex_rgb = (self.cube_vertex_positions * 0.5 + 0.5)
        vertex_rgba = tf.concat([vertex_rgb, tf.ones([8, 1])], axis=1)

        center = self.tf_float([[0.0, 0.0, 0.0]])
        world_up = self.tf_float([[0.0, 1.0, 0.0]])
        look_at_1 = camera_utils.look_at(self.tf_float([[2.0, 3.0, 6.0]]),
                                         center, world_up)
        look_at_2 = camera_utils.look_at(self.tf_float([[-3.0, 1.0, 6.0]]),
                                         center, world_up)
        projection_1 = tf.matmul(self.perspective, look_at_1)
        projection_2 = tf.matmul(self.perspective, look_at_2)
        projection = tf.concat([projection_1, projection_2], axis=0)
        background_value = [0.0, 0.0, 0.0, 0.0]

        rendered = rasterize_triangles.rasterize(
            tf.stack([self.cube_vertex_positions, self.cube_vertex_positions]),
            tf.stack([vertex_rgba, vertex_rgba]), self.cube_triangles,
            projection, self.image_width, self.image_height, background_value)

        with self.test_session() as sess:
            images = sess.run(rendered, feed_dict={})
            for i in (0, 1):
                image = images[i, :, :, :]
                baseline_image_name = 'Unlit_Cube_{}.png'.format(i)
                baseline_image_path = os.path.join(self.test_data_directory,
                                                   baseline_image_name)
                test_utils.expect_image_file_and_render_are_near(
                    self, sess, baseline_image_path, image)
    def testRendersTwoCubesInBatch(self):
        """Renders a simple cube in two viewpoints to test the python wrapper.
        """

        vertex_rgb = (self.cube_vertex_positions * 0.5 + 0.5)
        vertex_rgba = torch.cat([vertex_rgb, torch.ones([8, 1])], dim=1)

        center = torch.tensor([[0, 0, 0]], dtype=torch.float32)
        world_up = torch.tensor([[0, 1, 0]], dtype=torch.float32)
        look_at_1 = camera_utils.look_at(
            torch.tensor([[2, 3, 6]], dtype=torch.float32), center, world_up)
        look_at_2 = camera_utils.look_at(
            torch.tensor([[-3, 1, 6]], dtype=torch.float32), center, world_up)
        projection_1 = torch.matmul(self.perspective, look_at_1)
        projection_2 = torch.matmul(self.perspective, look_at_2)
        projection = torch.cat([projection_1, projection_2], dim=0)
        background_value = torch.Tensor([0., 0., 0., 0.])

        rendered = rasterize(
            torch.stack(
                [self.cube_vertex_positions, self.cube_vertex_positions]),
            torch.stack([vertex_rgba, vertex_rgba]), self.cube_triangles,
            projection, self.image_width, self.image_height, background_value)

        for i in (0, 1):
            image = rendered[i, :, :, :]
            baseline_image_name = "Unlit_Cube_{}.png".format(i)
            baseline_image_path = os.path.join(self.test_data_directory,
                                               baseline_image_name)
            test_utils.expect_image_file_and_render_are_near(
                self, baseline_image_path, image)
Пример #3
0
    def setUp(self):
        self.test_data_directory = 'mesh_renderer/test_data/'

        tf.reset_default_graph()
        self.cube_vertex_positions = tf.constant(
            [[-1, -1, 1], [-1, -1, -1], [-1, 1, -1], [-1, 1, 1], [1, -1, 1],
             [1, -1, -1], [1, 1, -1], [1, 1, 1]],
            dtype=tf.float32)
        self.cube_triangles = tf.constant(
            [[0, 1, 2], [2, 3, 0], [3, 2, 6], [6, 7, 3], [7, 6, 5], [5, 4, 7],
             [4, 5, 1], [1, 0, 4], [5, 6, 2], [2, 1, 5], [7, 4, 0], [0, 3, 7]],
            dtype=tf.int32)

        tf_float = lambda x: tf.constant(x, dtype=tf.float32)
        # camera position:
        eye = tf_float([[2.0, 3.0, 6.0]])
        center = tf_float([[0.0, 0.0, 0.0]])
        world_up = tf_float([[0.0, 1.0, 0.0]])

        self.image_width = 640
        self.image_height = 480

        look_at = camera_utils.look_at(eye, center, world_up)
        perspective = camera_utils.perspective(
            self.image_width / self.image_height, tf_float([40.0]),
            tf_float([0.01]), tf_float([10.0]))
        self.projection = tf.matmul(perspective, look_at)
  def testRendersSimpleCube(self):
    """Renders a simple cube to test the kernel and python wrapper."""

    tf_float = lambda x: tf.constant(x, dtype=tf.float32)
    # camera position:
    eye = tf_float([[2.0, 3.0, 6.0]])
    center = tf_float([[0.0, 0.0, 0.0]])
    world_up = tf_float([[0.0, 1.0, 0.0]])
    image_width = 640
    image_height = 480

    look_at = camera_utils.look_at(eye, center, world_up)
    perspective = camera_utils.perspective(image_width / image_height,
                                           tf_float([40.0]), tf_float([0.01]),
                                           tf_float([10.0]))

    vertex_rgb = (self.cube_vertex_positions * 0.5 + 0.5)
    vertex_rgba = tf.concat([vertex_rgb, tf.ones([8, 1])], axis=1)

    projection = tf.matmul(perspective, look_at)
    background_value = [0.0, 0.0, 0.0, 0.0]

    rendered = rasterize_triangles.rasterize_triangles(
        tf.expand_dims(self.cube_vertex_positions, axis=0),
        tf.expand_dims(vertex_rgba, axis=0), self.cube_triangles, projection,
        image_width, image_height, background_value)

    with self.test_session() as sess:
      image = sess.run(rendered, feed_dict={})[0,...]
      target_image_name = 'Unlit_Cube_0.png'
      baseline_image_path = os.path.join(self.test_data_directory,
                                         target_image_name)
      test_utils.expect_image_file_and_render_are_near(
          self, sess, baseline_image_path, image)
def mesh_renderer_2(vertices,
                  triangles,
                  normals,
                  diffuse_colors,
                  camera_position,
                  camera_lookat,
                  camera_up,
                  light_positions,
                  light_intensities,
                  image_width,
                  image_height,
                  specular_colors=None,
                  shininess_coefficients=None,
                  ambient_color=None,
                  fov_y=40.0,
                  near_clip=0.01,
                  far_clip=10.0):
  """Renders an input scene using phong shading, and returns an output image.

  Args:
    vertices: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each
        triplet is an xyz position in world space.
    triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet
        should contain vertex indices describing a triangle such that the
        triangle's normal points toward the viewer if the forward order of the
        triplet defines a clockwise winding of the vertices. Gradients with
        respect to this tensor are not available.
    normals: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each
        triplet is the xyz vertex normal for its corresponding vertex. Each
        vector is assumed to be already normalized.
    diffuse_colors: 3-D float32 tensor with shape [batch_size,
        vertex_count, 3]. The RGB diffuse reflection in the range [0,1] for
        each vertex.
    camera_position: 2-D tensor with shape [batch_size, 3] or 1-D tensor with
        shape [3] specifying the XYZ world space camera position.
    camera_lookat: 2-D tensor with shape [batch_size, 3] or 1-D tensor with
        shape [3] containing an XYZ point along the center of the camera's gaze.
    camera_up: 2-D tensor with shape [batch_size, 3] or 1-D tensor with shape
        [3] containing the up direction for the camera. The camera will have no
        tilt with respect to this direction.
    light_positions: a 3-D tensor with shape [batch_size, light_count, 3]. The
        XYZ position of each light in the scene. In the same coordinate space as
        pixel_positions.
    light_intensities: a 3-D tensor with shape [batch_size, light_count, 3]. The
        RGB intensity values for each light. Intensities may be above one.
    image_width: int specifying desired output image width in pixels.
    image_height: int specifying desired output image height in pixels.
    specular_colors: 3-D float32 tensor with shape [batch_size,
        vertex_count, 3]. The RGB specular reflection in the range [0, 1] for
        each vertex.  If supplied, specular reflections will be computed, and
        both specular_colors and shininess_coefficients are expected.
    shininess_coefficients: a 0D-2D float32 tensor with maximum shape
       [batch_size, vertex_count]. The phong shininess coefficient of each
       vertex. A 0D tensor or float gives a constant shininess coefficient
       across all batches and images. A 1D tensor must have shape [batch_size],
       and a single shininess coefficient per image is used.
    ambient_color: a 2D tensor with shape [batch_size, 3]. The RGB ambient
        color, which is added to each pixel in the scene. If None, it is
        assumed to be black.
    fov_y: float, 0D tensor, or 1D tensor with shape [batch_size] specifying
        desired output image y field of view in degrees.
    near_clip: float, 0D tensor, or 1D tensor with shape [batch_size] specifying
        near clipping plane distance.
    far_clip: float, 0D tensor, or 1D tensor with shape [batch_size] specifying
        far clipping plane distance.

  Returns:
    A 4-D float32 tensor of shape [batch_size, image_height, image_width, 4]
    containing the lit RGBA color values for each image at each pixel. RGB
    colors are the intensity values before tonemapping and can be in the range
    [0, infinity]. Clipping to the range [0,1] with tf.clip_by_value is likely
    reasonable for both viewing and training most scenes. More complex scenes
    with multiple lights should tone map color values for display only. One
    simple tonemapping approach is to rescale color values as x/(1+x); gamma
    compression is another common techinque. Alpha values are zero for
    background pixels and near one for mesh pixels.
  Raises:
    ValueError: An invalid argument to the method is detected.
  """
  if len(vertices.shape) != 3:
    raise ValueError('Vertices must have shape [batch_size, vertex_count, 3].')
  batch_size = vertices.shape[0].value
  if len(normals.shape) != 3:
    raise ValueError('Normals must have shape [batch_size, vertex_count, 3].')
  if len(light_positions.shape) != 3:
    raise ValueError(
        'Light_positions must have shape [batch_size, light_count, 3].')
  if len(light_intensities.shape) != 3:
    raise ValueError(
        'Light_intensities must have shape [batch_size, light_count, 3].')
  if len(diffuse_colors.shape) != 3:
    raise ValueError(
        'vertex_diffuse_colors must have shape [batch_size, vertex_count, 3].')
  if (ambient_color is not None and
      ambient_color.get_shape().as_list() != [batch_size, 3]):
    raise ValueError('Ambient_color must have shape [batch_size, 3].')
  if camera_position.get_shape().as_list() == [3]:
    camera_position = tf.tile(
        tf.expand_dims(camera_position, axis=0), [batch_size, 1])
  elif camera_position.get_shape().as_list() != [batch_size, 3]:
    raise ValueError('Camera_position must have shape [batch_size, 3]')
  if camera_lookat.get_shape().as_list() == [3]:
    camera_lookat = tf.tile(
        tf.expand_dims(camera_lookat, axis=0), [batch_size, 1])
  elif camera_lookat.get_shape().as_list() != [batch_size, 3]:
    raise ValueError('Camera_lookat must have shape [batch_size, 3]')
  if camera_up.get_shape().as_list() == [3]:
    camera_up = tf.tile(tf.expand_dims(camera_up, axis=0), [batch_size, 1])
  elif camera_up.get_shape().as_list() != [batch_size, 3]:
    raise ValueError('Camera_up must have shape [batch_size, 3]')
  if isinstance(fov_y, float):
    fov_y = tf.constant(batch_size * [fov_y], dtype=tf.float32)
  elif not fov_y.get_shape().as_list():
    fov_y = tf.tile(tf.expand_dims(fov_y, 0), [batch_size])
  elif fov_y.get_shape().as_list() != [batch_size]:
    raise ValueError('Fov_y must be a float, a 0D tensor, or a 1D tensor with'
                     'shape [batch_size]')
  if isinstance(near_clip, float):
    near_clip = tf.constant(batch_size * [near_clip], dtype=tf.float32)
  elif not near_clip.get_shape().as_list():
    near_clip = tf.tile(tf.expand_dims(near_clip, 0), [batch_size])
  elif near_clip.get_shape().as_list() != [batch_size]:
    raise ValueError('Near_clip must be a float, a 0D tensor, or a 1D tensor'
                     'with shape [batch_size]')
  if isinstance(far_clip, float):
    far_clip = tf.constant(batch_size * [far_clip], dtype=tf.float32)
  elif not far_clip.get_shape().as_list():
    far_clip = tf.tile(tf.expand_dims(far_clip, 0), [batch_size])
  elif far_clip.get_shape().as_list() != [batch_size]:
    raise ValueError('Far_clip must be a float, a 0D tensor, or a 1D tensor'
                     'with shape [batch_size]')
  if specular_colors is not None and shininess_coefficients is None:
    raise ValueError(
        'Specular colors were supplied without shininess coefficients.')
  if shininess_coefficients is not None and specular_colors is None:
    raise ValueError(
        'Shininess coefficients were supplied without specular colors.')
  if specular_colors is not None:
    # Since a 0-D float32 tensor is accepted, also accept a float.
    if isinstance(shininess_coefficients, float):
      shininess_coefficients = tf.constant(
          shininess_coefficients, dtype=tf.float32)
    if len(specular_colors.shape) != 3:
      raise ValueError('The specular colors must have shape [batch_size, '
                       'vertex_count, 3].')
    if len(shininess_coefficients.shape) > 2:
      raise ValueError('The shininess coefficients must have shape at most'
                       '[batch_size, vertex_count].')
    # If we don't have per-vertex coefficients, we can just reshape the
    # input shininess to broadcast later, rather than interpolating an
    # additional vertex attribute:
    if len(shininess_coefficients.shape) < 2:
      vertex_attributes = tf.concat(
          [normals, vertices, diffuse_colors, specular_colors], axis=2)
    else:
      vertex_attributes = tf.concat(
          [
              normals, vertices, diffuse_colors, specular_colors,
              tf.expand_dims(shininess_coefficients, axis=2)
          ],
          axis=2)
  else:
    vertex_attributes = tf.concat([normals, vertices, diffuse_colors], axis=2)

  camera_matrices = camera_utils.look_at(camera_position, camera_lookat,
                                         camera_up)

  perspective_transforms = camera_utils.perspective(image_width / image_height,
                                                    fov_y, near_clip, far_clip)

  clip_space_transforms = tf.matmul(perspective_transforms, camera_matrices)

  pixel_attributes, vertex_ids, barycentric_coordinates = rasterize_triangles.rasterize_2(
      vertices, vertex_attributes, triangles, clip_space_transforms,
      image_width, image_height, [-1] * vertex_attributes.shape[2].value)
  print(vertex_ids)
  print(barycentric_coordinates)
  # Extract the interpolated vertex attributes from the pixel buffer and
  # supply them to the shader:
  pixel_normals = tf.nn.l2_normalize(pixel_attributes[:, :, :, 0:3], dim=3)
  pixel_positions = pixel_attributes[:, :, :, 3:6]
  return pixel_positions, pixel_normals, vertex_ids, barycentric_coordinates
Пример #6
0
def mesh_renderer(vertices,
                  triangles,
                  normals,
                  diffuse_colors,
                  camera_position,
                  camera_lookat,
                  camera_up,
                  light_positions,
                  light_intensities,
                  image_width,
                  image_height,
                  specular_colors=None,
                  shininess_coefficients=None,
                  ambient_color=None,
                  fov_y=40.0,
                  near_clip=0.01,
                  far_clip=10.0):
    """Renders an input scene using phong shading, and returns an output image.

    Args:
      vertices: 3D float32 tensor with shape [batch_size, vertex_count, 3]. Each
        triplet is an xyz position in world space.
      triangles: 2D int32 tensor with shape [triangle_count, 3]. Each triplet
        should contain vertex indices describing a triangle such that the
        triangle's normal points toward the viewer if the forward order of the
        triplet defines a clockwise winding of the vertices. Gradients with
        respect to this tensor are not available.
      normals: 3D float32 tensor with shape [batch_size, vertex_count, 3]. Each
        triplet is the xyz vertex normal for its corresponding vertex. Each
        vector is assumed to be already normalized.
      diffuse_colors: 3D float32 tensor with shape [batch_size,
        vertex_count, 3]. The RGB diffuse reflection in the range [0, 1] for
        each vertex.
      camera_position: 2D tensor withb shape [batch_size, 3] or 1D tensor with
        shape [3] specifying the XYZ world space camera position.
      camera_lookat: 2D tensor with shape [batch_size, 3] or 1D tensor with
        shape [3] containing an XYZ point along the center of the camera's gaze.
      camera_up: 2D tensor with shape [batch_size, 3] or 1D tensor with shape
        [3] containing the up direction for the camera. The camera will have
        no tilt with respect to this direction.
      light_positions: a 3D tensor with shape [batch_size, light_count, 3]. The
        XYZ position of each light in the scene. In the same coordinate space as
        pixel_positions.
      light_intensities: a 3D tensor with shape [batch_size, light_count, 3].
        The RGB intensity values for each light. Intensities may be above 1.
      image_width: int specifying desired output image width in pixels.
      image_height: int specifying desired output image height in pixels.
      specular_colors: 3D float32 tensor with shape [batch_size,
        vertex_count, 3]. The RGB specular reflection in the range [0, 1] for
        each vertex. If supplied, specular reflections will be computed, and
        both specular colors and shininess_coefficients are expected.
      shininess_coefficients: a 0D-2D float32 tensor with maximum shape
        [batch_size, vertex_count]. The phong shininess coefficient of each
        vertex. A 0D tensor or float gives a constant shininess coefficient of
        all vertices across all batches and images. A 1D tensor must have shape
        [batch_size], and a single shininess coefficient per image is used.
      ambient_color: a 2D tensor with shape [bath_size, 3]. The RGB ambient
        color, which is added to each pixel in the scene. If None, it is
        assumed to be black.
      fov_y: float, 0D tensor, or 1D tensor with shape [batch_size] specifying
        desired output image y field of view in degrees.
      near_clip: float, 0D tensor, or 1D tensor with shape [batch_size]
        specifying near clipping plane distance.
      far_clip: float, 0D tensor, or 1D tensor with shape [batch_size]
        specifying far clipping plane distance.

    Returns:
      A 4D float32 tensor of shape [batch_size, image_height, image_width, 4]
      containing the lit RGBA color values for each image at each pixel. RGB
      colors are the intensity values before tonemapping and can be in the range
      [0, infinity]. Clipping to the range [0, 1] with np.clip is likely
      reasonable for both viewing and training most scenes. More complex scenes
      with multiple lights should tone map color values for display only. One
      simple tonemapping approach is to rescale color values as x/(1+x); gamma
      compression is another common technique. Alpha values are zero for
      background pixels and near one for mesh pixels.
    Raises:
      ValueError: An invalid argument to the method is detected.
    """
    if len(vertices.shape) != 3 or vertices.shape[-1] != 3:
        raise ValueError(
            "Vertices must have shape [batch_size, vertex_count, 3].")
    batch_size = vertices.shape[0]
    if len(normals.shape) != 3 or normals.shape[-1] != 3:
        raise ValueError(
            "Normals must have shape [batch_size, vertex_count, 3].")
    if len(light_positions.shape) != 3 or light_positions.shape[-1] != 3:
        raise ValueError(
            "light_positions must have shape [batch_size, light_count, 3].")
    if len(light_intensities.shape) != 3 or light_intensities.shape[-1] != 3:
        raise ValueError(
            "light_intensities must have shape [batch_size, light_count, 3].")
    if len(diffuse_colors.shape) != 3 or diffuse_colors.shape[-1] != 3:
        raise ValueError(
            "diffuse_colors must have shape [batch_size, vertex_count, 3].")
    if (ambient_color is not None
            and list(ambient_color.shape) != [batch_size, 3]):
        raise ValueError("ambient_color must have shape [batch_size, 3].")
    if list(camera_position.shape) == [3]:
        camera_position = torch.unsqueeze(camera_position,
                                          0).repeat(batch_size, 1)
    elif list(camera_position.shape) != [batch_size, 3]:
        raise ValueError(
            "camera_position must have shape [batch_size, 3] or [3].")
    if list(camera_lookat.shape) == [3]:
        camera_lookat = torch.unsqueeze(camera_lookat, 0).repeat(batch_size, 1)
    elif list(camera_lookat.shape) != [batch_size, 3]:
        raise ValueError(
            "camera_lookat must have shape [batch_size, 3] or [3].")
    if list(camera_up.shape) == [3]:
        camera_up = torch.unsqueeze(camera_up, 0).repeat(batch_size, 1)
    elif list(camera_up.shape) != [batch_size, 3]:
        raise ValueError("camera_up must have shape [batch_size, 3] or [3].")
    if isinstance(fov_y, float):
        fov_y = torch.tensor(batch_size * [fov_y], dtype=torch.float32)
    elif len(fov_y.shape) == 0:
        fov_y = torch.unsqueeze(fov_y, 0).repeat(batch_size)
    elif list(fov_y.shape) != [batch_size]:
        raise ValueError("fov_y must be a float, a 0D tensor, or a 1D tensor "
                         "with shape [batch_size].")
    if isinstance(near_clip, float):
        near_clip = torch.tensor(batch_size * [near_clip], dtype=torch.float32)
    elif len(near_clip.shape) == 0:
        near_clip = torch.unsqueeze(near_clip, 0).repeat(batch_size)
    elif list(near_clip.shape) != [batch_size]:
        raise ValueError("near_clip must be a float, a 0D tensor, or a 1D "
                         "tensor with shape [batch_size].")
    if isinstance(far_clip, float):
        far_clip = torch.tensor(batch_size * [far_clip], dtype=torch.float32)
    elif len(far_clip.shape) == 0:
        far_clip = torch.unsqueeze(far_clip, 0).repeat(batch_size)
    elif list(far_clip.shape) != [batch_size]:
        raise ValueError("far_clip must be a float, a 0D tensor, or a 1D "
                         "tensor with shape [batch_size].")
    if specular_colors is not None and shininess_coefficients is None:
        raise ValueError(
            "Specular colors were supplied without shininess coefficients.")
    if shininess_coefficients is not None and specular_colors is None:
        raise ValueError(
            "Shininess coefficients were supplied without specular colors.")
    if specular_colors is not None:
        # Since a 0D float32 tensor is accepted, also accept a float.
        if isinstance(shininess_coefficients, float):
            shininess_coefficients = torch.tensor(shininess_coefficients,
                                                  dtype=torch.float32)
        if len(specular_colors.shape) != 3:
            raise ValueError(
                "The specular colors must have shape [batch_size, "
                "vertex_count, 3].")
        if len(shininess_coefficients.shape) > 2:
            raise ValueError("The shininess coefficients must have shape at "
                             "most [batch_size, vertex_count].")
        # If we don't have per-vertex coefficients, we can just reshape the
        # input shininess to broadcast later, rather than interpolating an
        # additional vertex attribute:
        if len(shininess_coefficients.shape) < 2:
            vertex_attributes = torch.cat(
                [normals, vertices, diffuse_colors, specular_colors], 2)
        else:
            vertex_attributes = torch.cat([
                normals, vertices, diffuse_colors, specular_colors,
                torch.unsqueeze(shininess_coefficients, 2)
            ], 2)
    else:
        vertex_attributes = torch.cat([normals, vertices, diffuse_colors], 2)

    camera_matrices = camera_utils.look_at(camera_position, camera_lookat,
                                           camera_up)

    perspective_transforms = camera_utils.perspective(
        image_width / image_height, fov_y, near_clip, far_clip)

    clip_space_transforms = torch.matmul(perspective_transforms,
                                         camera_matrices)

    pixel_attributes = rasterize(vertices, vertex_attributes, triangles,
                                 clip_space_transforms, image_width,
                                 image_height,
                                 [-1] * vertex_attributes.shape[2])

    # Extract the interpolated vertex attributes from the pixel buffer and
    # supply them to the shader:
    pixel_normals = torch.nn.functional.normalize(pixel_attributes[:, :, :,
                                                                   0:3],
                                                  p=2,
                                                  dim=3)
    pixel_positions = pixel_attributes[:, :, :, 3:6]
    diffuse_colors = pixel_attributes[:, :, :, 6:9]
    if specular_colors is not None:
        specular_colors = pixel_attributes[:, :, :, 9:12]
        # Retrieve the interpolated shininess coefficients if necessary, or just
        # reshape our input for broadcasting:
        if len(shininess_coefficients.shape) == 2:
            shininess_coefficients = pixel_attributes[:, :, :, 12]
        else:
            shininess_coefficients = torch.reshape(shininess_coefficients,
                                                   [-1, 1, 1])

    pixel_mask = (diffuse_colors >= 0.0).reduce(dim=3).type(torch.float32)

    renders = phong_shader(normals=pixel_normals,
                           alphas=pixel_mask,
                           pixel_positions=pixel_positions,
                           light_positions=light_positions,
                           light_intensities=light_intensities,
                           diffuse_colors=diffuse_colors,
                           camera_position=camera_position
                           if specular_colors is not None else None,
                           specular_colors=specular_colors,
                           shininess_coefficients=shininess_coefficients,
                           ambient_color=ambient_color)
    return renders