예제 #1
0
  def call(self, inputs, training):
    """Pass a tensor through the bottleneck.

    Arguments:
      inputs: The tensor to be passed through the bottleneck.
      training: Boolean. If `True`, returns a differentiable approximation of
        the inputs, and their likelihoods under the modeled probability
        densities. If `False`, returns the quantized inputs and their
        likelihoods under the corresponding probability mass function. These
        quantities can't be used for training, as they are not differentiable,
        but represent actual compression more closely.

    Returns:
      values: `Tensor` with the same shape as `inputs` containing the perturbed
        or quantized input values.
      likelihood: `Tensor` with the same shape as `inputs` containing the
        likelihood of `values` under the modeled probability distributions.

    Raises:
      ValueError: if `inputs` has an integral or inconsistent `DType`, or
        inconsistent number of channels.
    """
    inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
    if inputs.dtype.is_integer:
      raise ValueError(
          "{} can't take integer inputs.".format(type(self).__name__))

    outputs = self._quantize(inputs, "noise" if training else "dequantize")
    assert outputs.dtype == self.dtype
    likelihood = self._likelihood(outputs)
    if self.likelihood_bound > 0:
      likelihood_bound = tf.constant(self.likelihood_bound, dtype=self.dtype)
      likelihood = math_ops.lower_bound(likelihood, likelihood_bound)

    if not tf.executing_eagerly():
      outputs_shape, likelihood_shape = self.compute_output_shape(inputs.shape)
      outputs.set_shape(outputs_shape)
      likelihood.set_shape(likelihood_shape)

    return outputs, likelihood
예제 #2
0
  def _test_lower_bound(self, gradient):
    inputs = tf.placeholder(dtype=tf.float32)
    outputs = math_ops.lower_bound(inputs, 0, gradient=gradient)
    pgrads, = tf.gradients([outputs], [inputs], [tf.ones_like(inputs)])
    ngrads, = tf.gradients([outputs], [inputs], [-tf.ones_like(inputs)])

    inputs_feed = [-1, 1]
    outputs_expected = [0, 1]
    if gradient == "disconnected":
      pgrads_expected = [0, 1]
      ngrads_expected = [0, -1]
    elif gradient == "identity":
      pgrads_expected = [1, 1]
      ngrads_expected = [-1, -1]
    else:
      pgrads_expected = [0, 1]
      ngrads_expected = [-1, -1]

    with self.test_session() as sess:
      outputs, pgrads, ngrads = sess.run(
          [outputs, pgrads, ngrads], {inputs: inputs_feed})
      self.assertAllEqual(outputs, outputs_expected)
      self.assertAllEqual(pgrads, pgrads_expected)
      self.assertAllEqual(ngrads, ngrads_expected)
