def call(self, x, training): """Computes rate and distortion losses.""" entropy_model = tfc.LocationScaleIndexedEntropyModel(tfc.NoisyNormal, self.num_scales, self.scale_fn, coding_rank=3, compression=False) side_entropy_model = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=False) x = tf.cast( x, self.compute_dtype) # TODO(jonycgn): Why is this necessary? y = self.analysis_transform(x) z = self.hyper_analysis_transform(abs(y)) z_hat, side_bits = side_entropy_model(z, training=training) indexes = self.hyper_synthesis_transform(z_hat) y_hat, bits = entropy_model(y, indexes, training=training) x_hat = self.synthesis_transform(y_hat) # Total number of bits divided by total number of pixels. num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), bits.dtype) bpp = (tf.reduce_sum(bits) + tf.reduce_sum(side_bits)) / num_pixels # Mean squared error across pixels. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) mse = tf.cast(mse, bpp.dtype) # The rate-distortion Lagrangian. loss = bpp + self.lmbda * mse return loss, bpp, mse
def __init__(self, lmbda, context_length, num_filters, num_scales, scale_min, scale_max): super().__init__() self.lmbda = lmbda self.num_scales = num_scales offset = tf.math.log(scale_min) factor = (tf.math.log(scale_max) - tf.math.log(scale_min)) / (num_scales - 1.) self.context_model = ContextModel(context_length=context_length, num_filters=num_filters) self.scale_fn = lambda i: tf.math.exp(offset + factor * i) self.hyperprior = tfc.NoisyDeepFactorized(batch_shape=(num_filters, )) self.analysis_transform = AnalysisTransform( context_length=context_length, num_filters=num_filters) self.hyper_analysis_transform = HyperAnalysisTransform( context_length=context_length, num_filters=num_filters) self.hyper_synthesis_transform = HyperSynthesisTransform( context_length=context_length, num_filters=num_filters) self.synthesis_transform = SynthesisTransform( context_length=context_length, num_filters=num_filters) self.entropy_model = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, self.num_scales, self.scale_fn, coding_rank=3, compression=False) self.side_entropy_model = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=False)
def fit(self, *args, **kwargs): retval = super().fit(*args, **kwargs) # After training, fix range coding tables. self.em_z = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=True, offset_heuristic=False) self.em_y = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn, coding_rank=3, compression=True) return retval
def fit(self, *args, **kwargs): retval = super().fit(*args, **kwargs) # After training, fix range coding tables. self.entropy_model = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, self.num_scales, self.scale_fn, coding_rank=3, compression=True) self.side_entropy_model = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=True) return retval
def call(self, x, training): """Computes rate and distortion losses.""" entropy_model = tfc.ContinuousBatchedEntropyModel(self.prior, coding_rank=3, compression=False) y = self.analysis_transform(x) y_hat, bits = entropy_model(y, training=training) x_hat = self.synthesis_transform(y_hat) # Total number of bits divided by total number of pixels. num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), bits.dtype) bpp = tf.reduce_sum(bits) / num_pixels # Mean squared error across pixels. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) # The rate-distortion Lagrangian. loss = bpp + self.lmbda * mse return loss, bpp, mse
def _run(self, mode, x=None, bit_strings=None): """Run model according to `mode` (train, compress, or decompress).""" training = (mode == "train") if mode == "decompress": x_shape, y_shape, z_shape, z_string = bit_strings[:4] y_strings = bit_strings[4:] assert len(y_strings) == NUM_SLICES else: y_strings = [] x_shape = tf.shape(x)[1:-1] # Build the encoder (analysis) half of the hierarchical autoencoder. y = self.analysis_transform(x) y_shape = tf.shape(y)[1:-1] z = self.hyper_analysis_transform(y) z_shape = tf.shape(z)[1:-1] if mode == "train": num_pixels = self.args.batchsize * self.args.patchsize**2 else: num_pixels = tf.cast(tf.reduce_prod(x_shape), tf.float32) # Build the entropy model for the hyperprior (z). em_z = tfc.ContinuousBatchedEntropyModel(self.entropy_bottleneck, coding_rank=3, compression=not training, no_variables=True) if mode != "decompress": # When training, z_bpp is based on the noisy version of z (z_tilde). _, z_bits = em_z(z, training=training) z_bpp = tf.reduce_mean(z_bits) / num_pixels if training: # Use rounding (instead of uniform noise) to modify z before passing it # to the hyper-synthesis transforms. Note that quantize() overrides the # gradient to create a straight-through estimator. z_hat = em_z.quantize(z) z_string = None else: if mode == "compress": z_string = em_z.compress(z) z_hat = em_z.decompress(z_string, z_shape) # Build the decoder (synthesis) half of the hierarchical autoencoder. latent_scales = self.hyper_synthesis_scale_transform(z_hat) latent_means = self.hyper_synthesis_mean_transform(z_hat) # En/Decode each slice conditioned on hyperprior and previous slices. y_slices = (y_strings if mode == "decompress" else tf.split( y, NUM_SLICES, axis=-1)) y_hat_slices = [] y_bpps = [] for slice_index, y_slice in enumerate(y_slices): # Model may condition on only a subset of previous slices. support_slices = (y_hat_slices if MAX_SUPPORT_SLICES < 0 else y_hat_slices[:MAX_SUPPORT_SLICES]) # Predict mu and sigma for the current slice. mean_support = tf.concat([latent_means] + support_slices, axis=-1) mu = self.cc_mean_transforms[slice_index](mean_support) mu = mu[:, :y_shape[0], :y_shape[1], :] # Note that in this implementation, `sigma` represents scale indices, # not actual scale values. scale_support = tf.concat([latent_scales] + support_slices, axis=-1) sigma = self.cc_scale_transforms[slice_index](scale_support) sigma = sigma[:, :y_shape[0], :y_shape[1], :] # Build the conditional entropy model for this slice. em_y = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, num_scales=SCALES_LEVELS, scale_fn=scale_fn, coding_rank=3, compression=not training, no_variables=True) if mode == "decompress": y_hat_slice = em_y.decompress(y_slice, sigma, loc=mu) else: _, slice_bits = em_y(y_slice, sigma, loc=mu, training=training) slice_bpp = tf.reduce_mean(slice_bits) / num_pixels y_bpps.append(slice_bpp) if training: # For the synthesis transform, use rounding. Note that quantize() # overrides the gradient to create a straight-through estimator. y_hat_slice = em_y.quantize(y_slice, sigma, loc=mu) else: assert mode == "compress" slice_string = em_y.compress(y_slice, sigma, mu) y_strings.append(slice_string) y_hat_slice = em_y.decompress(slice_string, sigma, mu) # Add latent residual prediction (LRP). lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1) lrp = self.lrp_transforms[slice_index](lrp_support) lrp = 0.5 * tf.math.tanh(lrp) y_hat_slice += lrp y_hat_slices.append(y_hat_slice) # Merge slices and generate the image reconstruction. y_hat = tf.concat(y_hat_slices, axis=-1) x_hat = self.synthesis_transform(y_hat) x_hat = x_hat[:, :x_shape[0], :x_shape[1], :] if mode != "decompress": # Total bpp is sum of bpp from hyperprior and all slices. total_bpp = tf.add_n(y_bpps + [z_bpp]) # Mean squared error across pixels. if training: # Don't clip or round pixel values while training. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) mse *= 255**2 # multiply by 255^2 to correct for rescaling else: x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) if mode == "compress": mse = tf.reduce_mean(tf.math.squared_difference( x * 255, x_hat)) if mode == "train": # Calculate and return the rate-distortion loss: R + lambda * D. loss = total_bpp + self.args.lmbda * mse tf.summary.scalar("bpp", total_bpp) tf.summary.scalar("mse", mse) tf.summary.scalar("loss", loss) tf.summary.image("original", quantize_image(x)) tf.summary.image("reconstruction", quantize_image(x_hat)) return loss elif mode == "compress": # Create `pack` dict mapping tensors to values. tensors = [x_shape, y_shape, z_shape, z_string] + y_strings pack = [(v, v.numpy()) for v in tensors] return mse, total_bpp, x_hat, pack elif mode == "decompress": return x_hat
def call(self, x, training): """Computes rate and distortion losses.""" # Build the encoder (analysis) half of the hierarchical autoencoder. y = self.analysis_transform(x) y_shape = tf.shape(y)[1:-1] z = self.hyper_analysis_transform(y) num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[1:-1]), tf.float32) # Build the entropy model for the hyperprior (z). em_z = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=False, offset_heuristic=False) # When training, z_bpp is based on the noisy version of z (z_tilde). _, z_bits = em_z(z, training=training) z_bpp = tf.reduce_mean(z_bits) / num_pixels # Use rounding (instead of uniform noise) to modify z before passing it # to the hyper-synthesis transforms. Note that quantize() overrides the # gradient to create a straight-through estimator. z_hat = em_z.quantize(z) # Build the decoder (synthesis) half of the hierarchical autoencoder. latent_scales = self.hyper_synthesis_scale_transform(z_hat) latent_means = self.hyper_synthesis_mean_transform(z_hat) # Build a conditional entropy model for the slices. em_y = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn, coding_rank=3, compression=False) # En/Decode each slice conditioned on hyperprior and previous slices. y_slices = tf.split(y, self.num_slices, axis=-1) y_hat_slices = [] y_bpps = [] for slice_index, y_slice in enumerate(y_slices): # Model may condition on only a subset of previous slices. support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices]) # Predict mu and sigma for the current slice. mean_support = tf.concat([latent_means] + support_slices, axis=-1) mu = self.cc_mean_transforms[slice_index](mean_support) mu = mu[:, :y_shape[0], :y_shape[1], :] # Note that in this implementation, `sigma` represents scale indices, # not actual scale values. scale_support = tf.concat([latent_scales] + support_slices, axis=-1) sigma = self.cc_scale_transforms[slice_index](scale_support) sigma = sigma[:, :y_shape[0], :y_shape[1], :] _, slice_bits = em_y(y_slice, sigma, loc=mu, training=training) slice_bpp = tf.reduce_mean(slice_bits) / num_pixels y_bpps.append(slice_bpp) # For the synthesis transform, use rounding. Note that quantize() # overrides the gradient to create a straight-through estimator. y_hat_slice = em_y.quantize(y_slice, loc=mu) # Add latent residual prediction (LRP). lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1) lrp = self.lrp_transforms[slice_index](lrp_support) lrp = 0.5 * tf.math.tanh(lrp) y_hat_slice += lrp y_hat_slices.append(y_hat_slice) # Merge slices and generate the image reconstruction. y_hat = tf.concat(y_hat_slices, axis=-1) x_hat = self.synthesis_transform(y_hat) # Total bpp is sum of bpp from hyperprior and all slices. total_bpp = tf.add_n(y_bpps + [z_bpp]) # Mean squared error across pixels. # Don't clip or round pixel values while training. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) # Calculate and return the rate-distortion loss: R + lambda * D. loss = total_bpp + self.lmbda * mse return loss, total_bpp, mse
def fit(self, *args, **kwargs): retval = super().fit(*args, **kwargs) # After training, fix range coding tables. self.entropy_model = tfc.ContinuousBatchedEntropyModel( self.prior, coding_rank=3, compression=True) return retval
def _run(self, mode, x=None,feature1=None,feature2=None,feature3=None,feature4=None, feature5=None,feature6=None,feature7=None,feature8=None,bit_strings=None): """Run model according to `mode` (train, compress, or decompress).""" training = (mode == "train") if mode == "decompress": x_shape,x_encoded_shape,disp_encoded_shape = bit_strings[:3] x_encoded_string,disp1_encoded_string,disp2_encoded_string = bit_strings[3:6] disp3_encoded_string,disp4_encoded_string,disp5_encoded_string = bit_strings[6:9] disp6_encoded_string,disp7_encoded_string,disp8_encoded_string = bit_strings[9:] else: x_shape = tf.shape(x)[1:-1] # Build the encoder (analysis) half of the color module. x_encoded = self.color_analysis_transform(x) # Build the encoder (analysis) half of the disparity module. disp1_encoded = self.disp1_analysis_transform(feature1) disp2_encoded = self.disp2_analysis_transform(feature2) disp3_encoded = self.disp3_analysis_transform(feature3) disp4_encoded = self.disp4_analysis_transform(feature4) disp5_encoded = self.disp5_analysis_transform(feature5) disp6_encoded = self.disp6_analysis_transform(feature6) disp7_encoded = self.disp7_analysis_transform(feature7) disp8_encoded = self.disp8_analysis_transform(feature8) x_encoded_shape = tf.shape(x_encoded)[1:-1] disp_encoded_shape = tf.shape(disp1_encoded)[1:-1] if mode == "train": num_pixels = tf.cast(self.args.batch_size * 8 * 64 ** 2, tf.float32) num_pixels_disp = num_pixels/2.0 else: num_pixels = tf.cast(tf.reduce_prod(x_shape), tf.float32) num_pixels_disp = num_pixels/2.0 # Build the entropy models for the latents. em_color = tfc.ContinuousBatchedEntropyModel( self.entropy_bottleneck_color, coding_rank=4, compression=not training, no_variables=True) em_disp = tfc.ContinuousBatchedEntropyModel( self.entropy_bottleneck_disp, coding_rank=4, compression=not training, no_variables=True) if mode != "decompress": # When training, *_bpp is based on the noisy version of the latents. _, x_encoded_bits = em_color(x_encoded, training=training) _, disp1_encoded_bits = em_disp(disp1_encoded, training=training) _, disp2_encoded_bits = em_disp(disp2_encoded, training=training) _, disp3_encoded_bits = em_disp(disp3_encoded, training=training) _, disp4_encoded_bits = em_disp(disp4_encoded, training=training) _, disp5_encoded_bits = em_disp(disp5_encoded, training=training) _, disp6_encoded_bits = em_disp(disp6_encoded, training=training) _, disp7_encoded_bits = em_disp(disp7_encoded, training=training) _, disp8_encoded_bits = em_disp(disp8_encoded, training=training) x_encoded_bpp = tf.reduce_mean(x_encoded_bits) / num_pixels disp1_encoded_bpp = tf.reduce_mean(disp1_encoded_bits) / num_pixels_disp disp2_encoded_bpp = tf.reduce_mean(disp2_encoded_bits) / num_pixels_disp disp3_encoded_bpp = tf.reduce_mean(disp3_encoded_bits) / num_pixels_disp disp4_encoded_bpp = tf.reduce_mean(disp4_encoded_bits) / num_pixels_disp disp5_encoded_bpp = tf.reduce_mean(disp5_encoded_bits) / num_pixels_disp disp6_encoded_bpp = tf.reduce_mean(disp6_encoded_bits) / num_pixels_disp disp7_encoded_bpp = tf.reduce_mean(disp7_encoded_bits) / num_pixels_disp disp8_encoded_bpp = tf.reduce_mean(disp8_encoded_bits) / num_pixels_disp total_bpp = x_encoded_bpp+disp1_encoded_bpp+disp2_encoded_bpp+disp3_encoded_bpp+ \ disp4_encoded_bpp+disp5_encoded_bpp+disp6_encoded_bpp+disp7_encoded_bpp+disp8_encoded_bpp if training: # Use rounding (instead of uniform noise) to modify latents before passing them # to their respective synthesis transforms. Note that quantize() overrides the # gradient to create a straight-through estimator. x_encoded_hat = em_color.quantize(x_encoded) disp1_encoded_hat = em_disp.quantize(disp1_encoded) disp2_encoded_hat = em_disp.quantize(disp2_encoded) disp3_encoded_hat = em_disp.quantize(disp3_encoded) disp4_encoded_hat = em_disp.quantize(disp4_encoded) disp5_encoded_hat = em_disp.quantize(disp5_encoded) disp6_encoded_hat = em_disp.quantize(disp6_encoded) disp7_encoded_hat = em_disp.quantize(disp7_encoded) disp8_encoded_hat = em_disp.quantize(disp8_encoded) x_encoded_string = None disp1_encoded_string = None disp2_encoded_string = None disp3_encoded_string = None disp4_encoded_string = None disp5_encoded_string = None disp6_encoded_string = None disp7_encoded_string = None disp8_encoded_string = None else: if mode == "compress": x_encoded_string = em_color.compress(x_encoded) disp1_encoded_string = em_disp.compress(disp1_encoded) disp2_encoded_string = em_disp.compress(disp2_encoded) disp3_encoded_string = em_disp.compress(disp3_encoded) disp4_encoded_string = em_disp.compress(disp4_encoded) disp5_encoded_string = em_disp.compress(disp5_encoded) disp6_encoded_string = em_disp.compress(disp6_encoded) disp7_encoded_string = em_disp.compress(disp7_encoded) disp8_encoded_string = em_disp.compress(disp8_encoded) x_encoded_hat = em_color.decompress(x_encoded_string, x_encoded_shape) disp1_encoded_hat = em_disp.decompress(disp1_encoded_string, disp_encoded_shape) disp2_encoded_hat = em_disp.decompress(disp2_encoded_string, disp_encoded_shape) disp3_encoded_hat = em_disp.decompress(disp3_encoded_string, disp_encoded_shape) disp4_encoded_hat = em_disp.decompress(disp4_encoded_string, disp_encoded_shape) disp5_encoded_hat = em_disp.decompress(disp5_encoded_string, disp_encoded_shape) disp6_encoded_hat = em_disp.decompress(disp6_encoded_string, disp_encoded_shape) disp7_encoded_hat = em_disp.decompress(disp7_encoded_string, disp_encoded_shape) disp8_encoded_hat = em_disp.decompress(disp8_encoded_string, disp_encoded_shape) # Build the decoder (synthesis) half of the color module. x_tilde = self.color_synthesis_transform(x_encoded_hat) # Build the decoder (synthesis) half of the disparity module. dispmap1 = tf.squeeze(self.disp1_synthesis_transform(disp1_encoded_hat),axis=1) dispmap2 = tf.squeeze(self.disp2_synthesis_transform(disp2_encoded_hat),axis=1) dispmap3 = tf.squeeze(self.disp3_synthesis_transform(disp3_encoded_hat),axis=1) dispmap4 = tf.squeeze(self.disp4_synthesis_transform(disp4_encoded_hat),axis=1) dispmap5 = tf.squeeze(self.disp5_synthesis_transform(disp5_encoded_hat),axis=1) dispmap6 = tf.squeeze(self.disp6_synthesis_transform(disp6_encoded_hat),axis=1) dispmap7 = tf.squeeze(self.disp7_synthesis_transform(disp7_encoded_hat),axis=1) dispmap8 = tf.squeeze(self.disp8_synthesis_transform(disp8_encoded_hat),axis=1) # Perform warping with disparity maps of respective slices of x_tilde. x_hat1 = dense_image_warp(x_tilde[:,0,:,:,:],dispmap1) x_hat2 = dense_image_warp(x_tilde[:,1,:,:,:],dispmap2) x_hat3 = dense_image_warp(x_tilde[:,2,:,:,:],dispmap3) x_hat4 = dense_image_warp(x_tilde[:,3,:,:,:],dispmap4) x_hat5 = dense_image_warp(x_tilde[:,4,:,:,:],dispmap5) x_hat6 = dense_image_warp(x_tilde[:,5,:,:,:],dispmap6) x_hat7 = dense_image_warp(x_tilde[:,6,:,:,:],dispmap7) x_hat8 = dense_image_warp(x_tilde[:,7,:,:,:],dispmap8) x_hat = tf.stack([x_hat1,x_hat2,x_hat3,x_hat4,x_hat5,x_hat6,x_hat7,x_hat8], axis = 1) # Mean squared error across pixels. if training: # Don't clip or round pixel values while training. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) mse *= 255 ** 2 # multiply by 255^2 to correct for rescaling else: x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) if mode == "compress": mse = tf.reduce_mean(tf.math.squared_difference(x * 255, x_hat)) if mode == "train": # Calculate and return the rate-distortion loss: R + lmbda * D. loss = total_bpp + self.args.lmbda * mse return loss,total_bpp,mse elif mode == "compress": # Create `pack` dict mapping tensors to values. tensors = [x_shape,x_encoded_shape,disp_encoded_shape,x_encoded_string,disp1_encoded_string, disp2_encoded_string,disp3_encoded_string,disp4_encoded_string,disp5_encoded_string, disp6_encoded_string,disp7_encoded_string,disp8_encoded_string] pack = [(v, v.numpy()) for v in tensors] return mse, total_bpp, x_hat, pack elif mode == "decompress": return x_hat