Пример #1
0
def monte_carlo_csiszar_f_divergence(
    f,
    p_log_prob,
    q,
    num_draws,
    use_reparametrization=None,
    seed=None,
    name=None):
  """Monte-Carlo approximation of the Csiszar f-Divergence.

  A Csiszar-function is a member of,

  ```none
  F = { f:R_+ to R : f convex }.
  ```

  The Csiszar f-Divergence for Csiszar-function f is given by:

  ```none
  D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
                  ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
                             where x_j ~iid q(X)
  ```

  Tricks: Reparameterization and Score-Gradient

  When q is "reparameterized", i.e., a diffeomorphic transformation of a
  parameterless distribution (e.g.,
  `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
  expectation, i.e.,
  `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}`
  and `s_i = f(x_i), x_i ~iid q(X)`.

  However, if q is not reparameterized, TensorFlow's gradient will be incorrect
  since the chain-rule stops at samples of unreparameterized distributions. In
  this circumstance using the Score-Gradient trick results in an unbiased
  gradient, i.e.,

  ```none
  grad[ E_q[f(X)] ]
  = grad[ int dx q(x) f(x) ]
  = int dx grad[ q(x) f(x) ]
  = int dx [ q'(x) f(x) + q(x) f'(x) ]
  = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
  = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
  = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
  ```

  Unless `q.reparameterization_type != tfd.FULLY_REPARAMETERIZED` it is
  usually preferable to set `use_reparametrization = True`.

  Example Application:

  The Csiszar f-Divergence is a useful framework for variational inference.
  I.e., observe that,

  ```none
  f(p(x)) =  f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
          <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
          := D_f[p(x, Z), q(Z | x)]
  ```

  The inequality follows from the fact that the "perspective" of `f`, i.e.,
  `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and
  `t` is a real. Since the above framework includes the popular Evidence Lower
  BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework
  "Evidence Divergence Bound Optimization" (EDBO).

  Args:
    f: Python `callable` representing a Csiszar-function in log-space, i.e.,
      takes `p_log_prob(q_samples) - q.log_prob(q_samples)`.
    p_log_prob: Python `callable` taking (a batch of) samples from `q` and
      returning the natural-log of the probability under distribution `p`.
      (In variational inference `p` is the joint distribution.)
    q: `tf.Distribution`-like instance; must implement:
      `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`.
      (In variational inference `q` is the approximate posterior distribution.)
    num_draws: Integer scalar number of draws used to approximate the
      f-Divergence expectation.
    use_reparametrization: Python `bool`. When `None` (the default),
      automatically set to:
      `q.reparameterization_type == tfd.FULLY_REPARAMETERIZED`.
      When `True` uses the standard Monte-Carlo average. When `False` uses the
      score-gradient trick. (See above for details.)  When `False`, consider
      using `csiszar_vimco`.
    seed: Python `int` seed for `q.sample`.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    monte_carlo_csiszar_f_divergence: `float`-like `Tensor` Monte Carlo
      approximation of the Csiszar f-Divergence.

  Raises:
    ValueError: if `q` is not a reparameterized distribution and
      `use_reparametrization = True`. A distribution `q` is said to be
      "reparameterized" when its samples are generated by transforming the
      samples of another distribution which does not depend on the
      parameterization of `q`. This property ensures the gradient (with respect
      to parameters) is valid.
    TypeError: if `p_log_prob` is not a Python `callable`.
  """
  with tf.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]):
    if use_reparametrization is None:
      use_reparametrization = (q.reparameterization_type
                               == tfd.FULLY_REPARAMETERIZED)
    elif (use_reparametrization and
          q.reparameterization_type != tfd.FULLY_REPARAMETERIZED):
      # TODO(jvdillon): Consider only raising an exception if the gradient is
      # requested.
      raise ValueError(
          "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
          "transformation of a parameterless distribution. (Otherwise this "
          "function has a biased gradient.)")
    if not callable(p_log_prob):
      raise TypeError("`p_log_prob` must be a Python `callable` function.")
    return monte_carlo.expectation(
        f=lambda q_samples: f(p_log_prob(q_samples) - q.log_prob(q_samples)),
        samples=q.sample(num_draws, seed=seed),
        log_prob=q.log_prob,  # Only used if use_reparametrization=False.
        use_reparametrization=use_reparametrization)