예제 #3
0
  def build(self, input_shape):
    """Builds the entropy model.

    This function precomputes the quantized CDF table based on the scale table.
    This can be done at graph construction time. Then, it creates the graph for
    computing the indexes into that table based on the scale tensor, and then
    uses this index tensor to determine the starting positions of the PMFs for
    each scale.

    Args:
      input_shape: Shape of the input tensor.

    Raises:
      ValueError: If `input_shape` doesn't specify number of input dimensions.
    """
    input_shape = tf.TensorShape(input_shape)
    input_shape.assert_is_compatible_with(self.input_spec.shape)

    scale_table = tf.constant(self.scale_table, dtype=self.dtype)

    # Lower bound scales. We need to do this here, and not in __init__, because
    # the dtype may not yet be known there.
    if self.scale_bound is None:
      self._scale = math_ops.lower_bound(self._scale, scale_table[0])
    elif self.scale_bound > 0:
      self._scale = math_ops.lower_bound(self._scale, self.scale_bound)

    multiplier = -self._standardized_quantile(self.tail_mass / 2)
    pmf_center = np.ceil(np.array(self.scale_table) * multiplier).astype(int)
    pmf_length = 2 * pmf_center + 1
    max_length = np.max(pmf_length)

    # This assumes that the standardized cumulative has the property
    # 1 - c(x) = c(-x), which means we can compute differences equivalently in
    # the left or right tail of the cumulative. The point is to only compute
    # differences in the left tail. This increases numerical stability: c(x) is
    # 1 for large x, 0 for small x. Subtracting two numbers close to 0 can be
    # done with much higher precision than subtracting two numbers close to 1.
    samples = abs(np.arange(max_length, dtype=int) - pmf_center[:, None])
    samples = tf.constant(samples, dtype=self.dtype)
    samples_scale = tf.expand_dims(scale_table, 1)
    upper = self._standardized_cumulative((.5 - samples) / samples_scale)
    lower = self._standardized_cumulative((-.5 - samples) / samples_scale)
    pmf = upper - lower

    # Compute out-of-range (tail) masses.
    tail_mass = 2 * lower[:, :1]

    def cdf_initializer(shape, dtype=None, partition_info=None):
      del partition_info  # unused
      assert tuple(shape) == (len(pmf_length), max_length + 2)
      assert dtype == tf.int32
      return self._pmf_to_cdf(
          pmf, tail_mass,
          tf.constant(pmf_length, dtype=tf.int32), max_length)

    quantized_cdf = self.add_weight(
        "quantized_cdf", shape=(len(pmf_length), max_length + 2),
        initializer=cdf_initializer, dtype=tf.int32, trainable=False,
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    cdf_length = self.add_weight(
        "cdf_length", shape=(len(pmf_length),),
        initializer=tf.initializers.constant(pmf_length + 2),
        dtype=tf.int32, trainable=False,
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    # Works around a weird TF issue with reading variables inside a loop.
    self._quantized_cdf = tf.identity(quantized_cdf)
    self._cdf_length = tf.identity(cdf_length)

    # Now, if they haven't been overridden, compute the indexes into the table
    # for each of the passed-in scales.
    if not hasattr(self, "_indexes"):
      # Prevent tensors from bouncing back and forth between host and GPU.
      with tf.device("/cpu:0"):
        fill = tf.constant(
            len(self.scale_table) - 1, dtype=tf.int32)
        initializer = tf.fill(tf.shape(self.scale), fill)

        def loop_body(indexes, scale):
          return indexes - tf.cast(self.scale <= scale, tf.int32)

        self._indexes = tf.foldr(
            loop_body, scale_table[:-1],
            initializer=initializer, back_prop=False, name="compute_indexes")

    self._offset = tf.constant(-pmf_center, dtype=tf.int32)

    super(SymmetricConditional, self).build(input_shape)
예제 #4
0
 def __call__(self) -> tf.Tensor:
     """Computes and returns the non-negative value as a `tf.Tensor`."""
     with self.name_scope:
         reparam_value = math_ops.lower_bound(self.variable, self._bound)
         return tf.math.square(reparam_value) - self._pedestal
예제 #5
0
  def call(self, inputs, training):
    """Pass a tensor through the bottleneck.

    Args:
      inputs: The tensor to be passed through the bottleneck.
      training: Boolean. If `True`, returns a differentiable approximation of
        the inputs, and their likelihoods under the modeled probability
        densities. If `False`, returns the quantized inputs and their
        likelihoods under the corresponding probability mass function. These
        quantities can't be used for training, as they are not differentiable,
        but represent actual compression more closely.

    Returns:
      values: `Tensor` with the same shape as `inputs` containing the perturbed
        or quantized input values.
      likelihood: `Tensor` with the same shape as `inputs` containing the
        likelihood of `values` under the modeled probability distributions.

    Raises:
      ValueError: if `inputs` has different `dtype` or number of channels than
        a previous set of inputs the model was invoked with earlier.
    """
    inputs = ops.convert_to_tensor(inputs)
    ndim = self.input_spec.ndim
    channel_axis = self._channel_axis(ndim)
    half = constant_op.constant(.5, dtype=self.dtype)

    # Convert to (channels, 1, batch) format by commuting channels to front
    # and then collapsing.
    order = list(range(ndim))
    order.pop(channel_axis)
    order.insert(0, channel_axis)
    values = array_ops.transpose(inputs, order)
    shape = array_ops.shape(values)
    values = array_ops.reshape(values, (shape[0], 1, -1))

    # Add noise or quantize.
    if training:
      noise = random_ops.random_uniform(array_ops.shape(values), -half, half)
      values = math_ops.add_n([values, noise])
    elif self.optimize_integer_offset:
      values = math_ops.round(values - self._medians) + self._medians
    else:
      values = math_ops.round(values)

    # Evaluate densities.
    # We can use the special rule below to only compute differences in the left
    # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1
    # for large x, 0 for small x. Subtracting two numbers close to 0 can be done
    # with much higher precision than subtracting two numbers close to 1.
    lower = self._logits_cumulative(values - half, stop_gradient=False)
    upper = self._logits_cumulative(values + half, stop_gradient=False)
    # Flip signs if we can move more towards the left tail of the sigmoid.
    sign = -math_ops.sign(math_ops.add_n([lower, upper]))
    sign = array_ops.stop_gradient(sign)
    likelihood = abs(
        math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower))
    if self.likelihood_bound > 0:
      likelihood_bound = constant_op.constant(
          self.likelihood_bound, dtype=self.dtype)
      likelihood = tfc_math_ops.lower_bound(likelihood, likelihood_bound)

    # Convert back to input tensor shape.
    order = list(range(1, ndim))
    order.insert(channel_axis, 0)
    values = array_ops.reshape(values, shape)
    values = array_ops.transpose(values, order)
    likelihood = array_ops.reshape(likelihood, shape)
    likelihood = array_ops.transpose(likelihood, order)

    if not context.executing_eagerly():
      values_shape, likelihood_shape = self.compute_output_shape(inputs.shape)
      values.set_shape(values_shape)
      likelihood.set_shape(likelihood_shape)

    return values, likelihood
예제 #6
0
 def reparam(var):
     var = math_ops.lower_bound(var, bound)
     var = tf.math.square(var) - pedestal
     return var
예제 #7
0
  def build(self, input_shape):
    """Builds the entropy model.

    This function precomputes the quantized CDF table based on the scale table.
    This can be done at graph construction time. Then, it creates the graph for
    computing the indexes into that table based on the scale tensor, and then
    uses this index tensor to determine the starting positions of the PMFs for
    each scale.

    Arguments:
      input_shape: Shape of the input tensor.

    Raises:
      ValueError: If `input_shape` doesn't specify number of input dimensions.
    """
    input_shape = tf.TensorShape(input_shape)
    input_shape.assert_is_compatible_with(self.input_spec.shape)

    scale_table = tf.constant(self.scale_table, dtype=self.dtype)

    # Lower bound scales. We need to do this here, and not in __init__, because
    # the dtype may not yet be known there.
    if self.scale_bound is None:
      self._scale = math_ops.lower_bound(self._scale, scale_table[0])
    elif self.scale_bound > 0:
      self._scale = math_ops.lower_bound(self._scale, self.scale_bound)

    multiplier = -self._standardized_quantile(self.tail_mass / 2)
    pmf_center = np.ceil(np.array(self.scale_table) * multiplier).astype(int)
    pmf_length = 2 * pmf_center + 1
    max_length = np.max(pmf_length)

    # This assumes that the standardized cumulative has the property
    # 1 - c(x) = c(-x), which means we can compute differences equivalently in
    # the left or right tail of the cumulative. The point is to only compute
    # differences in the left tail. This increases numerical stability: c(x) is
    # 1 for large x, 0 for small x. Subtracting two numbers close to 0 can be
    # done with much higher precision than subtracting two numbers close to 1.
    samples = abs(np.arange(max_length, dtype=int) - pmf_center[:, None])
    samples = tf.constant(samples, dtype=self.dtype)
    samples_scale = tf.expand_dims(scale_table, 1)
    upper = self._standardized_cumulative((.5 - samples) / samples_scale)
    lower = self._standardized_cumulative((-.5 - samples) / samples_scale)
    pmf = upper - lower

    # Compute out-of-range (tail) masses.
    tail_mass = 2 * lower[:, :1]

    def cdf_initializer(shape, dtype=None, partition_info=None):
      del partition_info  # unused
      assert tuple(shape) == (len(pmf_length), max_length + 2)
      assert dtype == tf.int32
      return self._pmf_to_cdf(
          pmf, tail_mass,
          tf.constant(pmf_length, dtype=tf.int32), max_length)

    quantized_cdf = self.add_variable(
        "quantized_cdf", shape=(len(pmf_length), max_length + 2),
        initializer=cdf_initializer, dtype=tf.int32, trainable=False)
    cdf_length = self.add_variable(
        "cdf_length", shape=(len(pmf_length),),
        initializer=tf.initializers.constant(pmf_length + 2),
        dtype=tf.int32, trainable=False)
    # Works around a weird TF issue with reading variables inside a loop.
    self._quantized_cdf = tf.identity(quantized_cdf)
    self._cdf_length = tf.identity(cdf_length)

    # Now, if they haven't been overridden, compute the indexes into the table
    # for each of the passed-in scales.
    if not hasattr(self, "_indexes"):
      # Prevent tensors from bouncing back and forth between host and GPU.
      with tf.device("/cpu:0"):
        fill = tf.constant(
            len(self.scale_table) - 1, dtype=tf.int32)
        initializer = tf.fill(tf.shape(self.scale), fill)

        def loop_body(indexes, scale):
          return indexes - tf.cast(self.scale <= scale, tf.int32)

        self._indexes = tf.foldr(
            loop_body, scale_table[:-1],
            initializer=initializer, back_prop=False, name="compute_indexes")

    self._offset = tf.constant(-pmf_center, dtype=tf.int32)

    super(SymmetricConditional, self).build(input_shape)
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial values for optimization
    y_init = analysis_transform(x)
    z_init = hyper_analysis_transform(y_init)

    y = tf.placeholder('float32', y_init.shape)
    from utils import round_with_identity_STE as round_with_STE
    y_tilde = round_with_STE(y)
    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior
    # # p(z_tilde) ("z_likelihoods")
    # z_tilde, z_likelihoods = entropy_bottleneck(z, training=training)
    z = tf.placeholder('float32', z_init.shape)
    z_tilde = round_with_STE(z)

    _ = entropy_bottleneck(
        z, training=False
    )  # dummy call to ensure entropy_bottleneck is properly built
    z_likelihoods = entropy_bottleneck._likelihood(z_tilde)  # p(\tilde z)
    if entropy_bottleneck.likelihood_bound > 0:
        likelihood_bound = entropy_bottleneck.likelihood_bound
        z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # graph = build_graph(args, x, training=False)

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp'
        ]
        eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        log_itv = 100
        if save_opt_record or stop_early:
            log_itv = 10
        rd_lr = 0.0001
        rd_opt_its = 2000
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur, z_cur = sess.run([y_init, z_init],
                                        feed_dict=x_feed_dict)  # np arrays
                adam_optimizer = Adam(lr=rd_lr)

                if stop_early:
                    obj_prev = np.inf
                    y_prev, z_prev = None, None
                opt_record = {
                    'its': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z: z_cur,
                            **x_feed_dict
                        })
                    y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        print(
                            'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f' %
                            (it, obj, mse_, train_bpp_, psnr_))
                        if stop_early:
                            if obj >= obj_prev:  # no longer improving
                                y_cur, z_cur = y_prev, z_prev
                                break
                            else:
                                obj_prev = obj
                                y_prev, z_prev = y_cur, z_cur

                        opt_record['its'].append(it)
                        opt_record['rd_loss'].append(obj)
                        opt_record['rd_loss_after_rounding'].append(obj)
                print()

                y_tilde_cur = np.round(
                    y_cur)  # this is the latents we end up transmitting
                z_tilde_cur = np.round(z_cur)

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_tilde_cur,
                                         z_tilde: z_tilde_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension

        # save RD evaluation results
        prefix = 'rd'
        save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file)
        if script_name != trained_script_name:
            save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                prefix, script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        if save_opt_record:
            # save optimization record
            prefix = 'opt'
            save_file = '%s-%s-input=%s.npz' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                    prefix, script_name, args.lmbda, args.runname, input_file)
            np.savez(os.path.join(args.results_dir, save_file), **opt_record)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
