コード例 #1
0
 def _normalize_indexes(self, indexes):
     indexes = math_ops.lower_bound(indexes, 0)
     if isinstance(self.index_ranges, int):
         indexes = math_ops.upper_bound(indexes, self.index_ranges - 1)
     else:
         axes = [1] * indexes.shape.rank
         axes[self.channel_axis] = len(self.index_ranges)
         bounds = tf.reshape([s - 1 for s in self.index_ranges], axes)
         indexes = math_ops.upper_bound(indexes, bounds)
     return indexes
コード例 #2
0
 def _normalize_indexes(self, indexes):
     indexes = math_ops.lower_bound(indexes, 0)
     if self.channel_axis is None:
         index_range, = self.index_ranges
         bounds = index_range - 1
     else:
         axes = [1] * indexes.shape.rank
         axes[self.channel_axis] = len(self.index_ranges)
         bounds = tf.reshape([s - 1 for s in self.index_ranges], axes)
     return math_ops.upper_bound(indexes, tf.cast(bounds, indexes.dtype))
コード例 #3
0
 def _normalize_indexes(self, indexes):
   """See base class."""
   num_indexes = indexes.shape[-1]  # Last dim of `indexes` should be static.
   if num_indexes == len(self.index_ranges):
     # Indexes have offsets.
     index_ranges = self.index_ranges
   else:
     # Indexes do not have offsets.
     index_ranges = self.index_ranges_without_offsets
     assert num_indexes == len(index_ranges)
   indexes = math_ops.lower_bound(indexes, 0)
   axes = [1] * indexes.shape.rank
   axes[self.channel_axis] = len(index_ranges)
   bounds = tf.reshape([s - 1 for s in index_ranges], axes)
   return math_ops.upper_bound(indexes, tf.cast(bounds, indexes.dtype))
コード例 #4
0
 def test_upper_bound_has_correct_outputs_and_gradients(self, gradient):
     inputs = tf.constant([-1, 1], dtype=tf.float32)
     with tf.GradientTape(persistent=True) as tape:
         tape.watch(inputs)
         outputs = math_ops.upper_bound(inputs, 0, gradient=gradient)
     pgrads = tape.gradient(outputs, inputs, tf.ones_like(inputs))
     ngrads = tape.gradient(outputs, inputs, -tf.ones_like(inputs))
     self.assertAllEqual(outputs, [-1, 0])
     if gradient == "disconnected":
         self.assertAllEqual(pgrads, [1, 0])
         self.assertAllEqual(ngrads, [-1, 0])
     elif gradient == "identity":
         self.assertAllEqual(pgrads, [1, 1])
         self.assertAllEqual(ngrads, [-1, -1])
     else:
         self.assertAllEqual(pgrads, [1, 1])
         self.assertAllEqual(ngrads, [-1, 0])
コード例 #5
0
ファイル: math_ops_test.py プロジェクト: phymucs/compression
    def _test_upper_bound(self, gradient):
        inputs = tf.placeholder(dtype=tf.float32)
        outputs = math_ops.upper_bound(inputs, 0, gradient=gradient)
        pgrads, = tf.gradients([outputs], [inputs], [tf.ones_like(inputs)])
        ngrads, = tf.gradients([outputs], [inputs], [-tf.ones_like(inputs)])

        inputs_feed = [-1, 1]
        outputs_expected = [-1, 0]
        if gradient == "disconnected":
            pgrads_expected = [1, 0]
            ngrads_expected = [-1, 0]
        elif gradient == "identity":
            pgrads_expected = [1, 1]
            ngrads_expected = [-1, -1]
        else:
            pgrads_expected = [1, 1]
            ngrads_expected = [-1, 0]

        with self.test_session() as sess:
            outputs, pgrads, ngrads = sess.run([outputs, pgrads, ngrads],
                                               {inputs: inputs_feed})
            self.assertAllEqual(outputs, outputs_expected)
            self.assertAllEqual(pgrads, pgrads_expected)
            self.assertAllEqual(ngrads, ngrads_expected)
