Exemplo n.º 1
0
    def call(self, x, training):
        """Computes rate and distortion losses."""
        entropy_model = tfc.LocationScaleIndexedEntropyModel(tfc.NoisyNormal,
                                                             self.num_scales,
                                                             self.scale_fn,
                                                             coding_rank=3,
                                                             compression=False)
        side_entropy_model = tfc.ContinuousBatchedEntropyModel(
            self.hyperprior, coding_rank=3, compression=False)

        x = tf.cast(
            x, self.compute_dtype)  # TODO(jonycgn): Why is this necessary?
        y = self.analysis_transform(x)
        z = self.hyper_analysis_transform(abs(y))
        z_hat, side_bits = side_entropy_model(z, training=training)
        indexes = self.hyper_synthesis_transform(z_hat)
        y_hat, bits = entropy_model(y, indexes, training=training)
        x_hat = self.synthesis_transform(y_hat)

        # Total number of bits divided by total number of pixels.
        num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), bits.dtype)
        bpp = (tf.reduce_sum(bits) + tf.reduce_sum(side_bits)) / num_pixels
        # Mean squared error across pixels.
        mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
        mse = tf.cast(mse, bpp.dtype)
        # The rate-distortion Lagrangian.
        loss = bpp + self.lmbda * mse
        return loss, bpp, mse
Exemplo n.º 2
0
 def __init__(self, lmbda, context_length, num_filters, num_scales,
              scale_min, scale_max):
     super().__init__()
     self.lmbda = lmbda
     self.num_scales = num_scales
     offset = tf.math.log(scale_min)
     factor = (tf.math.log(scale_max) -
               tf.math.log(scale_min)) / (num_scales - 1.)
     self.context_model = ContextModel(context_length=context_length,
                                       num_filters=num_filters)
     self.scale_fn = lambda i: tf.math.exp(offset + factor * i)
     self.hyperprior = tfc.NoisyDeepFactorized(batch_shape=(num_filters, ))
     self.analysis_transform = AnalysisTransform(
         context_length=context_length, num_filters=num_filters)
     self.hyper_analysis_transform = HyperAnalysisTransform(
         context_length=context_length, num_filters=num_filters)
     self.hyper_synthesis_transform = HyperSynthesisTransform(
         context_length=context_length, num_filters=num_filters)
     self.synthesis_transform = SynthesisTransform(
         context_length=context_length, num_filters=num_filters)
     self.entropy_model = tfc.LocationScaleIndexedEntropyModel(
         tfc.NoisyNormal,
         self.num_scales,
         self.scale_fn,
         coding_rank=3,
         compression=False)
     self.side_entropy_model = tfc.ContinuousBatchedEntropyModel(
         self.hyperprior, coding_rank=3, compression=False)
Exemplo n.º 3
0
 def fit(self, *args, **kwargs):
   retval = super().fit(*args, **kwargs)
   # After training, fix range coding tables.
   self.em_z = tfc.ContinuousBatchedEntropyModel(
       self.hyperprior, coding_rank=3, compression=True,
       offset_heuristic=False)
   self.em_y = tfc.LocationScaleIndexedEntropyModel(
       tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn,
       coding_rank=3, compression=True)
   return retval
Exemplo n.º 4
0
 def fit(self, *args, **kwargs):
     retval = super().fit(*args, **kwargs)
     # After training, fix range coding tables.
     self.entropy_model = tfc.LocationScaleIndexedEntropyModel(
         tfc.NoisyNormal,
         self.num_scales,
         self.scale_fn,
         coding_rank=3,
         compression=True)
     self.side_entropy_model = tfc.ContinuousBatchedEntropyModel(
         self.hyperprior, coding_rank=3, compression=True)
     return retval
Exemplo n.º 5
0
 def call(self, x, training):
     """Computes rate and distortion losses."""
     entropy_model = tfc.ContinuousBatchedEntropyModel(self.prior,
                                                       coding_rank=3,
                                                       compression=False)
     y = self.analysis_transform(x)
     y_hat, bits = entropy_model(y, training=training)
     x_hat = self.synthesis_transform(y_hat)
     # Total number of bits divided by total number of pixels.
     num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), bits.dtype)
     bpp = tf.reduce_sum(bits) / num_pixels
     # Mean squared error across pixels.
     mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
     # The rate-distortion Lagrangian.
     loss = bpp + self.lmbda * mse
     return loss, bpp, mse
