def testImageLossfunPreservesDtype(self, float_dtype): """Tests that image_lossfun's outputs precisions match its input.""" x = float_dtype(np.random.uniform(size=(10, 64, 64, 3))) loss, alpha, scale = adaptive.image_lossfun(x) with self.session() as sess: sess.run(tf.global_variables_initializer()) loss, alpha, scale = sess.run([loss, alpha, scale]) self.assertDTypeEqual(loss, float_dtype) self.assertDTypeEqual(alpha, float_dtype) self.assertDTypeEqual(scale, float_dtype)
def testFittingImageDataIsCorrect(self, image_data_callback): """Tests that minimizing the adaptive image loss recovers the true model. Here we generate a stack of color images drawn from a normal distribution, and then minimize image_lossfun() with respect to the mean and scale of each distribution, and check that after minimization the estimated means are close to the true means. Args: image_data_callback: The function used to generate the training data and parameters used during optimization. """ # Generate toy data. image_width = 4 num_samples = 10 wavelet_num_levels = 2 # Ignored by _generate_pixel_toy_image_data(). (samples, reference, color_space, representation) = image_data_callback(image_width, num_samples, wavelet_num_levels) # Construct the loss. prediction = tf.Variable(tf.zeros(tf.shape(reference), tf.float64)) x = samples - prediction[tf.newaxis, :] loss, alpha, scale = adaptive.image_lossfun( x, color_space=color_space, representation=representation, wavelet_num_levels=wavelet_num_levels, alpha_lo=2, alpha_hi=2) loss = tf.reduce_mean(loss) # Minimize the loss. with self.session() as sess: init_rate = 0.1 final_rate = 0.01 num_iters = 201 global_step = tf.Variable(0, trainable=False) t = tf.cast(global_step, tf.float32) / (num_iters - 1) rate = tf.math.exp( tf.math.log(init_rate) * (1. - t) + tf.math.log(final_rate) * t) optimizer = tf.train.AdamOptimizer(learning_rate=rate, beta1=0.5, beta2=0.9, epsilon=1e-08) step = optimizer.minimize(loss, global_step=global_step) sess.run(tf.global_variables_initializer()) for _ in range(num_iters): _ = sess.run(step) scale, alpha, prediction = sess.run([scale, alpha, prediction]) self.assertAllClose(prediction, reference, rtol=0.01, atol=0.01)
def model_fn(features, labels, mode, params, config): """Builds the model function for use in an estimator. Arguments: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. params: Some parameters, unused here. config: The RunConfig, unused here. Returns: EstimatorSpec: A tf.estimator.EstimatorSpec instance. """ del labels, params, config if FLAGS.analytic_kl and FLAGS.mixture_components != 1: raise NotImplementedError( "Using `analytic_kl` is only supported when `mixture_components = 1` " "since there's no closed form otherwise.") if FLAGS.floating_prior and not (FLAGS.unit_posterior and FLAGS.mixture_components == 1): raise NotImplementedError( "Using `floating_prior` is only supported when `unit_posterior` = True " "since there's a scale ambiguity otherwise, and when " "`mixture_components = 1` since there's no closed form otherwise.") if FLAGS.fitted_samples and FLAGS.mixture_components != 1: raise NotImplementedError( "Using `fitted_samples` is only supported when " "`mixture_components = 1` since there's no closed form otherwise.") if FLAGS.bilbo and not FLAGS.floating_prior: raise NotImplementedError( "Using `bilbo` is only supported when `floating_prior = True`.") activation = tf.nn.leaky_relu encoder = make_encoder(activation, FLAGS.latent_size, FLAGS.base_depth) decoder = make_decoder(activation, FLAGS.latent_size, [IMAGE_SIZE] * 2 + [3], FLAGS.base_depth) approx_posterior = encoder(features) approx_posterior_sample = approx_posterior.sample(FLAGS.n_samples) decoder_mu = decoder(approx_posterior_sample) if FLAGS.floating_prior or FLAGS.fitted_samples: posterior_batch_mean = tf.reduce_mean(approx_posterior.mean()**2, [0]) posterior_batch_variance = tf.reduce_mean(approx_posterior.stddev()**2, [0]) posterior_scale = posterior_batch_mean + posterior_batch_variance floating_prior = tfd.MultivariateNormalDiag( tf.zeros(FLAGS.latent_size), tf.sqrt(posterior_scale)) tf.summary.scalar("posterior_scale", tf.reduce_sum(posterior_scale)) if FLAGS.floating_prior: latent_prior = floating_prior else: latent_prior = make_mixture_prior(FLAGS.latent_size, FLAGS.mixture_components) # Decode samples from the prior for visualization. if FLAGS.fitted_samples: sample_distribution = floating_prior else: sample_distribution = latent_prior n_samples = VIZ_GRID_SIZE**2 random_mu = decoder(sample_distribution.sample(n_samples)) residual = tf.reshape(features - decoder_mu, [-1] + [IMAGE_SIZE] * 2 + [3]) if FLAGS.use_students_t: nll = adaptive.image_lossfun( residual, color_space=FLAGS.color_space, representation=FLAGS.representation, wavelet_num_levels=FLAGS.wavelet_num_levels, wavelet_scale_base=FLAGS.wavelet_scale_base, use_students_t=FLAGS.use_students_t, scale_lo=FLAGS.scale_lo, scale_init=FLAGS.scale_init)[0] else: nll = adaptive.image_lossfun( residual, color_space=FLAGS.color_space, representation=FLAGS.representation, wavelet_num_levels=FLAGS.wavelet_num_levels, wavelet_scale_base=FLAGS.wavelet_scale_base, use_students_t=FLAGS.use_students_t, alpha_lo=FLAGS.alpha_lo, alpha_hi=FLAGS.alpha_hi, alpha_init=FLAGS.alpha_init, scale_lo=FLAGS.scale_lo, scale_init=FLAGS.scale_init)[0] nll = tf.reshape(nll, [tf.shape(decoder_mu)[0], tf.shape(decoder_mu)[1]] + [IMAGE_SIZE] * 2 + [3]) # Clipping to prevent the loss from nanning out. max_val = np.finfo(np.float32).max nll = tf.clip_by_value(nll, -max_val, max_val) viz_n_inputs = np.int32(np.minimum(VIZ_MAX_N_INPUTS, FLAGS.batch_size)) viz_n_samples = np.int32(np.minimum(VIZ_MAX_N_SAMPLES, FLAGS.n_samples)) image_tile_summary("input", tf.to_float(features), rows=1, cols=viz_n_inputs) image_tile_summary( "recon/mean", decoder_mu[:viz_n_samples, :viz_n_inputs], rows=viz_n_samples, cols=viz_n_inputs) img_summary_input = image_tile_summary( "input1", tf.to_float(features), rows=viz_n_inputs, cols=1) img_summary_recon = image_tile_summary( "recon1", decoder_mu[:1, :viz_n_inputs], rows=viz_n_inputs, cols=1) image_tile_summary( "random/mean", random_mu, rows=VIZ_GRID_SIZE, cols=VIZ_GRID_SIZE) distortion = tf.reduce_sum(nll, axis=[2, 3, 4]) avg_distortion = tf.reduce_mean(distortion) tf.summary.scalar("distortion", avg_distortion) if FLAGS.analytic_kl: rate = tfd.kl_divergence(approx_posterior, latent_prior) else: rate = ( approx_posterior.log_prob(approx_posterior_sample) - latent_prior.log_prob(approx_posterior_sample)) avg_rate = tf.reduce_mean(rate) tf.summary.scalar("rate", avg_rate) elbo_local = -(rate + distortion) elbo = tf.reduce_mean(elbo_local) tf.summary.scalar("elbo", elbo) if FLAGS.bilbo: bilbo = -0.5 * tf.reduce_sum( tf.log1p( posterior_batch_mean / posterior_batch_variance)) - avg_distortion tf.summary.scalar("bilbo", bilbo) loss = -bilbo else: loss = -elbo importance_weighted_elbo = tf.reduce_mean( tf.reduce_logsumexp(elbo_local, axis=0) - tf.math.log(tf.to_float(FLAGS.n_samples))) tf.summary.scalar("elbo/importance_weighted", importance_weighted_elbo) # Perform variational inference by minimizing the -ELBO. global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.cosine_decay( FLAGS.learning_rate, tf.maximum( tf.cast(0, tf.int64), global_step - int(FLAGS.decay_start * FLAGS.max_steps)), int((1. - FLAGS.decay_start) * FLAGS.max_steps)) tf.summary.scalar("learning_rate", learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimizer.minimize(loss, global_step=global_step) else: train_op = None eval_metric_ops = {} eval_metric_ops["elbo"] = tf.metrics.mean(elbo) eval_metric_ops["elbo/importance_weighted"] = tf.metrics.mean( importance_weighted_elbo) eval_metric_ops["rate"] = tf.metrics.mean(avg_rate) eval_metric_ops["distortion"] = tf.metrics.mean(avg_distortion) # This ugly hackery is necessary to get TF to visualize when running the # eval set, apparently. eval_metric_ops["img_summary_input"] = (img_summary_input, tf.no_op()) eval_metric_ops["img_summary_recon"] = (img_summary_recon, tf.no_op()) eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()} return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, )