Ejemplo n.º 1
0
 def testImageLossfunChecksShape(self):
     """Tests that the image lossfun's checks input shapes."""
     x1 = np.ones((10, 16, 24, 3), np.float32)
     x2 = np.ones((10, 16, 16, 3), np.float32)
     lossfun = adaptive.AdaptiveImageLossFunction(x1.shape[1:], np.float32)
     with self.assertRaises(tf.errors.InvalidArgumentError):
         lossfun(x2)
Ejemplo n.º 2
0
 def testImageLossfunPreservesDtype(self, float_dtype):
     """Tests that the image lossfun's outputs precisions match its input."""
     x = float_dtype(np.random.uniform(size=(10, 64, 64, 3)))
     lossfun = adaptive.AdaptiveImageLossFunction(x.shape[1:], float_dtype)
     loss = lossfun(x).numpy()
     alpha = lossfun.alpha().numpy()
     scale = lossfun.scale().numpy()
     self.assertDTypeEqual(loss, float_dtype)
     self.assertDTypeEqual(alpha, float_dtype)
     self.assertDTypeEqual(scale, float_dtype)
Ejemplo n.º 3
0
    def testFittingImageDataIsCorrect(self, image_data_callback):
        """Tests that minimizing the adaptive image loss recovers the true model.

    Here we generate a stack of color images drawn from a normal distribution,
    and then minimize image_lossfun() with respect to the mean and scale of each
    distribution, and check that after minimization the estimated means are
    close to the true means.

    Args:
      image_data_callback: The function used to generate the training data and
        parameters used during optimization.
    """
        # Generate toy data.
        image_width = 4
        num_samples = 10
        wavelet_num_levels = 2  # Ignored by generate_pixel_toy_image_data().
        (samples, reference, color_space,
         representation) = image_data_callback(image_width, num_samples,
                                               wavelet_num_levels)

        # Construct the loss.
        mu = tf.Variable(tf.zeros(tf.shape(reference), samples.dtype))
        image_lossfun = adaptive.AdaptiveImageLossFunction(
            [image_width, image_width, 3],
            samples.dtype,
            color_space=color_space,
            representation=representation,
            wavelet_num_levels=wavelet_num_levels,
            alpha_lo=2,
            alpha_hi=2)
        trainable_variables = list(image_lossfun.trainable_variables) + [mu]

        init_rate = 1.
        final_rate = 0.01
        num_iters = 201
        learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
            init_rate, 1, (final_rate / init_rate)**(1. / num_iters))
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                             beta_1=0.5,
                                             beta_2=0.9,
                                             epsilon=1e-08)
        for _ in range(num_iters):
            optimizer.minimize(
                lambda: tf.reduce_mean(
                    image_lossfun(samples - mu[tf.newaxis, :])),
                trainable_variables)
        mu = mu.numpy()
        self.assertAllClose(mu, reference, rtol=0.01, atol=0.01)
Ejemplo n.º 4
0
 def __init__(self, imw, imh):
     alpha = 1  # fix to Charbonnier loss for now, as that usually works
     # well, and is similar to L1
     scale = 0.01  # because pixels are in [0, 1]
     wavelet_scale_base = 1  # this hyperparameter can have a huge effect
     # in how low-frequency errors are weighted against high-frequency
     # errors. Try setting this to 0.5 and 2 as well, and see what works
     # the best
     self.func = adaptive.AdaptiveImageLossFunction(
         (imh, imw, 3),
         tf.float32,
         color_space='YUV',
         representation='CDF9/7',
         summarize_loss=False,
         wavelet_num_levels=5,
         wavelet_scale_base=wavelet_scale_base,
         alpha_lo=alpha,
         alpha_hi=alpha,
         scale_lo=scale,
         scale_init=scale)
