Exemple #1
0
    def call(self, x, training):
        """Computes rate and distortion losses."""
        entropy_model = tfc.LocationScaleIndexedEntropyModel(tfc.NoisyNormal,
                                                             self.num_scales,
                                                             self.scale_fn,
                                                             coding_rank=3,
                                                             compression=False)
        side_entropy_model = tfc.ContinuousBatchedEntropyModel(
            self.hyperprior, coding_rank=3, compression=False)

        x = tf.cast(
            x, self.compute_dtype)  # TODO(jonycgn): Why is this necessary?
        y = self.analysis_transform(x)
        z = self.hyper_analysis_transform(abs(y))
        z_hat, side_bits = side_entropy_model(z, training=training)
        indexes = self.hyper_synthesis_transform(z_hat)
        y_hat, bits = entropy_model(y, indexes, training=training)
        x_hat = self.synthesis_transform(y_hat)

        # Total number of bits divided by total number of pixels.
        num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), bits.dtype)
        bpp = (tf.reduce_sum(bits) + tf.reduce_sum(side_bits)) / num_pixels
        # Mean squared error across pixels.
        mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
        mse = tf.cast(mse, bpp.dtype)
        # The rate-distortion Lagrangian.
        loss = bpp + self.lmbda * mse
        return loss, bpp, mse
Exemple #2
0
 def __init__(self, lmbda, context_length, num_filters, num_scales,
              scale_min, scale_max):
     super().__init__()
     self.lmbda = lmbda
     self.num_scales = num_scales
     offset = tf.math.log(scale_min)
     factor = (tf.math.log(scale_max) -
               tf.math.log(scale_min)) / (num_scales - 1.)
     self.context_model = ContextModel(context_length=context_length,
                                       num_filters=num_filters)
     self.scale_fn = lambda i: tf.math.exp(offset + factor * i)
     self.hyperprior = tfc.NoisyDeepFactorized(batch_shape=(num_filters, ))
     self.analysis_transform = AnalysisTransform(
         context_length=context_length, num_filters=num_filters)
     self.hyper_analysis_transform = HyperAnalysisTransform(
         context_length=context_length, num_filters=num_filters)
     self.hyper_synthesis_transform = HyperSynthesisTransform(
         context_length=context_length, num_filters=num_filters)
     self.synthesis_transform = SynthesisTransform(
         context_length=context_length, num_filters=num_filters)
     self.entropy_model = tfc.LocationScaleIndexedEntropyModel(
         tfc.NoisyNormal,
         self.num_scales,
         self.scale_fn,
         coding_rank=3,
         compression=False)
     self.side_entropy_model = tfc.ContinuousBatchedEntropyModel(
         self.hyperprior, coding_rank=3, compression=False)
Exemple #3
0
 def fit(self, *args, **kwargs):
   retval = super().fit(*args, **kwargs)
   # After training, fix range coding tables.
   self.em_z = tfc.ContinuousBatchedEntropyModel(
       self.hyperprior, coding_rank=3, compression=True,
       offset_heuristic=False)
   self.em_y = tfc.LocationScaleIndexedEntropyModel(
       tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn,
       coding_rank=3, compression=True)
   return retval
Exemple #4
0
 def fit(self, *args, **kwargs):
     retval = super().fit(*args, **kwargs)
     # After training, fix range coding tables.
     self.entropy_model = tfc.LocationScaleIndexedEntropyModel(
         tfc.NoisyNormal,
         self.num_scales,
         self.scale_fn,
         coding_rank=3,
         compression=True)
     self.side_entropy_model = tfc.ContinuousBatchedEntropyModel(
         self.hyperprior, coding_rank=3, compression=True)
     return retval
