def testImageLossfunChecksShape(self): """Tests that the image lossfun's checks input shapes.""" x1 = np.ones((10, 16, 24, 3), np.float32) x2 = np.ones((10, 16, 16, 3), np.float32) lossfun = adaptive.AdaptiveImageLossFunction(x1.shape[1:], np.float32) with self.assertRaises(tf.errors.InvalidArgumentError): lossfun(x2)
def testImageLossfunPreservesDtype(self, float_dtype): """Tests that the image lossfun's outputs precisions match its input.""" x = float_dtype(np.random.uniform(size=(10, 64, 64, 3))) lossfun = adaptive.AdaptiveImageLossFunction(x.shape[1:], float_dtype) loss = lossfun(x).numpy() alpha = lossfun.alpha().numpy() scale = lossfun.scale().numpy() self.assertDTypeEqual(loss, float_dtype) self.assertDTypeEqual(alpha, float_dtype) self.assertDTypeEqual(scale, float_dtype)
def testFittingImageDataIsCorrect(self, image_data_callback): """Tests that minimizing the adaptive image loss recovers the true model. Here we generate a stack of color images drawn from a normal distribution, and then minimize image_lossfun() with respect to the mean and scale of each distribution, and check that after minimization the estimated means are close to the true means. Args: image_data_callback: The function used to generate the training data and parameters used during optimization. """ # Generate toy data. image_width = 4 num_samples = 10 wavelet_num_levels = 2 # Ignored by generate_pixel_toy_image_data(). (samples, reference, color_space, representation) = image_data_callback(image_width, num_samples, wavelet_num_levels) # Construct the loss. mu = tf.Variable(tf.zeros(tf.shape(reference), samples.dtype)) image_lossfun = adaptive.AdaptiveImageLossFunction( [image_width, image_width, 3], samples.dtype, color_space=color_space, representation=representation, wavelet_num_levels=wavelet_num_levels, alpha_lo=2, alpha_hi=2) trainable_variables = list(image_lossfun.trainable_variables) + [mu] init_rate = 1. final_rate = 0.01 num_iters = 201 learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( init_rate, 1, (final_rate / init_rate)**(1. / num_iters)) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5, beta_2=0.9, epsilon=1e-08) for _ in range(num_iters): optimizer.minimize( lambda: tf.reduce_mean( image_lossfun(samples - mu[tf.newaxis, :])), trainable_variables) mu = mu.numpy() self.assertAllClose(mu, reference, rtol=0.01, atol=0.01)
def __init__(self, imw, imh): alpha = 1 # fix to Charbonnier loss for now, as that usually works # well, and is similar to L1 scale = 0.01 # because pixels are in [0, 1] wavelet_scale_base = 1 # this hyperparameter can have a huge effect # in how low-frequency errors are weighted against high-frequency # errors. Try setting this to 0.5 and 2 as well, and see what works # the best self.func = adaptive.AdaptiveImageLossFunction( (imh, imw, 3), tf.float32, color_space='YUV', representation='CDF9/7', summarize_loss=False, wavelet_num_levels=5, wavelet_scale_base=wavelet_scale_base, alpha_lo=alpha, alpha_hi=alpha, scale_lo=scale, scale_init=scale)
def model_fn(features, labels, mode, params, config): """Builds the model function for use in an estimator. Arguments: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. params: Some parameters, unused here. config: The RunConfig, unused here. Returns: EstimatorSpec: A tf.estimator.EstimatorSpec instance. """ del labels, params, config if FLAGS.analytic_kl and FLAGS.mixture_components != 1: raise NotImplementedError( "Using `analytic_kl` is only supported when `mixture_components = 1` " "since there's no closed form otherwise.") if FLAGS.floating_prior and not (FLAGS.unit_posterior and FLAGS.mixture_components == 1): raise NotImplementedError( "Using `floating_prior` is only supported when `unit_posterior` = True " "since there's a scale ambiguity otherwise, and when " "`mixture_components = 1` since there's no closed form otherwise.") if FLAGS.fitted_samples and FLAGS.mixture_components != 1: raise NotImplementedError( "Using `fitted_samples` is only supported when " "`mixture_components = 1` since there's no closed form otherwise.") if FLAGS.bilbo and not FLAGS.floating_prior: raise NotImplementedError( "Using `bilbo` is only supported when `floating_prior = True`.") activation = tf.nn.leaky_relu encoder = make_encoder(activation, FLAGS.latent_size, FLAGS.base_depth) decoder = make_decoder(activation, FLAGS.latent_size, [IMAGE_SIZE] * 2 + [3], FLAGS.base_depth) approx_posterior = encoder(features) approx_posterior_sample = approx_posterior.sample(FLAGS.n_samples) decoder_mu = decoder(approx_posterior_sample) if FLAGS.floating_prior or FLAGS.fitted_samples: posterior_batch_mean = tf.reduce_mean(approx_posterior.mean()**2, [0]) posterior_batch_variance = tf.reduce_mean(approx_posterior.stddev()**2, [0]) posterior_scale = posterior_batch_mean + posterior_batch_variance floating_prior = tfd.MultivariateNormalDiag( tf.zeros(FLAGS.latent_size), tf.sqrt(posterior_scale)) tf.summary.scalar("posterior_scale", tf.reduce_sum(posterior_scale)) if FLAGS.floating_prior: latent_prior = floating_prior else: latent_prior = make_mixture_prior(FLAGS.latent_size, FLAGS.mixture_components) # Decode samples from the prior for visualization. if FLAGS.fitted_samples: sample_distribution = floating_prior else: sample_distribution = latent_prior n_samples = VIZ_GRID_SIZE**2 random_mu = decoder(sample_distribution.sample(n_samples)) residual = tf.reshape(features - decoder_mu, [-1] + [IMAGE_SIZE] * 2 + [3]) if FLAGS.use_students_t: lossfun = adaptive.AdaptiveImageLossFunction( residual.shape[1:], residual.dtype, color_space=FLAGS.color_space, representation=FLAGS.representation, wavelet_num_levels=FLAGS.wavelet_num_levels, wavelet_scale_base=FLAGS.wavelet_scale_base, use_students_t=FLAGS.use_students_t, scale_lo=FLAGS.scale_lo, scale_init=FLAGS.scale_init) else: lossfun = adaptive.AdaptiveImageLossFunction( residual.shape[1:], residual.dtype, color_space=FLAGS.color_space, representation=FLAGS.representation, wavelet_num_levels=FLAGS.wavelet_num_levels, wavelet_scale_base=FLAGS.wavelet_scale_base, use_students_t=FLAGS.use_students_t, alpha_lo=FLAGS.alpha_lo, alpha_hi=FLAGS.alpha_hi, alpha_init=FLAGS.alpha_init, scale_lo=FLAGS.scale_lo, scale_init=FLAGS.scale_init) nll = lossfun(residual) nll = tf.reshape(nll, [tf.shape(decoder_mu)[0], tf.shape(decoder_mu)[1]] + [IMAGE_SIZE] * 2 + [3]) # Clipping to prevent the loss from nanning out. max_val = np.finfo(np.float32).max nll = tf.clip_by_value(nll, -max_val, max_val) viz_n_inputs = np.int32(np.minimum(VIZ_MAX_N_INPUTS, FLAGS.batch_size)) viz_n_samples = np.int32(np.minimum(VIZ_MAX_N_SAMPLES, FLAGS.n_samples)) image_tile_summary("input", tf.to_float(features), rows=1, cols=viz_n_inputs) image_tile_summary( "recon/mean", decoder_mu[:viz_n_samples, :viz_n_inputs], rows=viz_n_samples, cols=viz_n_inputs) img_summary_input = image_tile_summary( "input1", tf.to_float(features), rows=viz_n_inputs, cols=1) img_summary_recon = image_tile_summary( "recon1", decoder_mu[:1, :viz_n_inputs], rows=viz_n_inputs, cols=1) image_tile_summary( "random/mean", random_mu, rows=VIZ_GRID_SIZE, cols=VIZ_GRID_SIZE) distortion = tf.reduce_sum(nll, axis=[2, 3, 4]) avg_distortion = tf.reduce_mean(distortion) tf.summary.scalar("distortion", avg_distortion) if FLAGS.analytic_kl: rate = tfd.kl_divergence(approx_posterior, latent_prior) else: rate = ( approx_posterior.log_prob(approx_posterior_sample) - latent_prior.log_prob(approx_posterior_sample)) avg_rate = tf.reduce_mean(rate) tf.summary.scalar("rate", avg_rate) elbo_local = -(rate + distortion) elbo = tf.reduce_mean(elbo_local) tf.summary.scalar("elbo", elbo) if FLAGS.bilbo: bilbo = -0.5 * tf.reduce_sum( tf.log1p( posterior_batch_mean / posterior_batch_variance)) - avg_distortion tf.summary.scalar("bilbo", bilbo) loss = -bilbo else: loss = -elbo importance_weighted_elbo = tf.reduce_mean( tf.reduce_logsumexp(elbo_local, axis=0) - tf.math.log(tf.to_float(FLAGS.n_samples))) tf.summary.scalar("elbo/importance_weighted", importance_weighted_elbo) # Perform variational inference by minimizing the -ELBO. global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.cosine_decay( FLAGS.learning_rate, tf.maximum( tf.cast(0, tf.int64), global_step - int(FLAGS.decay_start * FLAGS.max_steps)), int((1. - FLAGS.decay_start) * FLAGS.max_steps)) tf.summary.scalar("learning_rate", learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimizer.minimize(loss, global_step=global_step) else: train_op = None eval_metric_ops = {} eval_metric_ops["elbo"] = tf.metrics.mean(elbo) eval_metric_ops["elbo/importance_weighted"] = tf.metrics.mean( importance_weighted_elbo) eval_metric_ops["rate"] = tf.metrics.mean(avg_rate) eval_metric_ops["distortion"] = tf.metrics.mean(avg_distortion) # This ugly hackery is necessary to get TF to visualize when running the # eval set, apparently. eval_metric_ops["img_summary_input"] = (img_summary_input, tf.no_op()) eval_metric_ops["img_summary_recon"] = (img_summary_recon, tf.no_op()) eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()} return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, )