Example #1
0
    def importance_weighted_divergence_fn(q_samples):
        q_lp = precomputed_surrogate_log_prob
        if q_lp is None:
            q_lp = surrogate_posterior.log_prob(q_samples)
        target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
        log_weights = target_log_prob - q_lp

        # Explicitly break out `importance_sample_size` as a separate axis.
        log_weights = tf.reshape(
            log_weights,
            ps.concat([[-1, importance_sample_size],
                       ps.shape(log_weights)[1:]],
                      axis=0))
        log_sum_weights = tf.reduce_logsumexp(log_weights, axis=1)
        log_avg_weights = log_sum_weights - tf.math.log(
            tf.cast(importance_sample_size, dtype=log_weights.dtype))

        if gradient_estimator == GradientEstimators.DOUBLY_REPARAMETERIZED:
            # Adapted from original implementation at
            # https://github.com/google-research/google-research/blob/master/dreg_estimators/model.py
            normalized_weights = tf.stop_gradient(
                tf.nn.softmax(log_weights, axis=1))
            log_weights_with_stopped_q = tf.reshape(
                target_log_prob -
                stopped_surrogate_posterior.log_prob(q_samples),
                ps.shape(log_weights))
            dreg_objective = tf.reduce_sum(log_weights_with_stopped_q *
                                           tf.square(normalized_weights),
                                           axis=1)
            # Replace the objective's gradient with the doubly-reparameterized
            # gradient.
            log_avg_weights = tf.stop_gradient(log_avg_weights) + (
                dreg_objective - tf.stop_gradient(dreg_objective))

        return discrepancy_fn(log_avg_weights)
  def testArgsExpansion(self):

    def foo(a, b):
      return a + b

    t = structural_tuple.structtuple(['c', 'd'])

    self.assertEqual(3, nest_util.call_fn(foo, t(1, 2)))
Example #3
0
    def divergence_fn(q_samples):
        q_lp = precomputed_surrogate_log_prob
        target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)

        if gradient_estimator == GradientEstimators.DOUBLY_REPARAMETERIZED:
            # Sticking-the-landing is the special case of doubly-reparameterized
            # gradients with `importance_sample_size=1`.
            q_lp = stopped_surrogate_posterior.log_prob(q_samples)
            log_weights = target_log_prob - q_lp
        else:
            if q_lp is None:
                q_lp = surrogate_posterior.log_prob(q_samples)
        log_weights = target_log_prob - q_lp
        return discrepancy_fn(log_weights)
Example #4
0
  def divergence_fn(q_samples):
    q_lp = precomputed_surrogate_log_prob
    if q_lp is None:
      q_lp = surrogate_posterior.log_prob(q_samples)

    target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
    log_weights = target_log_prob - q_lp
    if tf.get_static_value(importance_sample_size) == 1:
      # Bypass importance weighting.
      return discrepancy_fn(log_weights)

    # Explicitly break out `importance_sample_size` as a separate axis.
    log_weights = tf.reshape(
        log_weights,
        ps.concat([[-1, importance_sample_size],
                   ps.shape(log_weights)[1:]], axis=0))
    log_sum_weights = tf.reduce_logsumexp(log_weights, axis=1)
    log_avg_weights = log_sum_weights - tf.math.log(
        tf.cast(importance_sample_size, dtype=log_weights.dtype))
    return discrepancy_fn(log_avg_weights)
Example #5
0
    def test_target_log_prob_fn(self):
        """Test the construction `target_log_prob_fn` from a joint distribution."""
        def model_fn():
            c = yield Root(tfd.LogNormal(0., 1., name='c'))
            b = yield tfd.Normal(c, 1., name='b')
            yield tfd.Normal(c + b, 1., name='a')

        model = tfd.JointDistributionCoroutine(model_fn, validate_args=True)

        def target_log_prob_fn(*args):
            return model.log_prob(args + (1., ))

        dtype = model.dtype[:-1]
        event_shape = model.event_shape[:-1]
        self.assertAllEqual(('c', 'b'), dtype._fields)
        self.assertAllEqual(('c', 'b'), event_shape._fields)

        test_point = tf.nest.map_structure(tf.zeros, event_shape, dtype)
        lp_manual = model.log_prob(test_point + (1., ))
        lp_tlp = nest_util.call_fn(target_log_prob_fn, test_point)

        self.assertAllClose(self.evaluate(lp_manual), self.evaluate(lp_tlp))
Example #6
0
 def _build_module(self):
     return nest_util.call_fn(
         self._base_class,
         self._args_fn(*self._param_args, **self._param_kwargs))
Example #7
0
    def testCallFnTwoArgs(self, arg):
        def fn(arg1, arg2):
            return arg1 + arg2

        self.assertEqual(3, nest_util.call_fn(fn, arg))