Exemple #5
0
    def _run(self, mode, x=None, bit_strings=None):
        """Run model according to `mode` (train, compress, or decompress)."""
        training = (mode == "train")

        if mode == "decompress":
            x_shape, y_shape, z_shape, z_string = bit_strings[:4]
            y_strings = bit_strings[4:]
            assert len(y_strings) == NUM_SLICES
        else:
            y_strings = []
            x_shape = tf.shape(x)[1:-1]

            # Build the encoder (analysis) half of the hierarchical autoencoder.
            y = self.analysis_transform(x)
            y_shape = tf.shape(y)[1:-1]

            z = self.hyper_analysis_transform(y)
            z_shape = tf.shape(z)[1:-1]

        if mode == "train":
            num_pixels = self.args.batchsize * self.args.patchsize**2
        else:
            num_pixels = tf.cast(tf.reduce_prod(x_shape), tf.float32)

        # Build the entropy model for the hyperprior (z).
        em_z = tfc.ContinuousBatchedEntropyModel(self.entropy_bottleneck,
                                                 coding_rank=3,
                                                 compression=not training,
                                                 no_variables=True)

        if mode != "decompress":
            # When training, z_bpp is based on the noisy version of z (z_tilde).
            _, z_bits = em_z(z, training=training)
            z_bpp = tf.reduce_mean(z_bits) / num_pixels

        if training:
            # Use rounding (instead of uniform noise) to modify z before passing it
            # to the hyper-synthesis transforms. Note that quantize() overrides the
            # gradient to create a straight-through estimator.
            z_hat = em_z.quantize(z)
            z_string = None
        else:
            if mode == "compress":
                z_string = em_z.compress(z)
            z_hat = em_z.decompress(z_string, z_shape)

        # Build the decoder (synthesis) half of the hierarchical autoencoder.
        latent_scales = self.hyper_synthesis_scale_transform(z_hat)
        latent_means = self.hyper_synthesis_mean_transform(z_hat)

        # En/Decode each slice conditioned on hyperprior and previous slices.
        y_slices = (y_strings if mode == "decompress" else tf.split(
            y, NUM_SLICES, axis=-1))
        y_hat_slices = []
        y_bpps = []
        for slice_index, y_slice in enumerate(y_slices):
            # Model may condition on only a subset of previous slices.
            support_slices = (y_hat_slices if MAX_SUPPORT_SLICES < 0 else
                              y_hat_slices[:MAX_SUPPORT_SLICES])

            # Predict mu and sigma for the current slice.
            mean_support = tf.concat([latent_means] + support_slices, axis=-1)
            mu = self.cc_mean_transforms[slice_index](mean_support)
            mu = mu[:, :y_shape[0], :y_shape[1], :]

            # Note that in this implementation, `sigma` represents scale indices,
            # not actual scale values.
            scale_support = tf.concat([latent_scales] + support_slices,
                                      axis=-1)
            sigma = self.cc_scale_transforms[slice_index](scale_support)
            sigma = sigma[:, :y_shape[0], :y_shape[1], :]

            # Build the conditional entropy model for this slice.
            em_y = tfc.LocationScaleIndexedEntropyModel(
                tfc.NoisyNormal,
                num_scales=SCALES_LEVELS,
                scale_fn=scale_fn,
                coding_rank=3,
                compression=not training,
                no_variables=True)

            if mode == "decompress":
                y_hat_slice = em_y.decompress(y_slice, sigma, loc=mu)
            else:
                _, slice_bits = em_y(y_slice, sigma, loc=mu, training=training)
                slice_bpp = tf.reduce_mean(slice_bits) / num_pixels
                y_bpps.append(slice_bpp)

                if training:
                    # For the synthesis transform, use rounding. Note that quantize()
                    # overrides the gradient to create a straight-through estimator.
                    y_hat_slice = em_y.quantize(y_slice, sigma, loc=mu)
                else:
                    assert mode == "compress"
                    slice_string = em_y.compress(y_slice, sigma, mu)
                    y_strings.append(slice_string)
                    y_hat_slice = em_y.decompress(slice_string, sigma, mu)

            # Add latent residual prediction (LRP).
            lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1)
            lrp = self.lrp_transforms[slice_index](lrp_support)
            lrp = 0.5 * tf.math.tanh(lrp)
            y_hat_slice += lrp

            y_hat_slices.append(y_hat_slice)

        # Merge slices and generate the image reconstruction.
        y_hat = tf.concat(y_hat_slices, axis=-1)
        x_hat = self.synthesis_transform(y_hat)
        x_hat = x_hat[:, :x_shape[0], :x_shape[1], :]

        if mode != "decompress":
            # Total bpp is sum of bpp from hyperprior and all slices.
            total_bpp = tf.add_n(y_bpps + [z_bpp])

        # Mean squared error across pixels.
        if training:
            # Don't clip or round pixel values while training.
            mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
            mse *= 255**2  # multiply by 255^2 to correct for rescaling
        else:
            x_hat = tf.clip_by_value(x_hat, 0, 1)
            x_hat = tf.round(x_hat * 255)
            if mode == "compress":
                mse = tf.reduce_mean(tf.math.squared_difference(
                    x * 255, x_hat))

        if mode == "train":
            # Calculate and return the rate-distortion loss: R + lambda * D.
            loss = total_bpp + self.args.lmbda * mse

            tf.summary.scalar("bpp", total_bpp)
            tf.summary.scalar("mse", mse)
            tf.summary.scalar("loss", loss)
            tf.summary.image("original", quantize_image(x))
            tf.summary.image("reconstruction", quantize_image(x_hat))

            return loss
        elif mode == "compress":
            # Create `pack` dict mapping tensors to values.
            tensors = [x_shape, y_shape, z_shape, z_string] + y_strings
            pack = [(v, v.numpy()) for v in tensors]
            return mse, total_bpp, x_hat, pack
        elif mode == "decompress":
            return x_hat