Exemplo n.º 6
0
    def _run(self, mode, x=None, bit_strings=None):
        """Run model according to `mode` (train, compress, or decompress)."""
        training = (mode == "train")

        if mode == "decompress":
            x_shape, y_shape, z_shape, z_string = bit_strings[:4]
            y_strings = bit_strings[4:]
            assert len(y_strings) == NUM_SLICES
        else:
            y_strings = []
            x_shape = tf.shape(x)[1:-1]

            # Build the encoder (analysis) half of the hierarchical autoencoder.
            y = self.analysis_transform(x)
            y_shape = tf.shape(y)[1:-1]

            z = self.hyper_analysis_transform(y)
            z_shape = tf.shape(z)[1:-1]

        if mode == "train":
            num_pixels = self.args.batchsize * self.args.patchsize**2
        else:
            num_pixels = tf.cast(tf.reduce_prod(x_shape), tf.float32)

        # Build the entropy model for the hyperprior (z).
        em_z = tfc.ContinuousBatchedEntropyModel(self.entropy_bottleneck,
                                                 coding_rank=3,
                                                 compression=not training,
                                                 no_variables=True)

        if mode != "decompress":
            # When training, z_bpp is based on the noisy version of z (z_tilde).
            _, z_bits = em_z(z, training=training)
            z_bpp = tf.reduce_mean(z_bits) / num_pixels

        if training:
            # Use rounding (instead of uniform noise) to modify z before passing it
            # to the hyper-synthesis transforms. Note that quantize() overrides the
            # gradient to create a straight-through estimator.
            z_hat = em_z.quantize(z)
            z_string = None
        else:
            if mode == "compress":
                z_string = em_z.compress(z)
            z_hat = em_z.decompress(z_string, z_shape)

        # Build the decoder (synthesis) half of the hierarchical autoencoder.
        latent_scales = self.hyper_synthesis_scale_transform(z_hat)
        latent_means = self.hyper_synthesis_mean_transform(z_hat)

        # En/Decode each slice conditioned on hyperprior and previous slices.
        y_slices = (y_strings if mode == "decompress" else tf.split(
            y, NUM_SLICES, axis=-1))
        y_hat_slices = []
        y_bpps = []
        for slice_index, y_slice in enumerate(y_slices):
            # Model may condition on only a subset of previous slices.
            support_slices = (y_hat_slices if MAX_SUPPORT_SLICES < 0 else
                              y_hat_slices[:MAX_SUPPORT_SLICES])

            # Predict mu and sigma for the current slice.
            mean_support = tf.concat([latent_means] + support_slices, axis=-1)
            mu = self.cc_mean_transforms[slice_index](mean_support)
            mu = mu[:, :y_shape[0], :y_shape[1], :]

            # Note that in this implementation, `sigma` represents scale indices,
            # not actual scale values.
            scale_support = tf.concat([latent_scales] + support_slices,
                                      axis=-1)
            sigma = self.cc_scale_transforms[slice_index](scale_support)
            sigma = sigma[:, :y_shape[0], :y_shape[1], :]

            # Build the conditional entropy model for this slice.
            em_y = tfc.LocationScaleIndexedEntropyModel(
                tfc.NoisyNormal,
                num_scales=SCALES_LEVELS,
                scale_fn=scale_fn,
                coding_rank=3,
                compression=not training,
                no_variables=True)

            if mode == "decompress":
                y_hat_slice = em_y.decompress(y_slice, sigma, loc=mu)
            else:
                _, slice_bits = em_y(y_slice, sigma, loc=mu, training=training)
                slice_bpp = tf.reduce_mean(slice_bits) / num_pixels
                y_bpps.append(slice_bpp)

                if training:
                    # For the synthesis transform, use rounding. Note that quantize()
                    # overrides the gradient to create a straight-through estimator.
                    y_hat_slice = em_y.quantize(y_slice, sigma, loc=mu)
                else:
                    assert mode == "compress"
                    slice_string = em_y.compress(y_slice, sigma, mu)
                    y_strings.append(slice_string)
                    y_hat_slice = em_y.decompress(slice_string, sigma, mu)

            # Add latent residual prediction (LRP).
            lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1)
            lrp = self.lrp_transforms[slice_index](lrp_support)
            lrp = 0.5 * tf.math.tanh(lrp)
            y_hat_slice += lrp

            y_hat_slices.append(y_hat_slice)

        # Merge slices and generate the image reconstruction.
        y_hat = tf.concat(y_hat_slices, axis=-1)
        x_hat = self.synthesis_transform(y_hat)
        x_hat = x_hat[:, :x_shape[0], :x_shape[1], :]

        if mode != "decompress":
            # Total bpp is sum of bpp from hyperprior and all slices.
            total_bpp = tf.add_n(y_bpps + [z_bpp])

        # Mean squared error across pixels.
        if training:
            # Don't clip or round pixel values while training.
            mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
            mse *= 255**2  # multiply by 255^2 to correct for rescaling
        else:
            x_hat = tf.clip_by_value(x_hat, 0, 1)
            x_hat = tf.round(x_hat * 255)
            if mode == "compress":
                mse = tf.reduce_mean(tf.math.squared_difference(
                    x * 255, x_hat))

        if mode == "train":
            # Calculate and return the rate-distortion loss: R + lambda * D.
            loss = total_bpp + self.args.lmbda * mse

            tf.summary.scalar("bpp", total_bpp)
            tf.summary.scalar("mse", mse)
            tf.summary.scalar("loss", loss)
            tf.summary.image("original", quantize_image(x))
            tf.summary.image("reconstruction", quantize_image(x_hat))

            return loss
        elif mode == "compress":
            # Create `pack` dict mapping tensors to values.
            tensors = [x_shape, y_shape, z_shape, z_string] + y_strings
            pack = [(v, v.numpy()) for v in tensors]
            return mse, total_bpp, x_hat, pack
        elif mode == "decompress":
            return x_hat
