Exemplo n.º 1
0
 def testImageLossfunPreservesDtype(self, float_dtype):
     """Tests that image_lossfun's outputs precisions match its input."""
     x = float_dtype(np.random.uniform(size=(10, 64, 64, 3)))
     loss, alpha, scale = adaptive.image_lossfun(x)
     with self.session() as sess:
         sess.run(tf.global_variables_initializer())
         loss, alpha, scale = sess.run([loss, alpha, scale])
     self.assertDTypeEqual(loss, float_dtype)
     self.assertDTypeEqual(alpha, float_dtype)
     self.assertDTypeEqual(scale, float_dtype)
Exemplo n.º 2
0
    def testFittingImageDataIsCorrect(self, image_data_callback):
        """Tests that minimizing the adaptive image loss recovers the true model.

    Here we generate a stack of color images drawn from a normal distribution,
    and then minimize image_lossfun() with respect to the mean and scale of each
    distribution, and check that after minimization the estimated means are
    close to the true means.

    Args:
      image_data_callback: The function used to generate the training data and
        parameters used during optimization.
    """
        # Generate toy data.
        image_width = 4
        num_samples = 10
        wavelet_num_levels = 2  # Ignored by _generate_pixel_toy_image_data().
        (samples, reference, color_space,
         representation) = image_data_callback(image_width, num_samples,
                                               wavelet_num_levels)

        # Construct the loss.
        prediction = tf.Variable(tf.zeros(tf.shape(reference), tf.float64))
        x = samples - prediction[tf.newaxis, :]
        loss, alpha, scale = adaptive.image_lossfun(
            x,
            color_space=color_space,
            representation=representation,
            wavelet_num_levels=wavelet_num_levels,
            alpha_lo=2,
            alpha_hi=2)
        loss = tf.reduce_mean(loss)

        # Minimize the loss.
        with self.session() as sess:
            init_rate = 0.1
            final_rate = 0.01
            num_iters = 201
            global_step = tf.Variable(0, trainable=False)
            t = tf.cast(global_step, tf.float32) / (num_iters - 1)
            rate = tf.math.exp(
                tf.math.log(init_rate) * (1. - t) +
                tf.math.log(final_rate) * t)
            optimizer = tf.train.AdamOptimizer(learning_rate=rate,
                                               beta1=0.5,
                                               beta2=0.9,
                                               epsilon=1e-08)
            step = optimizer.minimize(loss, global_step=global_step)
            sess.run(tf.global_variables_initializer())
            for _ in range(num_iters):
                _ = sess.run(step)
            scale, alpha, prediction = sess.run([scale, alpha, prediction])
        self.assertAllClose(prediction, reference, rtol=0.01, atol=0.01)
