def test_number_and_shape_of_scales_match_channels_last(self, num_levels):
   nlp = transform.NLP(num_levels=num_levels, data_format="channels_last")
   image = tf.zeros((1, 16, 16, 2))
   subbands = nlp(image)
   self.assertLen(subbands, num_levels)
   expected_shapes = [(1, 16, 16, 2), (1, 8, 8, 2)]
   for subband, shape in zip(subbands, expected_shapes):
     self.assertEqual(subband.shape, shape)
Ejemplo n.º 2
0
def nlpd(image_a,
         image_b,
         num_levels=6,
         cdm_min=5,
         cdm_max=180,
         data_format="channels_last"):
    """Normalized Laplacian pyramid distance.

  This implements the NLPD as defined in the paper:

  > "Perceptually optimized image rendering"</br>
  > V. Laparra, A. Berardino, J. Ballé and E. P. Simoncelli</br>
  > https://doi.org/10.1364/JOSAA.34.001511

  The inputs are assumed to be sRGB images, and the display is assumed to have a
  dynamic range of `cdm_min` to `cdm_max` cd/m^2.

  Args:
    image_a: `Tensor` containing image A.
    image_b: `Tensor` containing image B.
    num_levels: Integer. The number of pyramid levels, including the lowpass
      residual.
    cdm_min: Float. Minimum assumed cd/m^2 of display.
    cdm_max: Float. Maximum assumed cd/m^2 of display.
    data_format: String. Either `'channels_first'` or `'channels_last'`.

  Returns:
    A `Tensor` giving the NLPD values for each of the non-spatial dimensions
    (e.g. shaped NC for NHWC inputs).
  """
    if not 0 <= cdm_min < cdm_max:
        raise ValueError("Must have `0 <= cdm_min < cdm_max`.")

    if image_a.dtype.is_integer:
        image_a = tf.cast(image_a, tf.float32)
    if image_b.dtype.is_integer:
        image_b = tf.cast(image_b, tf.float32)

    def convert_to_cdm2(image):
        return ((image / 255)**2.4) * (cdm_max - cdm_min) + cdm_min

    nlp = transform.NLP(num_levels=num_levels, data_format=data_format)

    subbands_a = nlp(convert_to_cdm2(image_a))
    subbands_b = nlp(convert_to_cdm2(image_b))

    if data_format == "channels_first":
        spatial_axes = (-2, -1)
    else:
        spatial_axes = (-3, -2)
    return lp_norm(subbands_a, subbands_b, spatial_axes=spatial_axes)
Ejemplo n.º 3
0
def nlpd_fast(image_a, image_b, num_levels=6, data_format="channels_last"):
    """Normalized Laplacian pyramid distance.

  This implements a quick-and-dirty approximation to the NLPD, which is defined
  in the paper:

  > "Perceptually optimized image rendering"</br>
  > V. Laparra, A. Berardino, J. Ballé and E. P. Simoncelli</br>
  > https://doi.org/10.1364/JOSAA.34.001511

  The inputs are assumed to be sRGB images. This approximation omits the
  colorspace conversion.

  Args:
    image_a: `Tensor` containing image A.
    image_b: `Tensor` containing image B.
    num_levels: Integer. The number of pyramid levels, including the lowpass
      residual.
    data_format: String. Either `'channels_first'` or `'channels_last'`.

  Returns:
    A `Tensor` giving the NLPD values for each of the non-spatial dimensions
    (e.g. shaped NC for NHWC inputs).
  """
    if image_a.dtype.is_integer:
        image_a = tf.cast(image_a, tf.float32)
    if image_b.dtype.is_integer:
        image_b = tf.cast(image_b, tf.float32)

    nlp = transform.NLP(num_levels=num_levels,
                        gamma=None,
                        data_format=data_format)

    subbands_a = nlp(image_a)
    subbands_b = nlp(image_b)

    if data_format == "channels_first":
        spatial_axes = (-2, -1)
    else:
        spatial_axes = (-3, -2)
    return lp_norm(subbands_a, subbands_b, spatial_axes=spatial_axes)
 def test_invalid_shapes_fail_channels_last(self):
   nlp = transform.NLP(data_format="channels_last")
   with self.assertRaises(ValueError):
     nlp.build((16, 16))
 def test_invalid_data_format_fails(self):
   with self.assertRaises(ValueError):
     transform.NLP(data_format=3)
 def test_invalid_gamma_fails(self):
   with self.assertRaises(ValueError):
     transform.NLP(gamma=-1)
 def test_invalid_num_levels_fails(self, num_levels):
   with self.assertRaises(ValueError):
     transform.NLP(num_levels=num_levels)