Exemplo n.º 7
0
  def call(self, x, training):
    """Computes rate and distortion losses."""
    # Build the encoder (analysis) half of the hierarchical autoencoder.
    y = self.analysis_transform(x)
    y_shape = tf.shape(y)[1:-1]

    z = self.hyper_analysis_transform(y)

    num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[1:-1]), tf.float32)

    # Build the entropy model for the hyperprior (z).
    em_z = tfc.ContinuousBatchedEntropyModel(
        self.hyperprior, coding_rank=3, compression=False,
        offset_heuristic=False)

    # When training, z_bpp is based on the noisy version of z (z_tilde).
    _, z_bits = em_z(z, training=training)
    z_bpp = tf.reduce_mean(z_bits) / num_pixels

    # Use rounding (instead of uniform noise) to modify z before passing it
    # to the hyper-synthesis transforms. Note that quantize() overrides the
    # gradient to create a straight-through estimator.
    z_hat = em_z.quantize(z)

    # Build the decoder (synthesis) half of the hierarchical autoencoder.
    latent_scales = self.hyper_synthesis_scale_transform(z_hat)
    latent_means = self.hyper_synthesis_mean_transform(z_hat)

    # Build a conditional entropy model for the slices.
    em_y = tfc.LocationScaleIndexedEntropyModel(
        tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn,
        coding_rank=3, compression=False)

    # En/Decode each slice conditioned on hyperprior and previous slices.
    y_slices = tf.split(y, self.num_slices, axis=-1)
    y_hat_slices = []
    y_bpps = []
    for slice_index, y_slice in enumerate(y_slices):
      # Model may condition on only a subset of previous slices.
      support_slices = (y_hat_slices if self.max_support_slices < 0 else
                        y_hat_slices[:self.max_support_slices])

      # Predict mu and sigma for the current slice.
      mean_support = tf.concat([latent_means] + support_slices, axis=-1)
      mu = self.cc_mean_transforms[slice_index](mean_support)
      mu = mu[:, :y_shape[0], :y_shape[1], :]

      # Note that in this implementation, `sigma` represents scale indices,
      # not actual scale values.
      scale_support = tf.concat([latent_scales] + support_slices, axis=-1)
      sigma = self.cc_scale_transforms[slice_index](scale_support)
      sigma = sigma[:, :y_shape[0], :y_shape[1], :]

      _, slice_bits = em_y(y_slice, sigma, loc=mu, training=training)
      slice_bpp = tf.reduce_mean(slice_bits) / num_pixels
      y_bpps.append(slice_bpp)

      # For the synthesis transform, use rounding. Note that quantize()
      # overrides the gradient to create a straight-through estimator.
      y_hat_slice = em_y.quantize(y_slice, loc=mu)

      # Add latent residual prediction (LRP).
      lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1)
      lrp = self.lrp_transforms[slice_index](lrp_support)
      lrp = 0.5 * tf.math.tanh(lrp)
      y_hat_slice += lrp

      y_hat_slices.append(y_hat_slice)

    # Merge slices and generate the image reconstruction.
    y_hat = tf.concat(y_hat_slices, axis=-1)
    x_hat = self.synthesis_transform(y_hat)

    # Total bpp is sum of bpp from hyperprior and all slices.
    total_bpp = tf.add_n(y_bpps + [z_bpp])

    # Mean squared error across pixels.
    # Don't clip or round pixel values while training.
    mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))

    # Calculate and return the rate-distortion loss: R + lambda * D.
    loss = total_bpp + self.lmbda * mse

    return loss, total_bpp, mse
