def csiszar_vimco_helper(logu, name=None):
    """Helper to `csiszar_vimco`; computes `log_avg_u`, `log_sooavg_u`.

  `axis = 0` of `logu` is presumed to correspond to iid samples from `q`, i.e.,

  ```none
  logu[j] = log(u[j])
  u[j] = p(x, h[j]) / q(h[j] | x)
  h[j] iid~ q(H | x)
  ```

  Args:
    logu: Floating-type `Tensor` representing `log(p(x, h) / q(h | x))`.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    log_avg_u: `logu.dtype` `Tensor` corresponding to the natural-log of the
      average of `u`. The sum of the gradient of `log_avg_u` is `1`.
    log_sooavg_u: `logu.dtype` `Tensor` characterized by the natural-log of the
      average of `u`` except that the average swaps-out `u[i]` for the
      leave-`i`-out Geometric-average. The mean of the gradient of
      `log_sooavg_u` is `1`. Mathematically `log_sooavg_u` is,
      ```none
      log_sooavg_u[i] = log(Avg{h[j ; i] : j=0, ..., m-1})
      h[j ; i] = { u[j]                              j!=i
                 { GeometricAverage{u[k] : k != i}   j==i
      ```

  """
    with tf.name_scope(name or 'csiszar_vimco_helper'):
        logu = tf.convert_to_tensor(logu, name='logu')
        return log_soomean_exp(logu, axis=0)[::-1]
Esempio n. 2
0
def get_vimco_local_learning_signal(elbo_tensor):
    """Get vimco local learning signal from batched ELBO.

  Args:
    elbo_tensor: a `float` Tensor of the shape [num_samples, batch_size].

  Returns:
    local_learning_signal: a `float` Tensor of the same shape as `input_tensor`,
      contains the multiplicative factor as described in Algorithm 1 of VIMCO,
      L_hat - L_hat^[-i].
  """
    assert_op = tf.debugging.assert_rank_at_least(
        elbo_tensor,
        rank=2,
        message='ELBO needs at least 2D, [sample, batch].')
    with tf.control_dependencies([assert_op]):
        # Calculate the log swap-one-out mean and log mean
        # log_soomean_f is of the same shape as f: [num_samples, batch]
        # log_mean_f is of the reduced shape: [1, batch]
        log_soomean_f, log_mean_f = log_soomean_exp(elbo_tensor,
                                                    axis=0,
                                                    keepdims=True)
        local_learning_signal = log_mean_f - log_soomean_f
        return local_learning_signal
Esempio n. 3
0
def csiszar_vimco(f,
                  p_log_prob,
                  q,
                  num_draws,
                  num_batch_draws=1,
                  seed=None,
                  name=None):
  """Use VIMCO to lower the variance of gradient[csiszar_function(log(Avg(u))].

  This function generalizes VIMCO [(Mnih and Rezende, 2016)][1] to Csiszar
  f-Divergences.

  Note: if `q.reparameterization_type = tfd.FULLY_REPARAMETERIZED`,
  consider using `monte_carlo_variational_loss`.

  The VIMCO loss is:

  ```none
  vimco = f(log(Avg{u[i] : i=0,...,m-1}))
  where,
    logu[i] = log( p(x, h[i]) / q(h[i] | x) )
    h[i] iid~ q(H | x)
  ```

  Interestingly, the VIMCO gradient is not the naive gradient of `vimco`.
  Rather, it is characterized by:

  ```none
  grad[vimco] - variance_reducing_term
  where,
    variance_reducing_term = Sum{ grad[log q(h[i] | x)] *
                                    (vimco - f(log Avg{h[j;i] : j=0,...,m-1}))
                                 : i=0, ..., m-1 }
    h[j;i] = { u[j]                             j!=i
             { GeometricAverage{ u[k] : k!=i}   j==i
  ```

  (We omitted `stop_gradient` for brevity. See implementation for more details.)

  The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th
  element has been replaced by the leave-`i`-out Geometric-average.

  This implementation prefers numerical precision over efficiency, i.e.,
  `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`.
  (The constant may be fairly large, perhaps around 12.)

  Args:
    f: Python `callable` representing a Csiszar-function in log-space.
    p_log_prob: Python `callable` representing the natural-log of the
      probability under distribution `p`. (In variational inference `p` is the
      joint distribution.)
    q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and
      `log_prob(x)`. (In variational inference `q` is the approximate posterior
      distribution.)
    num_draws: Integer scalar number of draws used to approximate the
      f-Divergence expectation.
    num_batch_draws: Integer scalar number of draws used to approximate the
      f-Divergence expectation.
    seed: Python `int` seed for `q.sample`.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    vimco: The Csiszar f-Divergence generalized VIMCO objective.

  Raises:
    ValueError: if `num_draws < 2`.

  #### References

  [1]: Andriy Mnih and Danilo Rezende. Variational Inference for Monte Carlo
       objectives. In _International Conference on Machine Learning_, 2016.
       https://arxiv.org/abs/1602.06725
  """
  with tf.name_scope(name or 'csiszar_vimco'):
    if num_draws < 2:
      raise ValueError('Must specify num_draws > 1.')
    stop = tf.stop_gradient  # For readability.

    q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed)
    x = tf.nest.map_structure(stop, q_sample)
    logqx = q.log_prob(x)
    logu = nest_util.call_fn(p_log_prob, x) - logqx
    f_log_sooavg_u, f_log_avg_u = map(f, log_soomean_exp(logu, axis=0))

    dotprod = tf.reduce_sum(
        logqx * stop(f_log_avg_u - f_log_sooavg_u),
        axis=0)  # Sum over iid samples.
    # We now rewrite f_log_avg_u so that:
    #   `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`.
    # To achieve this, we use a trick that
    #   `f(x) - stop(f(x)) == zeros_like(f(x))`
    # but its gradient is grad[f(x)].
    # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
    # this trick loses no precision. For more discussion regarding the relevant
    # portions of the IEEE754 standard, see the StackOverflow question,
    # "Is there a floating point value of x, for which x-x == 0 is false?"
    # http://stackoverflow.com/q/2686644
    # Following is same as adding zeros_like(dot_prod).
    f_log_avg_u = f_log_avg_u + dotprod - stop(dotprod)
    return tf.reduce_mean(f_log_avg_u, axis=0)  # Avg over batches.