Пример #2
0
def monte_carlo_variational_loss(target_log_prob_fn,
                                 surrogate_posterior,
                                 sample_size=1,
                                 importance_sample_size=1,
                                 discrepancy_fn=kl_reverse,
                                 use_reparameterization=None,
                                 seed=None,
                                 name=None):
    """Monte-Carlo approximation of an f-Divergence variational loss.

  Variational losses measure the divergence between an unnormalized target
  distribution `p` (provided via `target_log_prob_fn`) and a surrogate
  distribution `q` (provided as `surrogate_posterior`). When the
  target distribution is an unnormalized posterior from conditioning a model on
  data, minimizing the loss with respect to the parameters of
  `surrogate_posterior` performs approximate posterior inference.

  This function defines losses of the form
  `E_q[discrepancy_fn(log(u))]`, where `u = p(z) / q(z)` in the (default) case
  where `importance_sample_size == 1`, and
  `u = mean([p(z[k]) / q(z[k]) for k in range(importance_sample_size)]))` more
  generally. These losses are sometimes known as f-divergences [1, 2].

  The default behavior (`discrepancy_fn == tfp.vi.kl_reverse`, where
  `tfp.vi.kl_reverse = lambda logu: -logu`, and
  `importance_sample_size == 1`) computes an unbiased estimate of the standard
  evidence lower bound (ELBO) [3]. The bound may be tightened by setting
  `importance_sample_size > 1` [4], and the variance of the estimate reduced by
  setting `sample_size > 1`. Other discrepancies of interest
  available under `tfp.vi` include the forward `KL[p||q]`, total variation
  distance, Amari alpha-divergences, and [more](
  https://en.wikipedia.org/wiki/F-divergence).

  Args:
    target_log_prob_fn: Python callable that takes a set of `Tensor` arguments
      and returns a `Tensor` log-density. Given
      `q_sample = surrogate_posterior.sample(sample_size)`, this
      will be called as `target_log_prob_fn(*q_sample)` if `q_sample` is a list
      or a tuple, `target_log_prob_fn(**q_sample)` if `q_sample` is a
      dictionary, or `target_log_prob_fn(q_sample)` if `q_sample` is a `Tensor`.
      It should support batched evaluation, i.e., should return a result of
      shape `[sample_size]`.
    surrogate_posterior: A `tfp.distributions.Distribution`
      instance defining a variational posterior (could be a
      `tfd.JointDistribution`). Crucially, the distribution's `log_prob` and
      (if reparameterizeable) `sample` methods must directly invoke all ops
      that generate gradients to the underlying variables. One way to ensure
      this is to use `tfp.util.TransformedVariable` and/or
      `tfp.util.DeferredTensor` to represent any parameters defined as
      transformations of unconstrained variables, so that the transformations
      execute at runtime instead of at distribution creation.
    sample_size: Integer scalar number of Monte Carlo samples used to
      approximate the variational divergence. Larger values may stabilize
      the optimization, but at higher cost per step in time and memory.
      Default value: `1`.
    importance_sample_size: Python `int` number of terms used to define an
      importance-weighted divergence. If `importance_sample_size > 1`, then the
      `surrogate_posterior` is optimized to function as an importance-sampling
      proposal distribution. In this case it often makes sense to use
      importance sampling to approximate posterior expectations (see
      `tfp.vi.fit_surrogate_posterior` for an example).
      Default value: `1`.
    discrepancy_fn: Python `callable` representing a Csiszar `f` function in
      in log-space. That is, `discrepancy_fn(log(u)) = f(u)`, where `f` is
      convex in `u`.
      Default value: `tfp.vi.kl_reverse`.
    use_reparameterization: Python `bool`. When `None` (the default),
      automatically set to:
      `surrogate_posterior.reparameterization_type ==
      tfd.FULLY_REPARAMETERIZED`. When `True` uses the standard Monte-Carlo
      average. When `False` uses the score-gradient trick. (See above for
      details.)  When `False`, consider using `csiszar_vimco`.
    seed: PRNG seed for `surrogate_posterior.sample`; see
      `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    monte_carlo_variational_loss: `float`-like `Tensor` Monte Carlo
      approximation of the Csiszar f-Divergence.

  Raises:
    ValueError: if `surrogate_posterior` is not a reparameterized
      distribution and `use_reparameterization = True`. A distribution is said
      to be "reparameterized" when its samples are generated by transforming the
      samples of another distribution that does not depend on the first
      distribution's parameters. This property ensures the gradient with respect
      to parameters is valid.
    TypeError: if `target_log_prob_fn` is not a Python `callable`.

  #### Csiszar f-divergences

  A Csiszar function `f` is a convex function from `R^+` (the positive reals)
  to `R`. The Csiszar f-Divergence is given by:

  ```none
  D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
                  ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
                             where x_j ~iid q(X)
  ```

  For example, `f = lambda u: -log(u)` recovers `KL[q||p]`, while `f =
  lambda u: u * log(u)` recovers the forward `KL[p||q]`. These and other
  functions are available in `tfp.vi`.

  #### Tricks: Reparameterization and Score-Gradient

  When q is "reparameterized", i.e., a diffeomorphic transformation of a
  parameterless distribution (e.g.,
  `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
  expectation, i.e.,
  `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}`
  and `s_i = f(x_i), x_i ~iid q(X)`.

  However, if q is not reparameterized, TensorFlow's gradient will be incorrect
  since the chain-rule stops at samples of unreparameterized distributions. In
  this circumstance using the Score-Gradient trick results in an unbiased
  gradient, i.e.,

  ```none
  grad[ E_q[f(X)] ]
  = grad[ int dx q(x) f(x) ]
  = int dx grad[ q(x) f(x) ]
  = int dx [ q'(x) f(x) + q(x) f'(x) ]
  = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
  = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
  = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
  ```

  Unless `q.reparameterization_type != tfd.FULLY_REPARAMETERIZED` it is
  usually preferable to set `use_reparameterization = True`.

  #### Example Application:

  The Csiszar f-Divergence is a useful framework for variational inference.
  I.e., observe that,

  ```none
  f(p(x)) =  f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
          <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
          := D_f[p(x, Z), q(Z | x)]
  ```

  The inequality follows from the fact that the "perspective" of `f`, i.e.,
  `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and
  `t` is a real. Since the above framework includes the popular Evidence Lower
  BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework
  "Evidence Divergence Bound Optimization" (EDBO).

  #### References:

  [1]: https://en.wikipedia.org/wiki/F-divergence

  [2]: Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients
       of divergence of one distribution from another." Journal of the Royal
       Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.

  [3]: Christopher M. Bishop. Pattern Recognition and Machine Learning.
       Springer, 2006.

  [4]  Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance Weighted
       Autoencoders. In _International Conference on Learning
       Representations_, 2016. https://arxiv.org/abs/1509.00519

  """
    with tf.name_scope(name or 'monte_carlo_variational_loss'):
        reparameterization_types = tf.nest.flatten(
            surrogate_posterior.reparameterization_type)
        if use_reparameterization is None:
            use_reparameterization = all(
                reparameterization_type == FULLY_REPARAMETERIZED
                for reparameterization_type in reparameterization_types)
        elif (use_reparameterization and any(
                reparameterization_type != FULLY_REPARAMETERIZED
                for reparameterization_type in reparameterization_types)):
            # TODO(jvdillon): Consider only raising an exception if the gradient is
            # requested.
            raise ValueError(
                'Distribution `surrogate_posterior` must be reparameterized, i.e.,'
                'a diffeomorphic transformation of a parameterless distribution. '
                '(Otherwise this function has a biased gradient.)')
        if not callable(target_log_prob_fn):
            raise TypeError('`target_log_prob_fn` must be a Python `callable`'
                            'function.')

        if use_reparameterization:
            # Attempt to avoid bijector inverses by computing the surrogate log prob
            # during the forward sampling pass.
            q_samples, q_lp = surrogate_posterior.experimental_sample_and_log_prob(
                [sample_size * importance_sample_size], seed=seed)
        else:
            # Score fn objective requires explicit gradients of `log_prob`.
            q_samples = surrogate_posterior.sample(
                [sample_size * importance_sample_size], seed=seed)
            q_lp = None

        return monte_carlo.expectation(
            f=_make_importance_weighted_divergence_fn(
                target_log_prob_fn,
                surrogate_posterior=surrogate_posterior,
                discrepancy_fn=discrepancy_fn,
                precomputed_surrogate_log_prob=q_lp,
                importance_sample_size=importance_sample_size),
            samples=q_samples,
            # Log-prob is only used if use_reparameterization=False.
            log_prob=surrogate_posterior.log_prob,
            use_reparameterization=use_reparameterization)