Exemplo n.º 8
0
 def fit(self, *args, **kwargs):
   retval = super().fit(*args, **kwargs)
   # After training, fix range coding tables.
   self.entropy_model = tfc.ContinuousBatchedEntropyModel(
       self.prior, coding_rank=3, compression=True)
   return retval
Exemplo n.º 9
0
    def _run(self, mode, x=None,feature1=None,feature2=None,feature3=None,feature4=None,
        feature5=None,feature6=None,feature7=None,feature8=None,bit_strings=None):
        """Run model according to `mode` (train, compress, or decompress)."""
        training = (mode == "train")

        if mode == "decompress":
            x_shape,x_encoded_shape,disp_encoded_shape = bit_strings[:3]
            x_encoded_string,disp1_encoded_string,disp2_encoded_string = bit_strings[3:6]
            disp3_encoded_string,disp4_encoded_string,disp5_encoded_string = bit_strings[6:9]
            disp6_encoded_string,disp7_encoded_string,disp8_encoded_string = bit_strings[9:]

        else:
            x_shape = tf.shape(x)[1:-1]

            # Build the encoder (analysis) half of the color module.
            x_encoded = self.color_analysis_transform(x)

            # Build the encoder (analysis) half of the disparity module.
            disp1_encoded = self.disp1_analysis_transform(feature1)
            disp2_encoded = self.disp2_analysis_transform(feature2)
            disp3_encoded = self.disp3_analysis_transform(feature3)
            disp4_encoded = self.disp4_analysis_transform(feature4)
            disp5_encoded = self.disp5_analysis_transform(feature5)
            disp6_encoded = self.disp6_analysis_transform(feature6)
            disp7_encoded = self.disp7_analysis_transform(feature7) 
            disp8_encoded = self.disp8_analysis_transform(feature8)

            x_encoded_shape = tf.shape(x_encoded)[1:-1]
            disp_encoded_shape = tf.shape(disp1_encoded)[1:-1]

        if mode == "train":
            num_pixels = tf.cast(self.args.batch_size * 8 * 64 ** 2, tf.float32)
            num_pixels_disp = num_pixels/2.0
        else:
            num_pixels = tf.cast(tf.reduce_prod(x_shape), tf.float32)
            num_pixels_disp = num_pixels/2.0

        # Build the entropy models for the latents.
        em_color = tfc.ContinuousBatchedEntropyModel(
            self.entropy_bottleneck_color, coding_rank=4,
            compression=not training, no_variables=True)
        em_disp = tfc.ContinuousBatchedEntropyModel(
            self.entropy_bottleneck_disp, coding_rank=4,
            compression=not training, no_variables=True)

        if mode != "decompress":
            # When training, *_bpp is based on the noisy version of the latents.
            _, x_encoded_bits = em_color(x_encoded, training=training)
            _, disp1_encoded_bits = em_disp(disp1_encoded, training=training)
            _, disp2_encoded_bits = em_disp(disp2_encoded, training=training)
            _, disp3_encoded_bits = em_disp(disp3_encoded, training=training)
            _, disp4_encoded_bits = em_disp(disp4_encoded, training=training)
            _, disp5_encoded_bits = em_disp(disp5_encoded, training=training)
            _, disp6_encoded_bits = em_disp(disp6_encoded, training=training)
            _, disp7_encoded_bits = em_disp(disp7_encoded, training=training)
            _, disp8_encoded_bits = em_disp(disp8_encoded, training=training)
            x_encoded_bpp = tf.reduce_mean(x_encoded_bits) / num_pixels
            disp1_encoded_bpp = tf.reduce_mean(disp1_encoded_bits) / num_pixels_disp
            disp2_encoded_bpp = tf.reduce_mean(disp2_encoded_bits) / num_pixels_disp
            disp3_encoded_bpp = tf.reduce_mean(disp3_encoded_bits) / num_pixels_disp
            disp4_encoded_bpp = tf.reduce_mean(disp4_encoded_bits) / num_pixels_disp
            disp5_encoded_bpp = tf.reduce_mean(disp5_encoded_bits) / num_pixels_disp
            disp6_encoded_bpp = tf.reduce_mean(disp6_encoded_bits) / num_pixels_disp
            disp7_encoded_bpp = tf.reduce_mean(disp7_encoded_bits) / num_pixels_disp
            disp8_encoded_bpp = tf.reduce_mean(disp8_encoded_bits) / num_pixels_disp
            total_bpp = x_encoded_bpp+disp1_encoded_bpp+disp2_encoded_bpp+disp3_encoded_bpp+ \
            disp4_encoded_bpp+disp5_encoded_bpp+disp6_encoded_bpp+disp7_encoded_bpp+disp8_encoded_bpp


        if training:
            # Use rounding (instead of uniform noise) to modify latents before passing them
            # to their respective synthesis transforms. Note that quantize() overrides the
            # gradient to create a straight-through estimator.
            x_encoded_hat = em_color.quantize(x_encoded)
            disp1_encoded_hat = em_disp.quantize(disp1_encoded)
            disp2_encoded_hat = em_disp.quantize(disp2_encoded)
            disp3_encoded_hat = em_disp.quantize(disp3_encoded)
            disp4_encoded_hat = em_disp.quantize(disp4_encoded)
            disp5_encoded_hat = em_disp.quantize(disp5_encoded)
            disp6_encoded_hat = em_disp.quantize(disp6_encoded)
            disp7_encoded_hat = em_disp.quantize(disp7_encoded)
            disp8_encoded_hat = em_disp.quantize(disp8_encoded)
            x_encoded_string = None
            disp1_encoded_string = None
            disp2_encoded_string = None
            disp3_encoded_string = None
            disp4_encoded_string = None
            disp5_encoded_string = None
            disp6_encoded_string = None
            disp7_encoded_string = None
            disp8_encoded_string = None
        else:
            if mode == "compress":
                x_encoded_string = em_color.compress(x_encoded)
                disp1_encoded_string = em_disp.compress(disp1_encoded)
                disp2_encoded_string = em_disp.compress(disp2_encoded)
                disp3_encoded_string = em_disp.compress(disp3_encoded)
                disp4_encoded_string = em_disp.compress(disp4_encoded)
                disp5_encoded_string = em_disp.compress(disp5_encoded)
                disp6_encoded_string = em_disp.compress(disp6_encoded)
                disp7_encoded_string = em_disp.compress(disp7_encoded)
                disp8_encoded_string = em_disp.compress(disp8_encoded)
            x_encoded_hat = em_color.decompress(x_encoded_string, x_encoded_shape)
            disp1_encoded_hat = em_disp.decompress(disp1_encoded_string, disp_encoded_shape)
            disp2_encoded_hat = em_disp.decompress(disp2_encoded_string, disp_encoded_shape)
            disp3_encoded_hat = em_disp.decompress(disp3_encoded_string, disp_encoded_shape)
            disp4_encoded_hat = em_disp.decompress(disp4_encoded_string, disp_encoded_shape)
            disp5_encoded_hat = em_disp.decompress(disp5_encoded_string, disp_encoded_shape)
            disp6_encoded_hat = em_disp.decompress(disp6_encoded_string, disp_encoded_shape)
            disp7_encoded_hat = em_disp.decompress(disp7_encoded_string, disp_encoded_shape)
            disp8_encoded_hat = em_disp.decompress(disp8_encoded_string, disp_encoded_shape)

        # Build the decoder (synthesis) half of the color module.
        x_tilde = self.color_synthesis_transform(x_encoded_hat)

        # Build the decoder (synthesis) half of the disparity module.
        dispmap1 = tf.squeeze(self.disp1_synthesis_transform(disp1_encoded_hat),axis=1)
        dispmap2 = tf.squeeze(self.disp2_synthesis_transform(disp2_encoded_hat),axis=1)
        dispmap3 = tf.squeeze(self.disp3_synthesis_transform(disp3_encoded_hat),axis=1)
        dispmap4 = tf.squeeze(self.disp4_synthesis_transform(disp4_encoded_hat),axis=1)
        dispmap5 = tf.squeeze(self.disp5_synthesis_transform(disp5_encoded_hat),axis=1)
        dispmap6 = tf.squeeze(self.disp6_synthesis_transform(disp6_encoded_hat),axis=1)
        dispmap7 = tf.squeeze(self.disp7_synthesis_transform(disp7_encoded_hat),axis=1)
        dispmap8 = tf.squeeze(self.disp8_synthesis_transform(disp8_encoded_hat),axis=1)

        # Perform warping with disparity maps of respective slices of x_tilde.
        x_hat1 = dense_image_warp(x_tilde[:,0,:,:,:],dispmap1)        
        x_hat2 = dense_image_warp(x_tilde[:,1,:,:,:],dispmap2)       
        x_hat3 = dense_image_warp(x_tilde[:,2,:,:,:],dispmap3)       
        x_hat4 = dense_image_warp(x_tilde[:,3,:,:,:],dispmap4)        
        x_hat5 = dense_image_warp(x_tilde[:,4,:,:,:],dispmap5)        
        x_hat6 = dense_image_warp(x_tilde[:,5,:,:,:],dispmap6)        
        x_hat7 = dense_image_warp(x_tilde[:,6,:,:,:],dispmap7)        
        x_hat8 = dense_image_warp(x_tilde[:,7,:,:,:],dispmap8)
        x_hat = tf.stack([x_hat1,x_hat2,x_hat3,x_hat4,x_hat5,x_hat6,x_hat7,x_hat8], axis = 1)

        # Mean squared error across pixels.
        if training:
            # Don't clip or round pixel values while training.
            mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
            mse *= 255 ** 2  # multiply by 255^2 to correct for rescaling
        else:
            x_hat = tf.clip_by_value(x_hat, 0, 1)
            x_hat = tf.round(x_hat * 255)
            if mode == "compress":
                mse = tf.reduce_mean(tf.math.squared_difference(x * 255, x_hat))
        if mode == "train":
            # Calculate and return the rate-distortion loss: R + lmbda * D.
            loss = total_bpp + self.args.lmbda * mse
            return loss,total_bpp,mse
        elif mode == "compress":
            # Create `pack` dict mapping tensors to values.
            tensors = [x_shape,x_encoded_shape,disp_encoded_shape,x_encoded_string,disp1_encoded_string,
                      disp2_encoded_string,disp3_encoded_string,disp4_encoded_string,disp5_encoded_string,
                      disp6_encoded_string,disp7_encoded_string,disp8_encoded_string]
            pack = [(v, v.numpy()) for v in tensors]
            return mse, total_bpp, x_hat, pack
        elif mode == "decompress":
            return x_hat