コード例 #6
0
  def _test_upper_bound(self, gradient):
    inputs = tf.placeholder(dtype=tf.float32)
    outputs = math_ops.upper_bound(inputs, 0, gradient=gradient)
    pgrads, = tf.gradients([outputs], [inputs], [tf.ones_like(inputs)])
    ngrads, = tf.gradients([outputs], [inputs], [-tf.ones_like(inputs)])

    inputs_feed = [-1, 1]
    outputs_expected = [-1, 0]
    if gradient == "disconnected":
      pgrads_expected = [1, 0]
      ngrads_expected = [-1, 0]
    elif gradient == "identity":
      pgrads_expected = [1, 1]
      ngrads_expected = [-1, -1]
    else:
      pgrads_expected = [1, 1]
      ngrads_expected = [-1, 0]

    with self.test_session() as sess:
      outputs, pgrads, ngrads = sess.run(
          [outputs, pgrads, ngrads], {inputs: inputs_feed})
      self.assertAllEqual(outputs, outputs_expected)
      self.assertAllEqual(pgrads, pgrads_expected)
      self.assertAllEqual(ngrads, ngrads_expected)
def build_graph(args, x, training=True):
    """
    Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3].
    Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest.
    During training we sample from box-shaped posteriors; during compression this is approximated by rounding.
    """
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)

    # y_tilde ~ q(y_tilde | y = g_a(x))
    half = tf.constant(.5, dtype=y.dtype)
    if training:
        noise = tf.random.uniform(tf.shape(y), -half, half)
        y_tilde = y + noise
    else:
        # Approximately sample from q(y_tilde|x) by rounding. We can't be smart and do y_hat=floor(y + 0.5 - prior_mean) as
        # in Balle's model (ultimately implemented by conditional_bottleneck._quantize), because we don't have the prior
        # p(y_tilde | z_tilde) yet; in bb we have to sample z_tilde given y_tilde, whereas in BMSHJ2018, z_tilde is obtained
        # conditioned on x.
        y_tilde = tf.round(y)

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean, z_logvar = tf.split(hyper_analysis_transform(y_tilde),
                                num_or_size_splits=2,
                                axis=-1)
    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    from utils import log_normal_pdf
    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3))
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    if training:
        sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5)
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        y_shape = tf.shape(y)
        mu = mu[:, :y_shape[1], :y_shape[2], :]
        sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)

    x_tilde = synthesis_transform(y_tilde)
    if not training:
        x_shape = tf.shape(x)
        x_tilde = x_tilde[:, :x_shape[1], :x_shape[
            2], :]  # crop reconstruction to have the same shape as input

    return locals()
コード例 #8
0
 def test_upper_bound_invalid(self):
     with self.assertRaises(ValueError):
         math_ops.upper_bound(tf.zeros((1, 2)), 0, gradient="invalid")
コード例 #9
0
def build_graph(args, x, training=True):
    """
    Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3].
    Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest.
    During training we sample from box-shaped posteriors; during compression this is approximated by rounding.
    """
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)

    # z_tilde ~ q(z_tilde | x) = q(z_tilde | h_a(y))
    z_mean, z_logvar = tf.split(hyper_analysis_transform(y),
                                num_or_size_splits=2,
                                axis=-1)
    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean
    from utils import log_normal_pdf
    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3))
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    if training:
        sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5)
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        y_shape = tf.shape(y)
        mu = mu[:, :y_shape[1], :y_shape[2], :]
        sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # sample y_tilde from q(y_tilde|x) = U(y-0.5, y+0.5) = U(g_a(x)-0.5, g_a(x)+0.5), and then compute the pdf of
    # y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    # Note that at test/compression time, the resulting y_tilde doesn't simply
    # equal round(y); instead, the conditional_bottleneck does something
    # smarter and slightly more optimal: y_hat=floor(y + 0.5 - prior_mean), so
    # that the mean (mu) of the prior coincides with one of the quantization bins.
    y_tilde, y_likelihoods = conditional_bottleneck(y, training=training)

    x_tilde = synthesis_transform(y_tilde)
    if not training:
        x_shape = tf.shape(x)
        x_tilde = x_tilde[:, :x_shape[1], :x_shape[
            2], :]  # crop reconstruction to have the same shape as input

    return locals()