Exemple #6
0
  def call(self, x, training):
    """Computes rate and distortion losses."""
    # Build the encoder (analysis) half of the hierarchical autoencoder.
    y = self.analysis_transform(x)
    y_shape = tf.shape(y)[1:-1]

    z = self.hyper_analysis_transform(y)

    num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[1:-1]), tf.float32)

    # Build the entropy model for the hyperprior (z).
    em_z = tfc.ContinuousBatchedEntropyModel(
        self.hyperprior, coding_rank=3, compression=False,
        offset_heuristic=False)

    # When training, z_bpp is based on the noisy version of z (z_tilde).
    _, z_bits = em_z(z, training=training)
    z_bpp = tf.reduce_mean(z_bits) / num_pixels

    # Use rounding (instead of uniform noise) to modify z before passing it
    # to the hyper-synthesis transforms. Note that quantize() overrides the
    # gradient to create a straight-through estimator.
    z_hat = em_z.quantize(z)

    # Build the decoder (synthesis) half of the hierarchical autoencoder.
    latent_scales = self.hyper_synthesis_scale_transform(z_hat)
    latent_means = self.hyper_synthesis_mean_transform(z_hat)

    # Build a conditional entropy model for the slices.
    em_y = tfc.LocationScaleIndexedEntropyModel(
        tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn,
        coding_rank=3, compression=False)

    # En/Decode each slice conditioned on hyperprior and previous slices.
    y_slices = tf.split(y, self.num_slices, axis=-1)
    y_hat_slices = []
    y_bpps = []
    for slice_index, y_slice in enumerate(y_slices):
      # Model may condition on only a subset of previous slices.
      support_slices = (y_hat_slices if self.max_support_slices < 0 else
                        y_hat_slices[:self.max_support_slices])

      # Predict mu and sigma for the current slice.
      mean_support = tf.concat([latent_means] + support_slices, axis=-1)
      mu = self.cc_mean_transforms[slice_index](mean_support)
      mu = mu[:, :y_shape[0], :y_shape[1], :]

      # Note that in this implementation, `sigma` represents scale indices,
      # not actual scale values.
      scale_support = tf.concat([latent_scales] + support_slices, axis=-1)
      sigma = self.cc_scale_transforms[slice_index](scale_support)
      sigma = sigma[:, :y_shape[0], :y_shape[1], :]

      _, slice_bits = em_y(y_slice, sigma, loc=mu, training=training)
      slice_bpp = tf.reduce_mean(slice_bits) / num_pixels
      y_bpps.append(slice_bpp)

      # For the synthesis transform, use rounding. Note that quantize()
      # overrides the gradient to create a straight-through estimator.
      y_hat_slice = em_y.quantize(y_slice, loc=mu)

      # Add latent residual prediction (LRP).
      lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1)
      lrp = self.lrp_transforms[slice_index](lrp_support)
      lrp = 0.5 * tf.math.tanh(lrp)
      y_hat_slice += lrp

      y_hat_slices.append(y_hat_slice)

    # Merge slices and generate the image reconstruction.
    y_hat = tf.concat(y_hat_slices, axis=-1)
    x_hat = self.synthesis_transform(y_hat)

    # Total bpp is sum of bpp from hyperprior and all slices.
    total_bpp = tf.add_n(y_bpps + [z_bpp])

    # Mean squared error across pixels.
    # Don't clip or round pixel values while training.
    mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))

    # Calculate and return the rate-distortion loss: R + lambda * D.
    loss = total_bpp + self.lmbda * mse

    return loss, total_bpp, mse