コード例 #1
0
    def testDecompositionIsNonRedundant(self):
        """Test that wavelet construction is not redundant.

    If the wavelet decompositon is not redundant, then we should be able to
    1) Construct a wavelet decomposition
    2) Alter a single coefficient in the decomposition
    3) Collapse that decomposition into an image and back
    and the two wavelet decompositions should be the same.
    """
        for wavelet_type in wavelet.generate_filters():
            for _ in range(4):
                # Construct an image and a wavelet decomposition of it.
                num_levels = np.int32(np.ceil(4 * np.random.uniform()))
                sz_clamp = 2**(num_levels - 1) + 1
                sz = np.maximum(
                    np.int32(
                        np.ceil(
                            np.array([2, 32, 32]) *
                            np.random.uniform(size=3))),
                    np.array([0, sz_clamp, sz_clamp]))
                im = np.random.uniform(size=sz)
                pyr = wavelet.construct(im, num_levels, wavelet_type)
                pyr = list(pyr)

            # Pick a coefficient at random in the decomposition to alter.
            d = np.int32(np.floor(np.random.uniform() * len(pyr)))
            v = np.random.uniform()
            if d == (len(pyr) - 1):
                if np.prod(pyr[d].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d].shape)).tolist()
                    pyr[d] = pyr[d].numpy()
                    pyr[d][c, i, j] = v
            else:
                b = np.int32(np.floor(np.random.uniform() * len(pyr[d])))
                if np.prod(pyr[d][b].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d][b].shape)).tolist()
                    pyr[d] = list(pyr[d])
                    pyr[d][b] = pyr[d][b].numpy()
                    pyr[d][b][c, i, j] = v

            # Collapse and then reconstruct the wavelet decomposition, and check
            # that it is unchanged.
            recon = wavelet.collapse(pyr, wavelet_type)
            pyr_again = wavelet.construct(recon, num_levels, wavelet_type)
            self._assert_pyramids_close(pyr, pyr_again, 1e-8)
コード例 #2
0
 def testConstructMatchesGoldenData(self):
     """Tests construct() against golden data."""
     im, pyr_true, wavelet_type = self._load_golden_data()
     pyr = wavelet.construct(im, len(pyr_true) - 1, wavelet_type)
     with self.session() as sess:
         pyr = sess.run(pyr)
     self._assert_pyramids_close(pyr, pyr_true, 1e-5)
コード例 #3
0
 def testWaveletTransformationIsVolumePreserving(self):
     """Tests that construct() is volume preserving when size is a power of 2."""
     for wavelet_type in wavelet.generate_filters():
         sz = (1, 4, 4)
         num_levels = 2
         # Construct the Jacobian of construct().
         im = np.float32(np.random.uniform(0., 1., sz))
         im_ph = tf.placeholder(tf.float32, im.shape)
         jacobian = []
         vec = lambda x: tf.reshape(x, [-1])
         for d in range(im.size):
             jacobian.append(
                 vec(
                     tf.gradients(
                         vec(
                             wavelet.flatten(
                                 wavelet.construct(im_ph, num_levels,
                                                   wavelet_type)))[d],
                         im_ph)[0]))
         jacobian = tf.stack(jacobian, 1)
         with self.session() as sess:
             jacobian = sess.run(jacobian, {im_ph: im})
         # Assert that the determinant of the Jacobian is close to 1.
         det = np.linalg.det(jacobian)
         self.assertAllClose(det, 1., atol=1e-5, rtol=1e-5)
コード例 #4
0
 def _construct_preserves_dtype(self, float_dtype):
   """Checks that construct()'s output has the same precision as its input."""
   x = float_dtype(np.random.normal(size=(3, 16, 16)))
   for wavelet_type in wavelet.generate_filters():
     with self.session():
       y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type)).eval()
     self.assertDTypeEqual(y, float_dtype)
コード例 #5
0
 def testRescaleAndUnrescaleReproducesInput(self):
   """Tests that rescale(rescale(x, k), 1/k) = x."""
   im = np.random.uniform(size=(2, 32, 32))
   scale_base = np.exp(np.random.normal())
   pyr = wavelet.construct(im, 4, 'LeGall5/3')
   pyr_rescaled = wavelet.rescale(pyr, scale_base)
   pyr_recon = wavelet.rescale(pyr_rescaled, 1. / scale_base)
   self._assert_pyramids_close(pyr, pyr_recon, 1e-8)