Example #8
0
    def testCallFnOneArg(self, arg):
        def fn(arg):
            return arg

        self.assertEqual(tf.nest.flatten(arg),
                         tf.nest.flatten(nest_util.call_fn(fn, arg)))
Example #9
0
def csiszar_vimco(f,
                  p_log_prob,
                  q,
                  num_draws,
                  num_batch_draws=1,
                  seed=None,
                  name=None):
  """Use VIMCO to lower the variance of gradient[csiszar_function(log(Avg(u))].

  This function generalizes VIMCO [(Mnih and Rezende, 2016)][1] to Csiszar
  f-Divergences.

  Note: if `q.reparameterization_type = tfd.FULLY_REPARAMETERIZED`,
  consider using `monte_carlo_variational_loss`.

  The VIMCO loss is:

  ```none
  vimco = f(log(Avg{u[i] : i=0,...,m-1}))
  where,
    logu[i] = log( p(x, h[i]) / q(h[i] | x) )
    h[i] iid~ q(H | x)
  ```

  Interestingly, the VIMCO gradient is not the naive gradient of `vimco`.
  Rather, it is characterized by:

  ```none
  grad[vimco] - variance_reducing_term
  where,
    variance_reducing_term = Sum{ grad[log q(h[i] | x)] *
                                    (vimco - f(log Avg{h[j;i] : j=0,...,m-1}))
                                 : i=0, ..., m-1 }
    h[j;i] = { u[j]                             j!=i
             { GeometricAverage{ u[k] : k!=i}   j==i
  ```

  (We omitted `stop_gradient` for brevity. See implementation for more details.)

  The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th
  element has been replaced by the leave-`i`-out Geometric-average.

  This implementation prefers numerical precision over efficiency, i.e.,
  `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`.
  (The constant may be fairly large, perhaps around 12.)

  Args:
    f: Python `callable` representing a Csiszar-function in log-space.
    p_log_prob: Python `callable` representing the natural-log of the
      probability under distribution `p`. (In variational inference `p` is the
      joint distribution.)
    q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and
      `log_prob(x)`. (In variational inference `q` is the approximate posterior
      distribution.)
    num_draws: Integer scalar number of draws used to approximate the
      f-Divergence expectation.
    num_batch_draws: Integer scalar number of draws used to approximate the
      f-Divergence expectation.
    seed: Python `int` seed for `q.sample`.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    vimco: The Csiszar f-Divergence generalized VIMCO objective.

  Raises:
    ValueError: if `num_draws < 2`.

  #### References

  [1]: Andriy Mnih and Danilo Rezende. Variational Inference for Monte Carlo
       objectives. In _International Conference on Machine Learning_, 2016.
       https://arxiv.org/abs/1602.06725
  """
  with tf.name_scope(name or 'csiszar_vimco'):
    if num_draws < 2:
      raise ValueError('Must specify num_draws > 1.')
    stop = tf.stop_gradient  # For readability.

    q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed)
    x = tf.nest.map_structure(stop, q_sample)
    logqx = q.log_prob(x)
    logu = nest_util.call_fn(p_log_prob, x) - logqx
    f_log_sooavg_u, f_log_avg_u = map(f, log_soomean_exp(logu, axis=0))

    dotprod = tf.reduce_sum(
        logqx * stop(f_log_avg_u - f_log_sooavg_u),
        axis=0)  # Sum over iid samples.
    # We now rewrite f_log_avg_u so that:
    #   `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`.
    # To achieve this, we use a trick that
    #   `f(x) - stop(f(x)) == zeros_like(f(x))`
    # but its gradient is grad[f(x)].
    # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
    # this trick loses no precision. For more discussion regarding the relevant
    # portions of the IEEE754 standard, see the StackOverflow question,
    # "Is there a floating point value of x, for which x-x == 0 is false?"
    # http://stackoverflow.com/q/2686644
    # Following is same as adding zeros_like(dot_prod).
    f_log_avg_u = f_log_avg_u + dotprod - stop(dotprod)
    return tf.reduce_mean(f_log_avg_u, axis=0)  # Avg over batches.
Example #10
0
 def divergence_fn(q_samples):
     p_log_prob_term = nest_util.call_fn(p_log_prob, q_samples)
     return f(p_log_prob_term - q.log_prob(q_samples))
 def divergence_fn(q_samples):
   target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
   return discrepancy_fn(
       target_log_prob - surrogate_posterior.log_prob(
           q_samples))
 def divergence_fn(q_samples, q_lp=None):
     target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
     if q_lp is None:
         q_lp = surrogate_posterior.log_prob(q_samples)
     return discrepancy_fn(target_log_prob - q_lp)
 def _call_target_log_prob_fn(self, x):
     return nest_util.call_fn(self.target_log_prob_fn, x)