def _as_trainable_family(distribution):
  """Substitutes prior distributions with more easily trainable ones."""
  with tf.name_scope('as_trainable_family'):

    if isinstance(distribution, half_normal.HalfNormal):
      return truncated_normal.TruncatedNormal(
          loc=0.,
          scale=distribution.scale,
          low=0.,
          high=distribution.scale * 10.)
    elif isinstance(distribution, uniform.Uniform):
      return shift.Shift(distribution.low)(
          scale_lib.Scale(distribution.high - distribution.low)(beta.Beta(
              concentration0=tf.ones(
                  distribution.event_shape_tensor(), dtype=distribution.dtype),
              concentration1=1.)))
    else:
      return distribution
Beispiel #2
0
  a specific `name`, if we had reason to think that a Normal distribution would
  be a good surrogate for some model variables but not others.

  """
    global ASVI_SURROGATE_SUBSTITUTIONS
    if inspect.isclass(condition):
        condition = lambda distribution, cls=condition: isinstance(  # pylint: disable=g-long-lambda
            distribution, cls)
    ASVI_SURROGATE_SUBSTITUTIONS[condition] = substitution_fn


# Default substitutions attempt to express distributions using the most
# flexible available parameterization.
# pylint: disable=g-long-lambda
register_asvi_substitution_rule(
    half_normal.HalfNormal, lambda dist: truncated_normal.TruncatedNormal(
        loc=0., scale=dist.scale, low=0., high=dist.scale * 10.))
register_asvi_substitution_rule(
    uniform.Uniform, lambda dist: shift.Shift(dist.low)
    (scale_lib.Scale(dist.high - dist.low)
     (beta.Beta(concentration0=tf.ones_like(dist.mean()), concentration1=1.))))
register_asvi_substitution_rule(
    exponential.Exponential,
    lambda dist: gamma.Gamma(concentration=1., rate=dist.rate))
register_asvi_substitution_rule(
    chi2.Chi2, lambda dist: gamma.Gamma(concentration=0.5 * dist.df, rate=0.5))

# pylint: enable=g-long-lambda


# TODO(kateslin): Add support for models with prior+likelihood written as
# a single JointDistribution.