def monte_carlo_csiszar_f_divergence(f,
                                     p_log_prob,
                                     q,
                                     num_draws,
                                     use_reparametrization=None,
                                     seed=None,
                                     name=None):
    """Monte-Carlo approximation of the Csiszar f-Divergence.

  A Csiszar-function is a member of,

  ```none
  F = { f:R_+ to R : f convex }.
  ```

  The Csiszar f-Divergence for Csiszar-function f is given by:

  ```none
  D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
                  ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
                             where x_j ~iid q(X)
  ```

  Tricks: Reparameterization and Score-Gradient

  When q is "reparameterized", i.e., a diffeomorphic transformation of a
  parameterless distribution (e.g.,
  `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
  expectation, i.e.,
  `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}`
  and `s_i = f(x_i), x_i ~iid q(X)`.

  However, if q is not reparameterized, TensorFlow's gradient will be incorrect
  since the chain-rule stops at samples of unreparameterized distributions. In
  this circumstance using the Score-Gradient trick results in an unbiased
  gradient, i.e.,

  ```none
  grad[ E_q[f(X)] ]
  = grad[ int dx q(x) f(x) ]
  = int dx grad[ q(x) f(x) ]
  = int dx [ q'(x) f(x) + q(x) f'(x) ]
  = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
  = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
  = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
  ```

  Unless `q.reparameterization_type != tfd.FULLY_REPARAMETERIZED` it is
  usually preferable to set `use_reparametrization = True`.

  Example Application:

  The Csiszar f-Divergence is a useful framework for variational inference.
  I.e., observe that,

  ```none
  f(p(x)) =  f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
          <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
          := D_f[p(x, Z), q(Z | x)]
  ```

  The inequality follows from the fact that the "perspective" of `f`, i.e.,
  `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and
  `t` is a real. Since the above framework includes the popular Evidence Lower
  BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework
  "Evidence Divergence Bound Optimization" (EDBO).

  Args:
    f: Python `callable` representing a Csiszar-function in log-space, i.e.,
      takes `p_log_prob(q_samples) - q.log_prob(q_samples)`.
    p_log_prob: Python `callable` taking (a batch of) samples from `q` and
      returning the natural-log of the probability under distribution `p`.
      (In variational inference `p` is the joint distribution.)
    q: `tf.Distribution`-like instance; must implement:
      `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`.
      (In variational inference `q` is the approximate posterior distribution.)
    num_draws: Integer scalar number of draws used to approximate the
      f-Divergence expectation.
    use_reparametrization: Python `bool`. When `None` (the default),
      automatically set to:
      `q.reparameterization_type == tfd.FULLY_REPARAMETERIZED`.
      When `True` uses the standard Monte-Carlo average. When `False` uses the
      score-gradient trick. (See above for details.)  When `False`, consider
      using `csiszar_vimco`.
    seed: Python `int` seed for `q.sample`.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    monte_carlo_csiszar_f_divergence: `float`-like `Tensor` Monte Carlo
      approximation of the Csiszar f-Divergence.

  Raises:
    ValueError: if `q` is not a reparameterized distribution and
      `use_reparametrization = True`. A distribution `q` is said to be
      "reparameterized" when its samples are generated by transforming the
      samples of another distribution which does not depend on the
      parameterization of `q`. This property ensures the gradient (with respect
      to parameters) is valid.
    TypeError: if `p_log_prob` is not a Python `callable`.
  """
    with tf.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]):
        if use_reparametrization is None:
            use_reparametrization = (
                q.reparameterization_type == tfd.FULLY_REPARAMETERIZED)
        elif (use_reparametrization
              and q.reparameterization_type != tfd.FULLY_REPARAMETERIZED):
            # TODO(jvdillon): Consider only raising an exception if the gradient is
            # requested.
            raise ValueError(
                "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
                "transformation of a parameterless distribution. (Otherwise this "
                "function has a biased gradient.)")
        if not callable(p_log_prob):
            raise TypeError(
                "`p_log_prob` must be a Python `callable` function.")
        return monte_carlo.expectation(
            f=lambda q_samples: f(
                p_log_prob(q_samples) - q.log_prob(q_samples)),
            samples=q.sample(num_draws, seed=seed),
            log_prob=q.log_prob,  # Only used if use_reparametrization=False.
            use_reparametrization=use_reparametrization)