コード例 #6
0
 def testRescaleDoesNotAffectTheFirstLevel(self):
     """Tests that rescale(x, s)[0] = x[0] for any s."""
     im = np.random.uniform(size=(2, 32, 32))
     pyr = wavelet.construct(im, 4, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, np.exp(np.random.normal()))
     with self.session() as sess:
         pyr, pyr_rescaled = sess.run([pyr, pyr_rescaled])
     self._assert_pyramids_close(pyr[0:1], pyr_rescaled[0:1], 1e-8)
コード例 #7
0
 def testRescaleOneIsANoOp(self):
     """Tests that rescale(x, 1) = x."""
     im = np.random.uniform(size=(2, 32, 32))
     pyr = wavelet.construct(im, 4, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, 1.)
     with self.session() as sess:
         pyr, pyr_rescaled = sess.run([pyr, pyr_rescaled])
     self._assert_pyramids_close(pyr, pyr_rescaled, 1e-8)
コード例 #8
0
  def testAccurateRoundTripWithSmallRandomImages(self):
    """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4]."""
    for wavelet_type in wavelet.generate_filters():
      for width in range(1, 5):
        sz = [1, width, width]
        num_levels = wavelet.get_max_num_levels(sz)
        im = np.random.uniform(size=sz)

        pyr = wavelet.construct(im, num_levels, wavelet_type)
        recon = wavelet.collapse(pyr, wavelet_type)
        self.assertAllClose(recon, im, atol=1e-8, rtol=1e-8)
コード例 #9
0
 def testRescaleOneHalfIsNormalized(self):
     """Tests that rescale(construct(k), 0.5)[-1] = k for constant image k."""
     for num_levels in range(5):
         k = np.random.uniform()
         im = k * np.ones((2, 32, 32))
         pyr = wavelet.construct(im, num_levels, 'LeGall5/3')
         pyr_rescaled = wavelet.rescale(pyr, 0.5)
         self.assertAllClose(pyr_rescaled[-1],
                             k * np.ones_like(pyr_rescaled[-1]),
                             atol=1e-8,
                             rtol=1e-8)
コード例 #10
0
 def testAccurateRoundTripWithLargeRandomImages(self):
   """Tests that collapse(construct(x)) == x for large random x's."""
   for wavelet_type in wavelet.generate_filters():
     for _ in range(4):
       num_levels = np.int32(np.ceil(4 * np.random.uniform()))
       sz_clamp = 2**(num_levels - 1) + 1
       sz = np.maximum(
           np.int32(
               np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))),
           np.array([0, sz_clamp, sz_clamp]))
       im = np.random.uniform(size=sz)
       pyr = wavelet.construct(im, num_levels, wavelet_type)
       recon = wavelet.collapse(pyr, wavelet_type)
       self.assertAllClose(recon, im, atol=1e-8, rtol=1e-8)
コード例 #11
0
 def fun(z):
   # pylint: disable=cell-var-from-loop
   return wavelet.flatten(wavelet.construct(z, num_levels, wavelet_type))
コード例 #12
0
ファイル: adaptive.py プロジェクト: Leedeng/LineCounter
  def __call__(self, x):
    """Evaluates the loss function on a batch of images.

    Args:
      x: The image residuals for which the loss is being computed, which is
        expected to be the differences between RGB images. Must be a rank-4
        tensor, where the innermost dimension is the batch index, and the
        remaining 3 dimension corresponds `self._image_size` (two spatial, one
        channel).

    Returns:
      A TF tensor of the same type and shape as input `x`, containing
      the loss at each element of `x` as a function of `x`, `alpha`, and
      `scale`. These "losses" are actually negative log-likelihoods (as produced
      by distribution.nllfun()) and so they are not actually bounded from below
      by zero --- it is okay if they go negative! You'll probably want to
      minimize their sum or mean.
    """
    x = tf.convert_to_tensor(x)
    tf.debugging.assert_rank(x, 4)
    with tf.control_dependencies([
        tf.Assert(
            tf.reduce_all(tf.equal(x.shape[1:], self._image_size)),
            [x.shape[1:], self._image_size])
    ]):
      if self._color_space == 'YUV':
        x = util.rgb_to_syuv(x)
      # If `color_space` == 'RGB', do nothing.

      # Reshape `x` from
      #   (num_batches, width, height, num_channels) to
      #   (num_batches * num_channels, width, height)
      width, height, num_channels = self._image_size
      x_stack = tf.reshape(
          tf.transpose(x, perm=(0, 3, 1, 2)), (-1, width, height))

      # Turn each channel in `x_stack` into the spatial representation specified
      # by `representation`.
      if self._representation in wavelet.generate_filters():
        x_stack = wavelet.flatten(
            wavelet.rescale(
                wavelet.construct(x_stack, self._wavelet_num_levels,
                                  self._representation),
                self._wavelet_scale_base))
      elif self._representation == 'DCT':
        x_stack = util.image_dct(x_stack)
      # If `representation` == 'PIXEL', do nothing.

      # Reshape `x_stack` from
      #   (num_batches * num_channels, width, height) to
      #   (num_batches, num_channels * width * height)
      x_mat = tf.reshape(
          tf.transpose(
              tf.reshape(x_stack, [-1, num_channels, width, height]),
              perm=[0, 2, 3, 1]), [-1, width * height * num_channels])

      # Set up the adaptive loss. Note, if `use_students_t` == True then
      # `alpha_mat` actually contains "log(df)" values.
      loss_mat = self._lossfun(x_mat)

      # Reshape the loss function's outputs to have the shapes as the input.
      loss = tf.reshape(loss_mat, [-1, width, height, num_channels])

      if self._summarize_loss:
        # Summarize the `alpha` and `scale` parameters as images (normalized to
        # [0, 1]) and histograms.
        # Note that these may look unintuitive unless the colorspace is 'RGB'
        # and the image representation is 'PIXEL', as the image summaries
        # (like most images) are rendered as RGB pixels.
        log_scale = tf.math.log(self.scale())
        log_scale_min = tf.reduce_min(log_scale)
        log_scale_max = tf.reduce_max(log_scale)
        tf.summary.image('/log_scale',
                         (log_scale[tf.newaxis] - log_scale_min) /
                         (log_scale_max - log_scale_min + 1e-10))
        tf.summary.histogram('/log_scale', log_scale)

        if not self._use_students_t:
          alpha = self.alpha()
          alpha_min = tf.reduce_min(alpha)
          alpha_max = tf.reduce_max(alpha)
          tf.summary.image('/alpha', (alpha[tf.newaxis] - alpha_min) /
                           (alpha_max - alpha_min + 1e-10))
          tf.summary.histogram('/alpha', alpha)

      return loss