Exemplo n.º 3
0
def model_fn(features, labels, mode, params, config):
  """Builds the model function for use in an estimator.

  Arguments:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some parameters, unused here.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, params, config

  if FLAGS.analytic_kl and FLAGS.mixture_components != 1:
    raise NotImplementedError(
        "Using `analytic_kl` is only supported when `mixture_components = 1` "
        "since there's no closed form otherwise.")
  if FLAGS.floating_prior and not (FLAGS.unit_posterior and
                                   FLAGS.mixture_components == 1):
    raise NotImplementedError(
        "Using `floating_prior` is only supported when `unit_posterior` = True "
        "since there's a scale ambiguity otherwise, and when "
        "`mixture_components = 1` since there's no closed form otherwise.")
  if FLAGS.fitted_samples and FLAGS.mixture_components != 1:
    raise NotImplementedError(
        "Using `fitted_samples` is only supported when "
        "`mixture_components = 1` since there's no closed form otherwise.")
  if FLAGS.bilbo and not FLAGS.floating_prior:
    raise NotImplementedError(
        "Using `bilbo` is only supported when `floating_prior = True`.")

  activation = tf.nn.leaky_relu
  encoder = make_encoder(activation, FLAGS.latent_size, FLAGS.base_depth)
  decoder = make_decoder(activation, FLAGS.latent_size, [IMAGE_SIZE] * 2 + [3],
                         FLAGS.base_depth)

  approx_posterior = encoder(features)
  approx_posterior_sample = approx_posterior.sample(FLAGS.n_samples)
  decoder_mu = decoder(approx_posterior_sample)

  if FLAGS.floating_prior or FLAGS.fitted_samples:
    posterior_batch_mean = tf.reduce_mean(approx_posterior.mean()**2, [0])
    posterior_batch_variance = tf.reduce_mean(approx_posterior.stddev()**2, [0])
    posterior_scale = posterior_batch_mean + posterior_batch_variance
    floating_prior = tfd.MultivariateNormalDiag(
        tf.zeros(FLAGS.latent_size), tf.sqrt(posterior_scale))
    tf.summary.scalar("posterior_scale", tf.reduce_sum(posterior_scale))

  if FLAGS.floating_prior:
    latent_prior = floating_prior
  else:
    latent_prior = make_mixture_prior(FLAGS.latent_size,
                                      FLAGS.mixture_components)

  # Decode samples from the prior for visualization.
  if FLAGS.fitted_samples:
    sample_distribution = floating_prior
  else:
    sample_distribution = latent_prior

  n_samples = VIZ_GRID_SIZE**2
  random_mu = decoder(sample_distribution.sample(n_samples))

  residual = tf.reshape(features - decoder_mu, [-1] + [IMAGE_SIZE] * 2 + [3])

  if FLAGS.use_students_t:
    nll = adaptive.image_lossfun(
        residual,
        color_space=FLAGS.color_space,
        representation=FLAGS.representation,
        wavelet_num_levels=FLAGS.wavelet_num_levels,
        wavelet_scale_base=FLAGS.wavelet_scale_base,
        use_students_t=FLAGS.use_students_t,
        scale_lo=FLAGS.scale_lo,
        scale_init=FLAGS.scale_init)[0]
  else:
    nll = adaptive.image_lossfun(
        residual,
        color_space=FLAGS.color_space,
        representation=FLAGS.representation,
        wavelet_num_levels=FLAGS.wavelet_num_levels,
        wavelet_scale_base=FLAGS.wavelet_scale_base,
        use_students_t=FLAGS.use_students_t,
        alpha_lo=FLAGS.alpha_lo,
        alpha_hi=FLAGS.alpha_hi,
        alpha_init=FLAGS.alpha_init,
        scale_lo=FLAGS.scale_lo,
        scale_init=FLAGS.scale_init)[0]

  nll = tf.reshape(nll, [tf.shape(decoder_mu)[0],
                         tf.shape(decoder_mu)[1]] + [IMAGE_SIZE] * 2 + [3])

  # Clipping to prevent the loss from nanning out.
  max_val = np.finfo(np.float32).max
  nll = tf.clip_by_value(nll, -max_val, max_val)

  viz_n_inputs = np.int32(np.minimum(VIZ_MAX_N_INPUTS, FLAGS.batch_size))
  viz_n_samples = np.int32(np.minimum(VIZ_MAX_N_SAMPLES, FLAGS.n_samples))

  image_tile_summary("input", tf.to_float(features), rows=1, cols=viz_n_inputs)

  image_tile_summary(
      "recon/mean",
      decoder_mu[:viz_n_samples, :viz_n_inputs],
      rows=viz_n_samples,
      cols=viz_n_inputs)

  img_summary_input = image_tile_summary(
      "input1", tf.to_float(features), rows=viz_n_inputs, cols=1)
  img_summary_recon = image_tile_summary(
      "recon1", decoder_mu[:1, :viz_n_inputs], rows=viz_n_inputs, cols=1)

  image_tile_summary(
      "random/mean", random_mu, rows=VIZ_GRID_SIZE, cols=VIZ_GRID_SIZE)

  distortion = tf.reduce_sum(nll, axis=[2, 3, 4])

  avg_distortion = tf.reduce_mean(distortion)
  tf.summary.scalar("distortion", avg_distortion)

  if FLAGS.analytic_kl:
    rate = tfd.kl_divergence(approx_posterior, latent_prior)
  else:
    rate = (
        approx_posterior.log_prob(approx_posterior_sample) -
        latent_prior.log_prob(approx_posterior_sample))
  avg_rate = tf.reduce_mean(rate)
  tf.summary.scalar("rate", avg_rate)

  elbo_local = -(rate + distortion)

  elbo = tf.reduce_mean(elbo_local)
  tf.summary.scalar("elbo", elbo)

  if FLAGS.bilbo:
    bilbo = -0.5 * tf.reduce_sum(
        tf.log1p(
            posterior_batch_mean / posterior_batch_variance)) - avg_distortion
    tf.summary.scalar("bilbo", bilbo)
    loss = -bilbo
  else:
    loss = -elbo

  importance_weighted_elbo = tf.reduce_mean(
      tf.reduce_logsumexp(elbo_local, axis=0) -
      tf.math.log(tf.to_float(FLAGS.n_samples)))
  tf.summary.scalar("elbo/importance_weighted", importance_weighted_elbo)

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.train.get_or_create_global_step()
  learning_rate = tf.train.cosine_decay(
      FLAGS.learning_rate,
      tf.maximum(
          tf.cast(0, tf.int64),
          global_step - int(FLAGS.decay_start * FLAGS.max_steps)),
      int((1. - FLAGS.decay_start) * FLAGS.max_steps))
  tf.summary.scalar("learning_rate", learning_rate)
  optimizer = tf.train.AdamOptimizer(learning_rate)

  if mode == tf.estimator.ModeKeys.TRAIN:
    train_op = optimizer.minimize(loss, global_step=global_step)
  else:
    train_op = None

  eval_metric_ops = {}
  eval_metric_ops["elbo"] = tf.metrics.mean(elbo)
  eval_metric_ops["elbo/importance_weighted"] = tf.metrics.mean(
      importance_weighted_elbo)
  eval_metric_ops["rate"] = tf.metrics.mean(avg_rate)
  eval_metric_ops["distortion"] = tf.metrics.mean(avg_distortion)
  # This ugly hackery is necessary to get TF to visualize when running the
  # eval set, apparently.
  eval_metric_ops["img_summary_input"] = (img_summary_input, tf.no_op())
  eval_metric_ops["img_summary_recon"] = (img_summary_recon, tf.no_op())
  eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=eval_metric_ops,
  )