예제 #9
0
 def test_lower_bound_invalid(self):
     with self.assertRaises(ValueError):
         math_ops.lower_bound(tf.zeros((1, 2)), 0, gradient="invalid")
예제 #10
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    from utils import log_normal_pdf
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3))

    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)
    y_tilde = tf.round(y)

    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde),
                                          num_or_size_splits=2,
                                          axis=-1)
    z_mean = tf.placeholder(
        'float32',
        z_mean_init.shape)  # initialize to inference network results
    z_logvar = tf.placeholder('float32', z_logvar_init.shape)

    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    bpp_back = tf.reduce_sum(
        -log_q_z_tilde, axis=axes_except_batch) / (np.log(2) * img_num_pixels)
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp - bpp_back  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    local_gradients = tf.gradients(train_bpp, [z_mean, z_logvar])

    # Bring both images back to 0..255 range.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp', 'est_bpp_back'
        ]
        eval_tensors = [
            mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back
        ]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init],
                    feed_dict=x_feed_dict)  # np arrays

                opt_obj_hist = []
                opt_grad_hist = []
                lr = 0.005
                local_its = 1000
                from adam import Adam
                np_adam_optimizer = Adam(lr=lr)
                for it in range(local_its):
                    grads, obj = sess.run([local_gradients, train_bpp],
                                          feed_dict={
                                              z_mean: z_mean_cur,
                                              z_logvar: z_logvar_cur,
                                              **x_feed_dict
                                          })
                    z_mean_cur, z_logvar_cur = np_adam_optimizer.update(
                        [z_mean_cur, z_logvar_cur], grads)
                    if it % 100 == 0:
                        print('negative local ELBO', obj)
                    opt_obj_hist.append(obj)
                    opt_grad_hist.append(np.mean(np.abs(grads)))
                print()

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         z_mean: z_mean_cur,
                                         z_logvar: z_logvar_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension
        save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file)
        if script_name != trained_script_name:
            save_file = 'rd-%s+%s-input=%s.npz' % (script_name, args.runname,
                                                   input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
예제 #11
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial values for optimization
    y_init = analysis_transform(x)
    z_init = hyper_analysis_transform(y_init)

    y = tf.placeholder('float32', y_init.shape)
    T = tf.placeholder('float32', shape=[], name='temperature')
    y_floor = tf.floor(y)
    y_ceil = tf.ceil(y)
    y_bds = tf.stack([y_floor, y_ceil], axis=-1)
    epsilon = 1e-5
    ry_logits = tf.stack(
        [
            -tf.math.atanh(
                tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T,
            -tf.math.atanh(
                tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T
        ],
        axis=-1
    )  # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0
    ry = tf.nn.softmax(ry_logits, axis=-1)
    y_tilde = tf.reduce_sum(y_bds * ry, axis=-1)  # inner product in last dim
    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior
    # # p(z_tilde) ("z_likelihoods")
    # z_tilde, z_likelihoods = entropy_bottleneck(z, training=training)
    z = tf.placeholder('float32', z_init.shape)
    z_floor = tf.floor(z)
    z_ceil = tf.ceil(z)
    z_bds = tf.stack([z_floor, z_ceil], axis=-1)
    rz_logits = tf.stack([
        -tf.math.atanh(tf.clip_by_value(z - z_floor, -1 + epsilon,
                                        1 - epsilon)) / T,
        -tf.math.atanh(tf.clip_by_value(z_ceil - z, -1 + epsilon, 1 - epsilon))
        / T
    ],
                         axis=-1)  # last dim are logits for DOWN or UP
    rz = tf.nn.softmax(rz_logits, axis=-1)
    z_tilde = tf.reduce_sum(z_bds * rz, axis=-1)  # inner product in last dim

    # # We have to manually call entropy_bottleneck.build because we don't directly call entropy_bottleneck like we did
    # # with 'z_tilde, z_likelihoods = entropy_bottleneck(z, training=training)' during training
    # # UPDATE: this doesn't quite work, as the resulting variables don't have the proper name scope (will just be named
    # # "matrix_0", "bias_0", etc., instead of "entropy_bottleneck/matrix_0", "entropy_bottleneck/bias_0" as would with
    # # calling entropy_bottleneck on tensor, which breaks model loading (will get "Key bias_0 not found in checkpoint..
    # # tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due
    # # to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not
    # # altered the graph expected based on the checkpoint.").
    # entropy_bottleneck.build(z_tilde.shape)
    _ = entropy_bottleneck(
        z, training=False
    )  # dummy call to ensure entropy_bottleneck is properly built
    z_likelihoods = entropy_bottleneck._likelihood(z_tilde)  # p(\tilde z)
    if entropy_bottleneck.likelihood_bound > 0:
        likelihood_bound = entropy_bottleneck.likelihood_bound
        z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp'
        ]
        eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        log_itv = 100
        if save_opt_record:
            log_itv = 10
        rd_lr = 0.005
        rd_opt_its = 2000
        annealing_rate = 4e-3
        T_ub = 0.2

        def annealed_temperature(t, r, ub, lb=1e-8, backend=np):
            # Using the exp schedule from section 4.2 of Jang et. al., ICLR2017
            if backend is None:
                return min(max(np.exp(-r * t), lb), ub)
            else:
                return backend.minimum(
                    backend.maximum(backend.exp(-r * t), lb), ub)

        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur, z_cur = sess.run([y_init, z_init],
                                        feed_dict=x_feed_dict)  # np arrays
                adam_optimizer = Adam(lr=rd_lr)
                opt_record = {
                    'its': [],
                    'T': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    temperature = annealed_temperature(it,
                                                       r=annealing_rate,
                                                       ub=T_ub)
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z: z_cur,
                            **x_feed_dict, T: temperature
                        })
                    y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: np.round(y_cur),
                                    z_tilde: np.round(z_cur),
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_, rd_loss_after_rounding,
                                   bpp_after_rounding, psnr_after_rounding))
                            opt_record['rd_loss_after_rounding'].append(
                                rd_loss_after_rounding)
                        else:
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_))
                        opt_record['its'].append(it)
                        opt_record['T'].append(temperature)
                        opt_record['rd_loss'].append(obj)

                print()

                y_tilde_cur = np.round(
                    y_cur)  # this is the latents we end up transmitting
                z_tilde_cur = np.round(z_cur)

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_tilde_cur,
                                         z_tilde: z_tilde_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension

        # save RD evaluation results
        prefix = 'rd'
        save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file)
        if script_name != trained_script_name:
            save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                prefix, script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        if save_opt_record:
            # save optimization record
            prefix = 'opt'
            save_file = '%s-%s-input=%s.npz' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                    prefix, script_name, args.lmbda, args.runname, input_file)
            np.savez(os.path.join(args.results_dir, save_file), **opt_record)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
