def call(self, x, training): """Computes rate and distortion losses.""" entropy_model = tfc.LocationScaleIndexedEntropyModel(tfc.NoisyNormal, self.num_scales, self.scale_fn, coding_rank=3, compression=False) side_entropy_model = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=False) x = tf.cast( x, self.compute_dtype) # TODO(jonycgn): Why is this necessary? y = self.analysis_transform(x) z = self.hyper_analysis_transform(abs(y)) z_hat, side_bits = side_entropy_model(z, training=training) indexes = self.hyper_synthesis_transform(z_hat) y_hat, bits = entropy_model(y, indexes, training=training) x_hat = self.synthesis_transform(y_hat) # Total number of bits divided by total number of pixels. num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), bits.dtype) bpp = (tf.reduce_sum(bits) + tf.reduce_sum(side_bits)) / num_pixels # Mean squared error across pixels. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) mse = tf.cast(mse, bpp.dtype) # The rate-distortion Lagrangian. loss = bpp + self.lmbda * mse return loss, bpp, mse
def __init__(self, lmbda, context_length, num_filters, num_scales, scale_min, scale_max): super().__init__() self.lmbda = lmbda self.num_scales = num_scales offset = tf.math.log(scale_min) factor = (tf.math.log(scale_max) - tf.math.log(scale_min)) / (num_scales - 1.) self.context_model = ContextModel(context_length=context_length, num_filters=num_filters) self.scale_fn = lambda i: tf.math.exp(offset + factor * i) self.hyperprior = tfc.NoisyDeepFactorized(batch_shape=(num_filters, )) self.analysis_transform = AnalysisTransform( context_length=context_length, num_filters=num_filters) self.hyper_analysis_transform = HyperAnalysisTransform( context_length=context_length, num_filters=num_filters) self.hyper_synthesis_transform = HyperSynthesisTransform( context_length=context_length, num_filters=num_filters) self.synthesis_transform = SynthesisTransform( context_length=context_length, num_filters=num_filters) self.entropy_model = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, self.num_scales, self.scale_fn, coding_rank=3, compression=False) self.side_entropy_model = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=False)
def fit(self, *args, **kwargs): retval = super().fit(*args, **kwargs) # After training, fix range coding tables. self.em_z = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=True, offset_heuristic=False) self.em_y = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn, coding_rank=3, compression=True) return retval
def fit(self, *args, **kwargs): retval = super().fit(*args, **kwargs) # After training, fix range coding tables. self.entropy_model = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, self.num_scales, self.scale_fn, coding_rank=3, compression=True) self.side_entropy_model = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=True) return retval
def _run(self, mode, x=None, bit_strings=None): """Run model according to `mode` (train, compress, or decompress).""" training = (mode == "train") if mode == "decompress": x_shape, y_shape, z_shape, z_string = bit_strings[:4] y_strings = bit_strings[4:] assert len(y_strings) == NUM_SLICES else: y_strings = [] x_shape = tf.shape(x)[1:-1] # Build the encoder (analysis) half of the hierarchical autoencoder. y = self.analysis_transform(x) y_shape = tf.shape(y)[1:-1] z = self.hyper_analysis_transform(y) z_shape = tf.shape(z)[1:-1] if mode == "train": num_pixels = self.args.batchsize * self.args.patchsize**2 else: num_pixels = tf.cast(tf.reduce_prod(x_shape), tf.float32) # Build the entropy model for the hyperprior (z). em_z = tfc.ContinuousBatchedEntropyModel(self.entropy_bottleneck, coding_rank=3, compression=not training, no_variables=True) if mode != "decompress": # When training, z_bpp is based on the noisy version of z (z_tilde). _, z_bits = em_z(z, training=training) z_bpp = tf.reduce_mean(z_bits) / num_pixels if training: # Use rounding (instead of uniform noise) to modify z before passing it # to the hyper-synthesis transforms. Note that quantize() overrides the # gradient to create a straight-through estimator. z_hat = em_z.quantize(z) z_string = None else: if mode == "compress": z_string = em_z.compress(z) z_hat = em_z.decompress(z_string, z_shape) # Build the decoder (synthesis) half of the hierarchical autoencoder. latent_scales = self.hyper_synthesis_scale_transform(z_hat) latent_means = self.hyper_synthesis_mean_transform(z_hat) # En/Decode each slice conditioned on hyperprior and previous slices. y_slices = (y_strings if mode == "decompress" else tf.split( y, NUM_SLICES, axis=-1)) y_hat_slices = [] y_bpps = [] for slice_index, y_slice in enumerate(y_slices): # Model may condition on only a subset of previous slices. support_slices = (y_hat_slices if MAX_SUPPORT_SLICES < 0 else y_hat_slices[:MAX_SUPPORT_SLICES]) # Predict mu and sigma for the current slice. mean_support = tf.concat([latent_means] + support_slices, axis=-1) mu = self.cc_mean_transforms[slice_index](mean_support) mu = mu[:, :y_shape[0], :y_shape[1], :] # Note that in this implementation, `sigma` represents scale indices, # not actual scale values. scale_support = tf.concat([latent_scales] + support_slices, axis=-1) sigma = self.cc_scale_transforms[slice_index](scale_support) sigma = sigma[:, :y_shape[0], :y_shape[1], :] # Build the conditional entropy model for this slice. em_y = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, num_scales=SCALES_LEVELS, scale_fn=scale_fn, coding_rank=3, compression=not training, no_variables=True) if mode == "decompress": y_hat_slice = em_y.decompress(y_slice, sigma, loc=mu) else: _, slice_bits = em_y(y_slice, sigma, loc=mu, training=training) slice_bpp = tf.reduce_mean(slice_bits) / num_pixels y_bpps.append(slice_bpp) if training: # For the synthesis transform, use rounding. Note that quantize() # overrides the gradient to create a straight-through estimator. y_hat_slice = em_y.quantize(y_slice, sigma, loc=mu) else: assert mode == "compress" slice_string = em_y.compress(y_slice, sigma, mu) y_strings.append(slice_string) y_hat_slice = em_y.decompress(slice_string, sigma, mu) # Add latent residual prediction (LRP). lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1) lrp = self.lrp_transforms[slice_index](lrp_support) lrp = 0.5 * tf.math.tanh(lrp) y_hat_slice += lrp y_hat_slices.append(y_hat_slice) # Merge slices and generate the image reconstruction. y_hat = tf.concat(y_hat_slices, axis=-1) x_hat = self.synthesis_transform(y_hat) x_hat = x_hat[:, :x_shape[0], :x_shape[1], :] if mode != "decompress": # Total bpp is sum of bpp from hyperprior and all slices. total_bpp = tf.add_n(y_bpps + [z_bpp]) # Mean squared error across pixels. if training: # Don't clip or round pixel values while training. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) mse *= 255**2 # multiply by 255^2 to correct for rescaling else: x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) if mode == "compress": mse = tf.reduce_mean(tf.math.squared_difference( x * 255, x_hat)) if mode == "train": # Calculate and return the rate-distortion loss: R + lambda * D. loss = total_bpp + self.args.lmbda * mse tf.summary.scalar("bpp", total_bpp) tf.summary.scalar("mse", mse) tf.summary.scalar("loss", loss) tf.summary.image("original", quantize_image(x)) tf.summary.image("reconstruction", quantize_image(x_hat)) return loss elif mode == "compress": # Create `pack` dict mapping tensors to values. tensors = [x_shape, y_shape, z_shape, z_string] + y_strings pack = [(v, v.numpy()) for v in tensors] return mse, total_bpp, x_hat, pack elif mode == "decompress": return x_hat
def call(self, x, training): """Computes rate and distortion losses.""" # Build the encoder (analysis) half of the hierarchical autoencoder. y = self.analysis_transform(x) y_shape = tf.shape(y)[1:-1] z = self.hyper_analysis_transform(y) num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[1:-1]), tf.float32) # Build the entropy model for the hyperprior (z). em_z = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=False, offset_heuristic=False) # When training, z_bpp is based on the noisy version of z (z_tilde). _, z_bits = em_z(z, training=training) z_bpp = tf.reduce_mean(z_bits) / num_pixels # Use rounding (instead of uniform noise) to modify z before passing it # to the hyper-synthesis transforms. Note that quantize() overrides the # gradient to create a straight-through estimator. z_hat = em_z.quantize(z) # Build the decoder (synthesis) half of the hierarchical autoencoder. latent_scales = self.hyper_synthesis_scale_transform(z_hat) latent_means = self.hyper_synthesis_mean_transform(z_hat) # Build a conditional entropy model for the slices. em_y = tfc.LocationScaleIndexedEntropyModel( tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn, coding_rank=3, compression=False) # En/Decode each slice conditioned on hyperprior and previous slices. y_slices = tf.split(y, self.num_slices, axis=-1) y_hat_slices = [] y_bpps = [] for slice_index, y_slice in enumerate(y_slices): # Model may condition on only a subset of previous slices. support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices]) # Predict mu and sigma for the current slice. mean_support = tf.concat([latent_means] + support_slices, axis=-1) mu = self.cc_mean_transforms[slice_index](mean_support) mu = mu[:, :y_shape[0], :y_shape[1], :] # Note that in this implementation, `sigma` represents scale indices, # not actual scale values. scale_support = tf.concat([latent_scales] + support_slices, axis=-1) sigma = self.cc_scale_transforms[slice_index](scale_support) sigma = sigma[:, :y_shape[0], :y_shape[1], :] _, slice_bits = em_y(y_slice, sigma, loc=mu, training=training) slice_bpp = tf.reduce_mean(slice_bits) / num_pixels y_bpps.append(slice_bpp) # For the synthesis transform, use rounding. Note that quantize() # overrides the gradient to create a straight-through estimator. y_hat_slice = em_y.quantize(y_slice, loc=mu) # Add latent residual prediction (LRP). lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1) lrp = self.lrp_transforms[slice_index](lrp_support) lrp = 0.5 * tf.math.tanh(lrp) y_hat_slice += lrp y_hat_slices.append(y_hat_slice) # Merge slices and generate the image reconstruction. y_hat = tf.concat(y_hat_slices, axis=-1) x_hat = self.synthesis_transform(y_hat) # Total bpp is sum of bpp from hyperprior and all slices. total_bpp = tf.add_n(y_bpps + [z_bpp]) # Mean squared error across pixels. # Don't clip or round pixel values while training. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) # Calculate and return the rate-distortion loss: R + lambda * D. loss = total_bpp + self.lmbda * mse return loss, total_bpp, mse