Ejemplo n.º 5
0
def model_fn(features, labels, mode, params, config):
  """Builds the model function for use in an estimator.

  Arguments:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some parameters, unused here.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, params, config

  if FLAGS.analytic_kl and FLAGS.mixture_components != 1:
    raise NotImplementedError(
        "Using `analytic_kl` is only supported when `mixture_components = 1` "
        "since there's no closed form otherwise.")
  if FLAGS.floating_prior and not (FLAGS.unit_posterior and
                                   FLAGS.mixture_components == 1):
    raise NotImplementedError(
        "Using `floating_prior` is only supported when `unit_posterior` = True "
        "since there's a scale ambiguity otherwise, and when "
        "`mixture_components = 1` since there's no closed form otherwise.")
  if FLAGS.fitted_samples and FLAGS.mixture_components != 1:
    raise NotImplementedError(
        "Using `fitted_samples` is only supported when "
        "`mixture_components = 1` since there's no closed form otherwise.")
  if FLAGS.bilbo and not FLAGS.floating_prior:
    raise NotImplementedError(
        "Using `bilbo` is only supported when `floating_prior = True`.")

  activation = tf.nn.leaky_relu
  encoder = make_encoder(activation, FLAGS.latent_size, FLAGS.base_depth)
  decoder = make_decoder(activation, FLAGS.latent_size, [IMAGE_SIZE] * 2 + [3],
                         FLAGS.base_depth)

  approx_posterior = encoder(features)
  approx_posterior_sample = approx_posterior.sample(FLAGS.n_samples)
  decoder_mu = decoder(approx_posterior_sample)

  if FLAGS.floating_prior or FLAGS.fitted_samples:
    posterior_batch_mean = tf.reduce_mean(approx_posterior.mean()**2, [0])
    posterior_batch_variance = tf.reduce_mean(approx_posterior.stddev()**2, [0])
    posterior_scale = posterior_batch_mean + posterior_batch_variance
    floating_prior = tfd.MultivariateNormalDiag(
        tf.zeros(FLAGS.latent_size), tf.sqrt(posterior_scale))
    tf.summary.scalar("posterior_scale", tf.reduce_sum(posterior_scale))

  if FLAGS.floating_prior:
    latent_prior = floating_prior
  else:
    latent_prior = make_mixture_prior(FLAGS.latent_size,
                                      FLAGS.mixture_components)

  # Decode samples from the prior for visualization.
  if FLAGS.fitted_samples:
    sample_distribution = floating_prior
  else:
    sample_distribution = latent_prior

  n_samples = VIZ_GRID_SIZE**2
  random_mu = decoder(sample_distribution.sample(n_samples))

  residual = tf.reshape(features - decoder_mu, [-1] + [IMAGE_SIZE] * 2 + [3])

  if FLAGS.use_students_t:
    lossfun = adaptive.AdaptiveImageLossFunction(
        residual.shape[1:],
        residual.dtype,
        color_space=FLAGS.color_space,
        representation=FLAGS.representation,
        wavelet_num_levels=FLAGS.wavelet_num_levels,
        wavelet_scale_base=FLAGS.wavelet_scale_base,
        use_students_t=FLAGS.use_students_t,
        scale_lo=FLAGS.scale_lo,
        scale_init=FLAGS.scale_init)
  else:
    lossfun = adaptive.AdaptiveImageLossFunction(
        residual.shape[1:],
        residual.dtype,
        color_space=FLAGS.color_space,
        representation=FLAGS.representation,
        wavelet_num_levels=FLAGS.wavelet_num_levels,
        wavelet_scale_base=FLAGS.wavelet_scale_base,
        use_students_t=FLAGS.use_students_t,
        alpha_lo=FLAGS.alpha_lo,
        alpha_hi=FLAGS.alpha_hi,
        alpha_init=FLAGS.alpha_init,
        scale_lo=FLAGS.scale_lo,
        scale_init=FLAGS.scale_init)

  nll = lossfun(residual)

  nll = tf.reshape(nll, [tf.shape(decoder_mu)[0],
                         tf.shape(decoder_mu)[1]] + [IMAGE_SIZE] * 2 + [3])

  # Clipping to prevent the loss from nanning out.
  max_val = np.finfo(np.float32).max
  nll = tf.clip_by_value(nll, -max_val, max_val)

  viz_n_inputs = np.int32(np.minimum(VIZ_MAX_N_INPUTS, FLAGS.batch_size))
  viz_n_samples = np.int32(np.minimum(VIZ_MAX_N_SAMPLES, FLAGS.n_samples))

  image_tile_summary("input", tf.to_float(features), rows=1, cols=viz_n_inputs)

  image_tile_summary(
      "recon/mean",
      decoder_mu[:viz_n_samples, :viz_n_inputs],
      rows=viz_n_samples,
      cols=viz_n_inputs)

  img_summary_input = image_tile_summary(
      "input1", tf.to_float(features), rows=viz_n_inputs, cols=1)
  img_summary_recon = image_tile_summary(
      "recon1", decoder_mu[:1, :viz_n_inputs], rows=viz_n_inputs, cols=1)

  image_tile_summary(
      "random/mean", random_mu, rows=VIZ_GRID_SIZE, cols=VIZ_GRID_SIZE)

  distortion = tf.reduce_sum(nll, axis=[2, 3, 4])

  avg_distortion = tf.reduce_mean(distortion)
  tf.summary.scalar("distortion", avg_distortion)

  if FLAGS.analytic_kl:
    rate = tfd.kl_divergence(approx_posterior, latent_prior)
  else:
    rate = (
        approx_posterior.log_prob(approx_posterior_sample) -
        latent_prior.log_prob(approx_posterior_sample))
  avg_rate = tf.reduce_mean(rate)
  tf.summary.scalar("rate", avg_rate)

  elbo_local = -(rate + distortion)

  elbo = tf.reduce_mean(elbo_local)
  tf.summary.scalar("elbo", elbo)

  if FLAGS.bilbo:
    bilbo = -0.5 * tf.reduce_sum(
        tf.log1p(
            posterior_batch_mean / posterior_batch_variance)) - avg_distortion
    tf.summary.scalar("bilbo", bilbo)
    loss = -bilbo
  else:
    loss = -elbo

  importance_weighted_elbo = tf.reduce_mean(
      tf.reduce_logsumexp(elbo_local, axis=0) -
      tf.math.log(tf.to_float(FLAGS.n_samples)))
  tf.summary.scalar("elbo/importance_weighted", importance_weighted_elbo)

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.train.get_or_create_global_step()
  learning_rate = tf.train.cosine_decay(
      FLAGS.learning_rate,
      tf.maximum(
          tf.cast(0, tf.int64),
          global_step - int(FLAGS.decay_start * FLAGS.max_steps)),
      int((1. - FLAGS.decay_start) * FLAGS.max_steps))
  tf.summary.scalar("learning_rate", learning_rate)
  optimizer = tf.train.AdamOptimizer(learning_rate)

  if mode == tf.estimator.ModeKeys.TRAIN:
    train_op = optimizer.minimize(loss, global_step=global_step)
  else:
    train_op = None

  eval_metric_ops = {}
  eval_metric_ops["elbo"] = tf.metrics.mean(elbo)
  eval_metric_ops["elbo/importance_weighted"] = tf.metrics.mean(
      importance_weighted_elbo)
  eval_metric_ops["rate"] = tf.metrics.mean(avg_rate)
  eval_metric_ops["distortion"] = tf.metrics.mean(avg_distortion)
  # This ugly hackery is necessary to get TF to visualize when running the
  # eval set, apparently.
  eval_metric_ops["img_summary_input"] = (img_summary_input, tf.no_op())
  eval_metric_ops["img_summary_recon"] = (img_summary_recon, tf.no_op())
  eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=eval_metric_ops,
  )