def importance_weighted_divergence_fn(q_samples): q_lp = precomputed_surrogate_log_prob if q_lp is None: q_lp = surrogate_posterior.log_prob(q_samples) target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) log_weights = target_log_prob - q_lp # Explicitly break out `importance_sample_size` as a separate axis. log_weights = tf.reshape( log_weights, ps.concat([[-1, importance_sample_size], ps.shape(log_weights)[1:]], axis=0)) log_sum_weights = tf.reduce_logsumexp(log_weights, axis=1) log_avg_weights = log_sum_weights - tf.math.log( tf.cast(importance_sample_size, dtype=log_weights.dtype)) if gradient_estimator == GradientEstimators.DOUBLY_REPARAMETERIZED: # Adapted from original implementation at # https://github.com/google-research/google-research/blob/master/dreg_estimators/model.py normalized_weights = tf.stop_gradient( tf.nn.softmax(log_weights, axis=1)) log_weights_with_stopped_q = tf.reshape( target_log_prob - stopped_surrogate_posterior.log_prob(q_samples), ps.shape(log_weights)) dreg_objective = tf.reduce_sum(log_weights_with_stopped_q * tf.square(normalized_weights), axis=1) # Replace the objective's gradient with the doubly-reparameterized # gradient. log_avg_weights = tf.stop_gradient(log_avg_weights) + ( dreg_objective - tf.stop_gradient(dreg_objective)) return discrepancy_fn(log_avg_weights)
def testArgsExpansion(self): def foo(a, b): return a + b t = structural_tuple.structtuple(['c', 'd']) self.assertEqual(3, nest_util.call_fn(foo, t(1, 2)))
def divergence_fn(q_samples): q_lp = precomputed_surrogate_log_prob target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) if gradient_estimator == GradientEstimators.DOUBLY_REPARAMETERIZED: # Sticking-the-landing is the special case of doubly-reparameterized # gradients with `importance_sample_size=1`. q_lp = stopped_surrogate_posterior.log_prob(q_samples) log_weights = target_log_prob - q_lp else: if q_lp is None: q_lp = surrogate_posterior.log_prob(q_samples) log_weights = target_log_prob - q_lp return discrepancy_fn(log_weights)
def divergence_fn(q_samples): q_lp = precomputed_surrogate_log_prob if q_lp is None: q_lp = surrogate_posterior.log_prob(q_samples) target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) log_weights = target_log_prob - q_lp if tf.get_static_value(importance_sample_size) == 1: # Bypass importance weighting. return discrepancy_fn(log_weights) # Explicitly break out `importance_sample_size` as a separate axis. log_weights = tf.reshape( log_weights, ps.concat([[-1, importance_sample_size], ps.shape(log_weights)[1:]], axis=0)) log_sum_weights = tf.reduce_logsumexp(log_weights, axis=1) log_avg_weights = log_sum_weights - tf.math.log( tf.cast(importance_sample_size, dtype=log_weights.dtype)) return discrepancy_fn(log_avg_weights)
def test_target_log_prob_fn(self): """Test the construction `target_log_prob_fn` from a joint distribution.""" def model_fn(): c = yield Root(tfd.LogNormal(0., 1., name='c')) b = yield tfd.Normal(c, 1., name='b') yield tfd.Normal(c + b, 1., name='a') model = tfd.JointDistributionCoroutine(model_fn, validate_args=True) def target_log_prob_fn(*args): return model.log_prob(args + (1., )) dtype = model.dtype[:-1] event_shape = model.event_shape[:-1] self.assertAllEqual(('c', 'b'), dtype._fields) self.assertAllEqual(('c', 'b'), event_shape._fields) test_point = tf.nest.map_structure(tf.zeros, event_shape, dtype) lp_manual = model.log_prob(test_point + (1., )) lp_tlp = nest_util.call_fn(target_log_prob_fn, test_point) self.assertAllClose(self.evaluate(lp_manual), self.evaluate(lp_tlp))
def _build_module(self): return nest_util.call_fn( self._base_class, self._args_fn(*self._param_args, **self._param_kwargs))
def testCallFnTwoArgs(self, arg): def fn(arg1, arg2): return arg1 + arg2 self.assertEqual(3, nest_util.call_fn(fn, arg))
def testCallFnOneArg(self, arg): def fn(arg): return arg self.assertEqual(tf.nest.flatten(arg), tf.nest.flatten(nest_util.call_fn(fn, arg)))
def csiszar_vimco(f, p_log_prob, q, num_draws, num_batch_draws=1, seed=None, name=None): """Use VIMCO to lower the variance of gradient[csiszar_function(log(Avg(u))]. This function generalizes VIMCO [(Mnih and Rezende, 2016)][1] to Csiszar f-Divergences. Note: if `q.reparameterization_type = tfd.FULLY_REPARAMETERIZED`, consider using `monte_carlo_variational_loss`. The VIMCO loss is: ```none vimco = f(log(Avg{u[i] : i=0,...,m-1})) where, logu[i] = log( p(x, h[i]) / q(h[i] | x) ) h[i] iid~ q(H | x) ``` Interestingly, the VIMCO gradient is not the naive gradient of `vimco`. Rather, it is characterized by: ```none grad[vimco] - variance_reducing_term where, variance_reducing_term = Sum{ grad[log q(h[i] | x)] * (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) : i=0, ..., m-1 } h[j;i] = { u[j] j!=i { GeometricAverage{ u[k] : k!=i} j==i ``` (We omitted `stop_gradient` for brevity. See implementation for more details.) The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th element has been replaced by the leave-`i`-out Geometric-average. This implementation prefers numerical precision over efficiency, i.e., `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`. (The constant may be fairly large, perhaps around 12.) Args: f: Python `callable` representing a Csiszar-function in log-space. p_log_prob: Python `callable` representing the natural-log of the probability under distribution `p`. (In variational inference `p` is the joint distribution.) q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and `log_prob(x)`. (In variational inference `q` is the approximate posterior distribution.) num_draws: Integer scalar number of draws used to approximate the f-Divergence expectation. num_batch_draws: Integer scalar number of draws used to approximate the f-Divergence expectation. seed: Python `int` seed for `q.sample`. name: Python `str` name prefixed to Ops created by this function. Returns: vimco: The Csiszar f-Divergence generalized VIMCO objective. Raises: ValueError: if `num_draws < 2`. #### References [1]: Andriy Mnih and Danilo Rezende. Variational Inference for Monte Carlo objectives. In _International Conference on Machine Learning_, 2016. https://arxiv.org/abs/1602.06725 """ with tf.name_scope(name or 'csiszar_vimco'): if num_draws < 2: raise ValueError('Must specify num_draws > 1.') stop = tf.stop_gradient # For readability. q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) x = tf.nest.map_structure(stop, q_sample) logqx = q.log_prob(x) logu = nest_util.call_fn(p_log_prob, x) - logqx f_log_sooavg_u, f_log_avg_u = map(f, log_soomean_exp(logu, axis=0)) dotprod = tf.reduce_sum( logqx * stop(f_log_avg_u - f_log_sooavg_u), axis=0) # Sum over iid samples. # We now rewrite f_log_avg_u so that: # `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`. # To achieve this, we use a trick that # `f(x) - stop(f(x)) == zeros_like(f(x))` # but its gradient is grad[f(x)]. # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence # this trick loses no precision. For more discussion regarding the relevant # portions of the IEEE754 standard, see the StackOverflow question, # "Is there a floating point value of x, for which x-x == 0 is false?" # http://stackoverflow.com/q/2686644 # Following is same as adding zeros_like(dot_prod). f_log_avg_u = f_log_avg_u + dotprod - stop(dotprod) return tf.reduce_mean(f_log_avg_u, axis=0) # Avg over batches.
def divergence_fn(q_samples): p_log_prob_term = nest_util.call_fn(p_log_prob, q_samples) return f(p_log_prob_term - q.log_prob(q_samples))
def divergence_fn(q_samples): target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) return discrepancy_fn( target_log_prob - surrogate_posterior.log_prob( q_samples))
def divergence_fn(q_samples, q_lp=None): target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) if q_lp is None: q_lp = surrogate_posterior.log_prob(q_samples) return discrepancy_fn(target_log_prob - q_lp)
def _call_target_log_prob_fn(self, x): return nest_util.call_fn(self.target_log_prob_fn, x)