Пример #4
0
def monte_carlo_variational_loss(target_log_prob_fn,
                                 surrogate_posterior,
                                 sample_size=1,
                                 importance_sample_size=1,
                                 discrepancy_fn=kl_reverse,
                                 use_reparameterization=None,
                                 gradient_estimator=None,
                                 stopped_surrogate_posterior=None,
                                 seed=None,
                                 name=None):
    """Monte-Carlo approximation of an f-Divergence variational loss.

  Variational losses measure the divergence between an unnormalized target
  distribution `p` (provided via `target_log_prob_fn`) and a surrogate
  distribution `q` (provided as `surrogate_posterior`). When the
  target distribution is an unnormalized posterior from conditioning a model on
  data, minimizing the loss with respect to the parameters of
  `surrogate_posterior` performs approximate posterior inference.

  This function defines losses of the form
  `E_q[discrepancy_fn(log(u))]`, where `u = p(z) / q(z)` in the (default) case
  where `importance_sample_size == 1`, and
  `u = mean([p(z[k]) / q(z[k]) for k in range(importance_sample_size)]))` more
  generally. These losses are sometimes known as f-divergences [1, 2].

  The default behavior (`discrepancy_fn == tfp.vi.kl_reverse`, where
  `tfp.vi.kl_reverse = lambda logu: -logu`, and
  `importance_sample_size == 1`) computes an unbiased estimate of the standard
  evidence lower bound (ELBO) [3]. The bound may be tightened by setting
  `importance_sample_size > 1` [4], and the variance of the estimate reduced by
  setting `sample_size > 1`. Other discrepancies of interest
  available under `tfp.vi` include the forward `KL[p||q]`, total variation
  distance, Amari alpha-divergences, and [more](
  https://en.wikipedia.org/wiki/F-divergence).

  Args:
    target_log_prob_fn: Python callable that takes a set of `Tensor` arguments
      and returns a `Tensor` log-density. Given
      `q_sample = surrogate_posterior.sample(sample_size)`, this
      will be called as `target_log_prob_fn(*q_sample)` if `q_sample` is a list
      or a tuple, `target_log_prob_fn(**q_sample)` if `q_sample` is a
      dictionary, or `target_log_prob_fn(q_sample)` if `q_sample` is a `Tensor`.
      It should support batched evaluation, i.e., should return a result of
      shape `[sample_size]`.
    surrogate_posterior: A `tfp.distributions.Distribution`
      instance defining a variational posterior (could be a
      `tfd.JointDistribution`). If using `tf.Variable` parameters, the
      distribution's `log_prob` and (if reparameterizeable) `sample` methods
      must directly invoke all ops that generate gradients to the underlying
      variables. One way to ensure this is to use `tfp.util.TransformedVariable`
      and/or `tfp.util.DeferredTensor` to represent any parameters defined as
      transformations of unconstrained variables, so that the transformations
      execute at runtime instead of at distribution creation.
    sample_size: Integer scalar number of Monte Carlo samples used to
      approximate the variational divergence. Larger values may stabilize
      the optimization, but at higher cost per step in time and memory.
      Default value: `1`.
    importance_sample_size: Python `int` number of terms used to define an
      importance-weighted divergence. If `importance_sample_size > 1`, then the
      `surrogate_posterior` is optimized to function as an importance-sampling
      proposal distribution. In this case it often makes sense to use
      importance sampling to approximate posterior expectations (see
      `tfp.vi.fit_surrogate_posterior` for an example).
      Default value: `1`.
    discrepancy_fn: Python `callable` representing a Csiszar `f` function in
      in log-space. That is, `discrepancy_fn(log(u)) = f(u)`, where `f` is
      convex in `u`.
      Default value: `tfp.vi.kl_reverse`.
    use_reparameterization: Deprecated; use `gradient_estimator` instead.
    gradient_estimator: Optional element from `tfp.vi.GradientEstimators`
      specifying the stochastic gradient estimator to associate with the
      variational loss. If `None`, a default estimator (either score-function or
      reparameterization) is chosen based on
      `surrogate_posterior.reparameterization_type`.
      Default value: `None`.
    stopped_surrogate_posterior: Optional copy of `surrogate_posterior` with
      stopped gradients to the parameters, e.g.,
      `tfd.Normal(loc=tf.stop_gradient(loc), scale=tf.stop_gradient(scale))`.
      Required if and only if
      `gradient_estimator == tfp.vi.GradientEstimators.DOUBLY_REPARAMETERIZED`.
      Default value: `None`.
    seed: PRNG seed for `surrogate_posterior.sample`; see
      `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    monte_carlo_variational_loss: `float`-like `Tensor` Monte Carlo
      approximation of the Csiszar f-Divergence.

  Raises:
    ValueError: if `surrogate_posterior` is not a reparameterized
      distribution and `use_reparameterization = True`. A distribution is said
      to be "reparameterized" when its samples are generated by transforming the
      samples of another distribution that does not depend on the first
      distribution's parameters. This property ensures the gradient with respect
      to parameters is valid.
    TypeError: if `target_log_prob_fn` is not a Python `callable`.

  #### Csiszar f-divergences

  A Csiszar function `f` is a convex function from `R^+` (the positive reals)
  to `R`. The Csiszar f-Divergence is given by:

  ```none
  D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
                  ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
                             where x_j ~iid q(X)
  ```

  For example, `f = lambda u: -log(u)` recovers `KL[q||p]`, while `f =
  lambda u: u * log(u)` recovers the forward `KL[p||q]`. These and other
  functions are available in `tfp.vi`.

  #### Example Application:

  The Csiszar f-Divergence is a useful framework for variational inference.
  I.e., observe that,

  ```none
  f(p(x)) =  f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
          <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
          := D_f[p(x, Z), q(Z | x)]
  ```

  The inequality follows from the fact that the "perspective" of `f`, i.e.,
  `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and
  `t` is a real. Since the above framework includes the popular Evidence Lower
  BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework
  "Evidence Divergence Bound Optimization" (EDBO).

  #### References:

  [1]: https://en.wikipedia.org/wiki/F-divergence

  [2]: Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients
       of divergence of one distribution from another." Journal of the Royal
       Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.

  [3]: Christopher M. Bishop. Pattern Recognition and Machine Learning.
       Springer, 2006.

  [4]  Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance Weighted
       Autoencoders. In _International Conference on Learning
       Representations_, 2016. https://arxiv.org/abs/1509.00519

  """
    with tf.name_scope(name or 'monte_carlo_variational_loss'):
        if not callable(target_log_prob_fn):
            raise TypeError('`target_log_prob_fn` must be a Python `callable`'
                            'function.')

        reparameterization_types = tf.nest.flatten(
            surrogate_posterior.reparameterization_type)
        if gradient_estimator is None:
            gradient_estimator = _choose_gradient_estimator(
                use_reparameterization=use_reparameterization,
                reparameterization_types=reparameterization_types)

        if gradient_estimator == GradientEstimators.VIMCO:
            return csiszar_vimco(f=discrepancy_fn,
                                 p_log_prob=target_log_prob_fn,
                                 q=surrogate_posterior,
                                 num_draws=importance_sample_size,
                                 num_batch_draws=sample_size,
                                 seed=seed)
        if gradient_estimator == GradientEstimators.SCORE_FUNCTION:
            if tf.get_static_value(importance_sample_size) != 1:
                # TODO(b/213378570): Support score function gradients for
                # importance-weighted bounds.
                raise ValueError(
                    'Score-function gradients are not supported for '
                    'losses with `importance_sample_size != 1`.')
            # Score fn objective requires explicit gradients of `log_prob`.
            q_samples = surrogate_posterior.sample(
                [sample_size * importance_sample_size], seed=seed)
            q_lp = None
        else:
            if any(reparameterization_type != FULLY_REPARAMETERIZED
                   for reparameterization_type in reparameterization_types):
                warnings.warn(
                    'Reparameterization gradients requested, but '
                    '`surrogate_posterior.reparameterization_type` is not fully '
                    'reparameterized (saw: {}). Gradient estimates may be '
                    'biased.'.format(
                        surrogate_posterior.reparameterization_type))
            # Attempt to avoid bijector inverses by computing the surrogate log prob
            # during the forward sampling pass.
            q_samples, q_lp = surrogate_posterior.experimental_sample_and_log_prob(
                [sample_size * importance_sample_size], seed=seed)

        return monte_carlo.expectation(
            f=_make_importance_weighted_divergence_fn(
                target_log_prob_fn,
                surrogate_posterior=surrogate_posterior,
                discrepancy_fn=discrepancy_fn,
                precomputed_surrogate_log_prob=q_lp,
                importance_sample_size=importance_sample_size,
                gradient_estimator=gradient_estimator,
                stopped_surrogate_posterior=(stopped_surrogate_posterior)),
            samples=q_samples,
            # Log-prob is only used if `gradient_estimator == SCORE_FUNCTION`.
            log_prob=surrogate_posterior.log_prob,
            use_reparameterization=(gradient_estimator !=
                                    GradientEstimators.SCORE_FUNCTION))