def _as_trainable_family(distribution): """Substitutes prior distributions with more easily trainable ones.""" with tf.name_scope('as_trainable_family'): if isinstance(distribution, half_normal.HalfNormal): return truncated_normal.TruncatedNormal( loc=0., scale=distribution.scale, low=0., high=distribution.scale * 10.) elif isinstance(distribution, uniform.Uniform): return shift.Shift(distribution.low)( scale_lib.Scale(distribution.high - distribution.low)(beta.Beta( concentration0=tf.ones( distribution.event_shape_tensor(), dtype=distribution.dtype), concentration1=1.))) else: return distribution
a specific `name`, if we had reason to think that a Normal distribution would be a good surrogate for some model variables but not others. """ global ASVI_SURROGATE_SUBSTITUTIONS if inspect.isclass(condition): condition = lambda distribution, cls=condition: isinstance( # pylint: disable=g-long-lambda distribution, cls) ASVI_SURROGATE_SUBSTITUTIONS[condition] = substitution_fn # Default substitutions attempt to express distributions using the most # flexible available parameterization. # pylint: disable=g-long-lambda register_asvi_substitution_rule( half_normal.HalfNormal, lambda dist: truncated_normal.TruncatedNormal( loc=0., scale=dist.scale, low=0., high=dist.scale * 10.)) register_asvi_substitution_rule( uniform.Uniform, lambda dist: shift.Shift(dist.low) (scale_lib.Scale(dist.high - dist.low) (beta.Beta(concentration0=tf.ones_like(dist.mean()), concentration1=1.)))) register_asvi_substitution_rule( exponential.Exponential, lambda dist: gamma.Gamma(concentration=1., rate=dist.rate)) register_asvi_substitution_rule( chi2.Chi2, lambda dist: gamma.Gamma(concentration=0.5 * dist.df, rate=0.5)) # pylint: enable=g-long-lambda # TODO(kateslin): Add support for models with prior+likelihood written as # a single JointDistribution.