예제 #12
0
def build_graph(args, x, training=True):
    """
    Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3].
    Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest.
    During training we sample from box-shaped posteriors; during compression this is approximated by rounding.
    """
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)

    # z_tilde ~ q(z_tilde | x) = q(z_tilde | h_a(y))
    z_mean, z_logvar = tf.split(hyper_analysis_transform(y),
                                num_or_size_splits=2,
                                axis=-1)
    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean
    from utils import log_normal_pdf
    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3))
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    if training:
        sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5)
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        y_shape = tf.shape(y)
        mu = mu[:, :y_shape[1], :y_shape[2], :]
        sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # sample y_tilde from q(y_tilde|x) = U(y-0.5, y+0.5) = U(g_a(x)-0.5, g_a(x)+0.5), and then compute the pdf of
    # y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    # Note that at test/compression time, the resulting y_tilde doesn't simply
    # equal round(y); instead, the conditional_bottleneck does something
    # smarter and slightly more optimal: y_hat=floor(y + 0.5 - prior_mean), so
    # that the mean (mu) of the prior coincides with one of the quantization bins.
    y_tilde, y_likelihoods = conditional_bottleneck(y, training=training)

    x_tilde = synthesis_transform(y_tilde)
    if not training:
        x_shape = tf.shape(x)
        x_tilde = x_tilde[:, :x_shape[1], :x_shape[
            2], :]  # crop reconstruction to have the same shape as input

    return locals()
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    from utils import log_normal_pdf
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3))

    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial optimization (where we still have access to x)
    # Soft-to-hard rounding with Gumbel-softmax trick; for each element of z_tilde, let R be a 2D auxiliary one-hot
    # random vector, such that R=[1, 0] means rounding DOWN and [0, 1] means rounding UP.
    # Let the logits of each outcome be -(z - z_floor) / T and -(z_ceil - z) / T (i.e., Boltzmann distribution with
    # energies (z - floor(z)) and (ceil(z) - z), so p(R==[1,0]) = softmax((z - z_floor) / T), ...
    # Let z_tilde = p(R==[1,0]) * floor(z) + p(R==[0,1]) * ceil(z), so z_tilde -> round(z) as T -> 0.
    import tensorflow_probability as tfp
    T = tf.placeholder('float32', shape=[], name='temperature')
    y_init = analysis_transform(x)
    y = tf.placeholder('float32', y_init.shape)
    y_floor = tf.floor(y)
    y_ceil = tf.ceil(y)
    y_bds = tf.stack([y_floor, y_ceil], axis=-1)
    epsilon = 1e-5
    logits = tf.stack(
        [
            -tf.math.atanh(
                tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T,
            -tf.math.atanh(
                tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T
        ],
        axis=-1
    )  # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0
    rounding_dist = tfp.distributions.RelaxedOneHotCategorical(
        T,
        logits=logits)  # technically we can use a different temperature here
    sample_concrete = rounding_dist.sample()
    y_tilde = tf.reduce_sum(y_bds * sample_concrete,
                            axis=-1)  # inner product in last dim
    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde),
                                          num_or_size_splits=2,
                                          axis=-1)
    z_mean = tf.placeholder(
        'float32',
        z_mean_init.shape)  # initialize to inference network results
    z_logvar = tf.placeholder('float32', z_logvar_init.shape)

    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    batch_log_q_z_tilde = tf.reduce_sum(log_q_z_tilde, axis=axes_except_batch)
    bpp_back = -batch_log_q_z_tilde / (np.log(2) * img_num_pixels)
    batch_log_cond_p_y_tilde = tf.reduce_sum(tf.log(y_likelihoods),
                                             axis=axes_except_batch)
    y_bpp = -batch_log_cond_p_y_tilde / (np.log(2) * img_num_pixels)
    batch_log_p_z_tilde = tf.reduce_sum(tf.log(z_likelihoods),
                                        axis=axes_except_batch)
    z_bpp = -batch_log_p_z_tilde / (np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp - bpp_back  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z_mean, z_logvar])
    r_gradients = tf.gradients(train_bpp, [z_mean, z_logvar])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp', 'est_bpp_back'
        ]
        eval_tensors = [
            mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back
        ]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        log_itv = 100
        rd_lr = 0.005
        # rd_opt_its = args.sga_its
        rd_opt_its = 2000
        annealing_scheme = 'exp0'
        annealing_rate = args.annealing_rate  # default annealing_rate = 1e-3
        t0 = args.t0  # default t0 = 700
        T_ub = 0.5  # max/initial temperature
        from utils import annealed_temperature
        r_lr = 0.003
        r_opt_its = 2000
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur = sess.run(y_init, feed_dict=x_feed_dict)  # np arrays
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init], feed_dict={y_tilde: y_cur})
                rd_loss_hist = []
                adam_optimizer = Adam(lr=rd_lr)

                opt_record = {
                    'its': [],
                    'T': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    temperature = annealed_temperature(it,
                                                       r=annealing_rate,
                                                       ub=T_ub,
                                                       scheme=annealing_scheme,
                                                       t0=t0)
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z_mean: z_mean_cur,
                            z_logvar: z_logvar_cur,
                            **x_feed_dict, T: temperature
                        })
                    y_cur, z_mean_cur, z_logvar_cur = adam_optimizer.update(
                        [y_cur, z_mean_cur, z_logvar_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: np.round(y_cur),
                                    z_mean: z_mean_cur,
                                    z_logvar: z_logvar_cur,
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_, rd_loss_after_rounding,
                                   bpp_after_rounding, psnr_after_rounding))
                        else:
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_))
                    rd_loss_hist.append(obj)
                print()

                # 2. Fix y_tilde, perform rate optimization w.r.t. z_mean and z_logvar.
                y_tilde_cur = np.round(
                    y_cur)  # this is the latents we end up transmitting
                # rate_feed_dict = {y_tilde: y_tilde_cur, **x_feed_dict}
                rate_feed_dict = {y_tilde: y_tilde_cur}
                np.random.seed(seed)
                tf.set_random_seed(seed)
                print('----Rate Optimization----')
                # Reinitialize based on the value of y_tilde
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init],
                    feed_dict=rate_feed_dict)  # np arrays

                r_loss_hist = []
                # rate_grad_hist = []

                adam_optimizer = Adam(lr=r_lr)
                for it in range(r_opt_its):
                    grads, obj = sess.run(
                        [r_gradients, train_bpp],
                        feed_dict={
                            z_mean: z_mean_cur,
                            z_logvar: z_logvar_cur,
                            **rate_feed_dict
                        })
                    z_mean_cur, z_logvar_cur = adam_optimizer.update(
                        [z_mean_cur, z_logvar_cur], grads)
                    if it % log_itv == 0 or it + 1 == r_opt_its:
                        print('it=', it, '\trate=', obj)
                    r_loss_hist.append(obj)
                    # rate_grad_hist.append(np.mean(np.abs(grads)))
                print()

                # fig, axes = plt.subplots(nrows=2, sharex=True)
                # axes[0].plot(rd_loss_hist)
                # axes[0].set_ylabel('RD loss')
                # axes[1].plot(r_loss_hist)
                # axes[1].set_ylabel('Rate loss')
                # axes[1].set_xlabel('SGD iterations')
                # plt.savefig('plots/local_q_opt_hist-%s-input=%s-b=%d.png' %
                #             (args.runname, os.path.basename(args.input_file), batch_idx))

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_tilde_cur,
                                         z_mean: z_mean_cur,
                                         z_logvar: z_logvar_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension
        save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file)
        if script_name != trained_script_name:
            save_file = 'rd-%s-lmbda=%g+%s-input=%s.npz' % (
                script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def train(args):
    tf.reset_default_graph()

    num_channels = args.num_channels
    dims = args.dims
    init_scale = args.init_scale
    model = BMSHJ2018Prior(num_channels, dims=dims, init_scale=init_scale)

    # if hasattr(args, 'runname') and args['runname']:
    #     runname = args.runname
    # else:
    #     import time, datetime
    #     ts = time.time()
    #     runname = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H:%M:%S')

    runname = get_runname(vars(args))
    data = np.load(args.data_path)

    lr = args.lr
    its = args.its
    tol = args.tol
    logging_freq = args.logging_freq
    plot = args.plot
    if plot:
        import matplotlib.pyplot as plt

    checkpoint_dir = args.checkpoint_dir

    import os
    save_dir_name = runname
    save_dir = os.path.join(checkpoint_dir, save_dir_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model_name = os.path.join(save_dir, 'prior_model')

    import json
    with open(os.path.join(save_dir, 'args.json'), 'w') as f:  # will overwrite existing
        json.dump(vars(args), f, indent=4, sort_keys=True)

    assert not tf.executing_eagerly()
    X = tf.placeholder(data.dtype, [None, num_channels])
    pdf_lower_bound = 1e-10
    # pdf = model.pdf(X, False)
    # [cdf, pdf] = model.cdf_pdf(X, False)
    # pdf = math_ops.lower_bound(pdf, pdf_lower_bound)
    # log_likelihood = tf.math.log(pdf)
    # log_likelihood = model.logpdf(X, False)
    pdf = model.pdf(X, stop_gradient=False)

    pdf = math_ops.lower_bound(pdf, pdf_lower_bound)
    log_likelihood = tf.math.log(pdf)
    print(log_likelihood)
    loss = - tf.reduce_mean(log_likelihood)

    optimizer = tf.train.AdamOptimizer(lr)
    # optimizer = tf.train.AdadeltaOptimizer()
    train_step = optimizer.minimize(loss)

    record = []

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        prev_loss = float('inf')
        for it in range(its):
            sess.run(train_step, feed_dict={X: data})
            loss_ = sess.run(loss, feed_dict={X: data})
            loss_ = float(loss_)

            if abs(prev_loss - loss_) / abs(loss_) < tol:
                break

            if it % logging_freq == 0 or it + 1 == its:
                print('it=%d,\t\tloss=%g' % (it, loss_))
                record.append(dict(it=it, loss=loss_))

                if plot:
                    # plot p(x)
                    xlim = [-5, 5]
                    xs = np.linspace(*xlim)
                    # figsize=None
                    figsize = (12, 8)
                    plt.figure(figsize=figsize)

                    xs_feed = np.tile(xs[..., None], num_channels)  # len(xs) by num_channels
                    q_xs = sess.run(pdf, feed_dict={X: xs_feed})

                    h, v = 2, 4
                    for k in range(h * v):
                        plt.subplot(h, v, k + 1)
                        plt.plot(xs, q_xs[:, k], label='$q(x)$')

                        bins = 31
                        plt.hist(data[:, k].ravel(), bins=bins, density=True, alpha=0.4, label='$\hat q(z)$')

                        plt.xlim(xlim)
                        plt.title('channel %d, it %d' % (k, it))
                    # plt.ylim([0, 2])

                    plt.legend()
                    plt.tight_layout()
                    plt.savefig(os.path.join(save_dir, runname + '_it=%d.png' % it))
                    # plt.show()

        model.save_weights(model_name)

    with open(os.path.join(save_dir, 'record.json'), 'w') as f:  # will overwrite existing
        json.dump(record, f, indent=4, sort_keys=True)
예제 #15
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format. or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_shape = (None, *X.shape[1:])
    x_ph = x = tf.placeholder('float32',
                              x_shape)  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial values for optimization
    y_init = analysis_transform(x)
    z_init = hyper_analysis_transform(y_init)

    # Soft-to-hard rounding with Gumbel-softmax trick; for each element of z_tilde, let R be a 2D auxiliary one-hot
    # random vector, such that R=[1, 0] means rounding DOWN and [0, 1] means rounding UP.
    # Let the logits of each outcome be -(z - z_floor) / T and -(z_ceil - z) / T (i.e., Boltzmann distribution with
    # energies (z - floor(z)) and (ceil(z) - z), so p(R==[1,0]) = softmax((z - z_floor) / T), ...
    # Let z_tilde = p(R==[1,0]) * floor(z) + p(R==[0,1]) * ceil(z), so z_tilde -> round(z) as T -> 0.
    import tensorflow_probability as tfp
    T = tf.placeholder('float32', shape=[], name='temperature')

    z = tf.placeholder(
        'float32', z_init.shape
    )  # interface ("proxy") variable for SGA (to be annealed to int)
    z_floor = tf.floor(z)
    z_ceil = tf.ceil(z)
    z_bds = tf.stack([z_floor, z_ceil], axis=-1)
    rz_logits = tf.stack(
        [
            -tf.math.atanh(
                tf.clip_by_value(z - z_floor, -1 + epsilon, 1 - epsilon)) / T,
            -tf.math.atanh(
                tf.clip_by_value(z_ceil - z, -1 + epsilon, 1 - epsilon)) / T
        ],
        axis=-1
    )  # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0
    rz_dist = tfp.distributions.RelaxedOneHotCategorical(
        T, logits=rz_logits
    )  # technically we can use a different temperature here
    rz_sample = rz_dist.sample()
    z_tilde = tf.reduce_sum(z_bds * rz_sample,
                            axis=-1)  # inner product in last dim

    _ = entropy_bottleneck(
        z, training=False
    )  # dummy call to ensure entropy_bottleneck is properly built
    z_likelihoods = entropy_bottleneck._likelihood(z_tilde)  # p(\tilde z)
    if entropy_bottleneck.likelihood_bound > 0:
        likelihood_bound = entropy_bottleneck.likelihood_bound
        z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound)

    # compute parameters of conditional prior p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # set up SGA for low-level latents
    y = tf.placeholder(
        'float32', y_init.shape
    )  # interface ("proxy") variable for SGA (to be annealed to int)
    y_floor = tf.floor(y)
    y_ceil = tf.ceil(y)
    y_bds = tf.stack([y_floor, y_ceil], axis=-1)
    ry_logits = tf.stack([
        -tf.math.atanh(tf.clip_by_value(y - y_floor, -1 + epsilon,
                                        1 - epsilon)) / T,
        -tf.math.atanh(tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon))
        / T
    ],
                         axis=-1)  # last dim are logits for DOWN or UP
    ry_dist = tfp.distributions.RelaxedOneHotCategorical(
        T, logits=ry_logits
    )  # technically we can use a different temperature here
    ry_sample = ry_dist.sample()
    y_tilde = tf.reduce_sum(y_bds * ry_sample,
                            axis=-1)  # inner product in last dim
    x_tilde = synthesis_transform(y_tilde)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # graph = build_graph(args, x, training=False)

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    if save_reconstruction:
        x_tilde_float = x_tilde
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp'
        ]
        eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        log_itv = 100
        if save_opt_record:
            log_itv = 10
        rd_lr = 0.005
        # rd_opt_its = args.sga_its
        rd_opt_its = 2000
        annealing_scheme = 'exp0'
        annealing_rate = args.annealing_rate  # default annealing_rate = 1e-3
        t0 = args.t0  # default t0 = 700
        T_ub = 0.5  # max/initial temperature
        from utils import annealed_temperature
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur, z_cur = sess.run([y_init, z_init],
                                        feed_dict=x_feed_dict)  # np arrays
                adam_optimizer = Adam(lr=rd_lr)
                opt_record = {
                    'its': [],
                    'T': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    temperature = annealed_temperature(it,
                                                       r=annealing_rate,
                                                       ub=T_ub,
                                                       scheme=annealing_scheme,
                                                       t0=t0)
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z: z_cur,
                            **x_feed_dict, T: temperature
                        })
                    y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: np.round(y_cur),
                                    z_tilde: np.round(z_cur),
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_, rd_loss_after_rounding,
                                   bpp_after_rounding, psnr_after_rounding))
                            opt_record['rd_loss_after_rounding'].append(
                                rd_loss_after_rounding)
                        else:
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_))
                        opt_record['its'].append(it)
                        opt_record['T'].append(temperature)
                        opt_record['rd_loss'].append(obj)

                print()

                y_tilde_cur = np.round(
                    y_cur)  # this is the latents we end up transmitting
                z_tilde_cur = np.round(z_cur)

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_tilde_cur,
                                         z_tilde: z_tilde_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension

        # save RD evaluation results
        prefix = 'rd'
        save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file)
        if script_name != trained_script_name:
            save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                prefix, script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        if save_opt_record:
            # save optimization record
            prefix = 'opt'
            save_file = '%s-%s-input=%s.npz' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                    prefix, script_name, args.lmbda, args.runname, input_file)
            np.savez(os.path.join(args.results_dir, save_file), **opt_record)

        if save_reconstruction:
            assert num_images == 1
            prefix = 'recon'
            save_file = '%s-%s-input=%s.png' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g-rd_opt_its=%d+%s-input=%s.png' % (
                    prefix, script_name, args.lmbda, rd_opt_its, args.runname,
                    input_file)
            # Write reconstructed image out as a PNG file.
            save_file = os.path.join(args.results_dir, save_file)
            print("Saving image reconstruction to ", save_file)
            save_png_op = write_png(save_file, x_tilde_float[0])
            sess.run(save_png_op, feed_dict={y_tilde: y_tilde_cur})

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
예제 #16
0
 def reparam(var):
   var = math_ops.lower_bound(var, bound)
   var = tf.math.square(var) - pedestal
   return var