コード例 #13
0
def image_lossfun(x,
                  color_space='YUV',
                  representation='CDF9/7',
                  wavelet_num_levels=5,
                  wavelet_scale_base=1,
                  use_students_t=False,
                  summarize_loss=True,
                  **kwargs):
    """Computes the adaptive form of the robust loss on a set of images.

  This function is a wrapper around lossfun() above. Like lossfun(), this
  function is not "stateless" --- it requires inputs of a specific shape and
  size, and constructs TF variables describing each non-batch dimension in `x`.
  `x` is expected to be the difference between sets of RGB images, and the other
  arguments to this function allow for the color space and spatial
  representation of `x` to be changed before the loss is imposed. By default,
  this function uses a CDF9/7 wavelet decomposition in a YUV color space, which
  often works well. This function also returns handles to the scale and
  shape parameters (both in the shape of images) being optimized over,
  and summarizes both parameters in TensorBoard.

  Args:
    x: A set of image residuals for which the loss is being computed. Must be a
      rank-4 tensor of size (num_batches, width, height, color_channels). This
      is assumed to be a set of differences between RGB images.
    color_space: The color space that `x` will be transformed into before
      computing the loss. Must be 'RGB' (in which case no transformation is
      applied) or 'YUV' (in which case we actually use a volume-preserving
      scaled YUV colorspace so that log-likelihoods still have meaning, see
      util.rgb_to_syuv()). Note that changing this argument does not change the
      assumption that `x` is the set of differences between RGB images, it just
      changes what color space `x` is converted to from RGB when computing the
      loss.
    representation: The spatial image representation that `x` will be
      transformed into after converting the color space and before computing the
      loss. If this is a valid type of wavelet according to
      wavelet.generate_filters() then that is what will be used, but we also
      support setting this to 'DCT' which applies a 2D DCT to the images, and to
      'PIXEL' which applies no transformation to the image, thereby causing the
      loss to be imposed directly on pixels.
    wavelet_num_levels: If `representation` is a kind of wavelet, this is the
      number of levels used when constructing wavelet representations. Otherwise
      this is ignored. Should probably be set to as large as possible a value
      that is supported by the input resolution, such as that produced by
      wavelet.get_max_num_levels().
    wavelet_scale_base: If `representation` is a kind of wavelet, this is the
      base of the scaling used when constructing wavelet representations.
      Otherwise this is ignored. For image_lossfun() to be volume preserving (a
      useful property when evaluating generative models) this value must be ==
      1. If the goal of this loss isn't proper statistical modeling, then
      modifying this value (say, setting it to 0.5 or 2) may significantly
      improve performance.
    use_students_t: If true, use the NLL of Student's T-distribution instead
      of the adaptive loss. This causes all `alpha_*` inputs to be ignored.
    summarize_loss: Whether or not to make TF summaries describing the latent
      state of the loss function. True by default.
    **kwargs: Arguments to be passed to the underlying lossfun().

  Returns:
    A tuple of the form (`loss`, `alpha`, `scale`). If use_students_t == True,
    then `log(df)` is returned instead of `alpha`.

    `loss`: a TF tensor of the same type and shape as input `x`, containing
    the loss at each element of `x` as a function of `x`, `alpha`, and
    `scale`. These "losses" are actually negative log-likelihoods (as produced
    by distribution.nllfun()) and so they are not actually bounded from below
    by zero. You'll probably want to minimize their sum or mean.

    `scale`: a TF tensor of the same type as x, of size
      (width, height, color_channels),
    as we construct a scale variable for each spatial and color dimension of `x`
    but not for each batch element. This contains the current estimated scale
    parameter for each dimension, and will change during optimization.

    `alpha`: a TF tensor of the same type as x, of size
      (width, height, color_channels),
    as we construct an alpha variable for each spatial and color dimension of
    `x` but not for each batch element. This contains the current estimated
    alpha parameter for each dimension, and will change during optimization.

  Raises:
    ValueError: if `color_space` of `representation` are unsupported color
      spaces or image representations, respectively.
  """
    color_spaces = ['RGB', 'YUV']
    if color_space not in color_spaces:
        raise ValueError('`color_space` must be in {}, but is {!r}'.format(
            color_spaces, color_space))
    representations = wavelet.generate_filters() + ['DCT', 'PIXEL']
    if representation not in representations:
        raise ValueError('`representation` must be in {}, but is {!r}'.format(
            representations, representation))
    assert_ops = [tf.Assert(tf.equal(tf.rank(x), 4), [tf.rank(x)])]
    with tf.control_dependencies(assert_ops):
        if color_space == 'YUV':
            x = util.rgb_to_syuv(x)
        # If `color_space` == 'RGB', do nothing.

        # Reshape `x` from
        #   (num_batches, width, height, num_channels) to
        #   (num_batches * num_channels, width, height)
        _, width, height, num_channels = x.shape.as_list()
        x_stack = tf.reshape(tf.transpose(x, (0, 3, 1, 2)),
                             (-1, width, height))

        # Turn each channel in `x_stack` into the spatial representation specified
        # by `representation`.
        if representation in wavelet.generate_filters():
            x_stack = wavelet.flatten(
                wavelet.rescale(
                    wavelet.construct(x_stack, wavelet_num_levels,
                                      representation), wavelet_scale_base))
        elif representation == 'DCT':
            x_stack = util.image_dct(x_stack)
        # If `representation` == 'PIXEL', do nothing.

        # Reshape `x_stack` from
        #   (num_batches * num_channels, width, height) to
        #   (num_batches, num_channels * width * height)
        x_mat = tf.reshape(
            tf.transpose(
                tf.reshape(x_stack, [-1, num_channels, width, height]),
                [0, 2, 3, 1]), [-1, width * height * num_channels])

        # Set up the adaptive loss. Note, if `use_students_t` == True then
        # `alpha_mat` actually contains "log(df)" values.
        if use_students_t:
            loss_mat, alpha_mat, scale_mat = lossfun_students(x_mat, **kwargs)
        else:
            loss_mat, alpha_mat, scale_mat = lossfun(x_mat, **kwargs)

        # Reshape the loss function's outputs to have the shapes as the input.
        loss = tf.reshape(loss_mat, [-1, width, height, num_channels])
        alpha = tf.reshape(alpha_mat, [width, height, num_channels])
        scale = tf.reshape(scale_mat, [width, height, num_channels])

        if summarize_loss:
            # Summarize the `alpha` and `scale` parameters as images (normalized to
            # [0, 1]) and histograms.
            # Note that these may look unintuitive unless the colorspace is 'RGB' and
            # the image representation is 'PIXEL', as the image summaries (like most
            # images) are rendered as RGB pixels.
            alpha_min = tf.reduce_min(alpha)
            alpha_max = tf.reduce_max(alpha)
            tf.summary.image('robust/alpha', (alpha[tf.newaxis] - alpha_min) /
                             (alpha_max - alpha_min + 1e-10))
            tf.summary.histogram('robust/alpha', alpha)
            log_scale = tf.math.log(scale)
            log_scale_min = tf.reduce_min(log_scale)
            log_scale_max = tf.reduce_max(log_scale)
            tf.summary.image('robust/log_scale',
                             (log_scale[tf.newaxis] - log_scale_min) /
                             (log_scale_max - log_scale_min + 1e-10))
            tf.summary.histogram('robust/log_scale', log_scale)

        return loss, alpha, scale