def experimental_default_event_space_bijector(self, *args, **kwargs): """A bijector to pull back unpinned values to unconstrained reals.""" if args or kwargs: return (self.experimental_pin( *args, **kwargs).experimental_default_event_space_bijector()) if self.use_vectorized_map: return _DefaultJointBijectorAutoBatchedWithPins(self) return joint_distribution._DefaultJointBijector(self)
def __init__(self, jd, **kwargs): parameters = dict(locals()) self._jd = jd self._bijector_kwargs = kwargs self._joint_bijector = joint_distribution_lib._DefaultJointBijector( jd=self._jd, **self._bijector_kwargs) super(_DefaultJointBijectorAutoBatched, self).__init__(forward_min_event_ndims=self._joint_bijector. forward_min_event_ndims, inverse_min_event_ndims=self._joint_bijector. inverse_min_event_ndims, validate_args=self._joint_bijector.validate_args, parameters=parameters, name=self._joint_bijector.name) # Wrap the non-batched `joint_bijector` to take batched args. # pylint: disable=protected-access self._forward = self._vectorize_member_fn( lambda bij, x: bij._forward(x), core_ndims=[self._joint_bijector.forward_min_event_ndims]) self._inverse = self._vectorize_member_fn( lambda bij, y: bij._inverse(y), core_ndims=[self._joint_bijector.inverse_min_event_ndims]) self._forward_log_det_jacobian = self._vectorize_member_fn( lambda bij, x: bij._forward_log_det_jacobian( # pylint: disable=g-long-lambda x, event_ndims=bij.forward_min_event_ndims), core_ndims=[self._joint_bijector.forward_min_event_ndims]) self._inverse_log_det_jacobian = self._vectorize_member_fn( lambda bij, y: bij._inverse_log_det_jacobian( # pylint: disable=g-long-lambda y, event_ndims=bij.inverse_min_event_ndims), core_ndims=[self._joint_bijector.inverse_min_event_ndims]) for attr in ( '_forward_event_shape', '_forward_event_shape_tensor', '_inverse_event_shape', '_inverse_event_shape_tensor', '_forward_dtype', '_inverse_dtype', 'forward_event_ndims', 'inverse_event_ndims', ): setattr(self, attr, getattr(self._joint_bijector, attr))
def make_distribution_bijector(distribution, name='make_distribution_bijector'): """Builds a bijector to approximately transform `N(0, 1)` into `distribution`. This represents a distribution as a bijector that transforms a (multivariate) standard normal distribution into the distribution of interest. Args: distribution: A `tfd.Distribution` instance; this may be a joint distribution. name: Python `str` name for ops created by this method. Returns: distribution_bijector: a `tfb.Bijector` instance such that `distribution_bijector(tfd.Normal(0., 1.))` is approximately equivalent to `distribution`. #### Examples This method may be used to convert structured variational distributions into MCMC preconditioners. Consider a model containing [funnel geometry](https://crackedbassoon.com/writing/funneling), which may be difficult for an MCMC algorithm to sample directly. ```python model_with_funnel = tfd.JointDistributionSequentialAutoBatched([ tfd.Normal(loc=-1., scale=2., name='z'), lambda z: tfd.Normal(loc=[0., 0., 0.], scale=tf.exp(z), name='x'), lambda x: tfd.Poisson(log_rate=x, name='y')]) pinned_model = tfp.experimental.distributions.JointDistributionPinned( model_with_funnel, y=[1, 3, 0]) ``` We can approximate the posterior in this model using a structured variational surrogate distribution, which will capture the funnel geometry, but cannot exactly represent the (non-Gaussian) posterior. ```python # Build and fit a structured surrogate posterior distribution. surrogate_posterior = tfp.experimental.vi.build_asvi_surrogate_posterior( pinned_model) _ = tfp.vi.fit_surrogate_posterior(pinned_model.unnormalized_log_prob, surrogate_posterior=surrogate_posterior, optimizer=tf.optimizers.Adam(0.01), num_steps=200) ``` Creating a preconditioning bijector allows us to obtain higher-quality posterior samples, without any Gaussianity assumption, by using the surrogate to guide an MCMC sampler. ```python surrogate_posterior_bijector = ( tfp.experimental.bijectors.make_distribution_bijector(surrogate_posterior)) samples, _ = tfp.mcmc.sample_chain( kernel=tfp.mcmc.DualAveragingStepSizeAdaptation( tfp.mcmc.TransformedTransitionKernel( tfp.mcmc.NoUTurnSampler(pinned_model.unnormalized_log_prob, step_size=0.1), bijector=surrogate_posterior_bijector), num_adaptation_steps=80), current_state=surrogate_posterior.sample(), num_burnin_steps=100, trace_fn=lambda _0, _1: [], num_results=500) ``` #### Mathematical details The bijectors returned by this method generally follow the following principles, although the specific bijectors returned may vary without notice. Normal distributions are reparameterized by a location-scale transform. ```python b = tfp.experimental.bijectors.make_distribution_bijector( tfd.Normal(loc=10., scale=5.)) # ==> tfb.Shift(10.)(tfb.Scale(5.))) b = tfp.experimental.bijectors.make_distribution_bijector( tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril)) # ==> tfb.Shift(loc)(tfb.ScaleMatvecTriL(scale_tril)) ``` The distribution's `quantile` function is used, when available: ```python d = tfd.Cauchy(loc=loc, scale=scale) b = tfp.experimental.bijectors.make_distribution_bijector(d) # ==> tfb.Inline(forward_fn=d.quantile, inverse_fn=d.cdf)(tfb.NormalCDF()) ``` Otherwise, a quantile function is derived by inverting the CDF: ```python d = tfd.Gamma(concentration=alpha, rate=beta) b = tfp.experimental.bijectors.make_distribution_bijector(d) # ==> tfb.Invert( # tfp.experimental.bijectors.ScalarFunctionWithInferredInverse(fn=d.cdf))( # tfb.NormalCDF()) ``` Transformed distributions are represented by chaining the transforming bijector with a preconditioning bijector for the base distribution: ```python b = tfp.experimental.bijectors.make_distribution_bijector( tfb.Exp(tfd.Normal(loc=10., scale=5.))) # ==> tfb.Exp(tfb.Shift(10.)(tfb.Scale(5.))) ``` Joint distributions are represented by a joint bijector, which converts each component distribution to a bijector with parameters conditioned on the previous variables in the model. The joint bijector's inputs and outputs follow the structure of the joint distribution. ```python jd = tfd.JointDistributionNamed( {'a': tfd.InverseGamma(concentration=2., scale=1.), 'b': lambda a: tfd.Normal(loc=3., scale=tf.sqrt(a))}) b = tfp.experimental.bijectors.make_distribution_bijector(jd) whitened_jd = tfb.Invert(b)(jd) x = whitened_jd.sample() # x <=> {'a': tfd.Normal(0., 1.).sample(), 'b': tfd.Normal(0., 1.).sample()} ``` """ with tf.name_scope(name): event_space_bijector = ( distribution.experimental_default_event_space_bijector()) if event_space_bijector is None: # Fail if the distribution is discrete. raise NotImplementedError( 'Cannot transform distribution {} to a standard normal ' 'distribution.'.format(distribution)) # Recurse over joint distributions. if isinstance(distribution, joint_distribution.JointDistribution): return joint_distribution._DefaultJointBijector( # pylint: disable=protected-access distribution, bijector_fn=make_distribution_bijector) # Recurse through transformed distributions. if isinstance(distribution, transformed_distribution.TransformedDistribution): return distribution.bijector( make_distribution_bijector(distribution.distribution)) # If we've annotated a specific bijector for this distribution, use that. if isinstance(distribution, tuple(preconditioning_bijector_fns)): return preconditioning_bijector_fns[type(distribution)]( distribution) # Otherwise, if this distribution implements a CDF and inverse CDF, build # a bijector from those. implements_cdf = False implements_quantile = False input_spec = tf.zeros(shape=distribution.event_shape, dtype=distribution.dtype) try: callable_util.get_output_spec(distribution.cdf, input_spec) implements_cdf = True except NotImplementedError: pass try: callable_util.get_output_spec(distribution.quantile, input_spec) implements_quantile = True except NotImplementedError: pass if implements_cdf and implements_quantile: # This path will only trigger for scalar distributions, since multivariate # distributions have non-invertible CDF and so cannot define a `quantile`. return tfb.Inline(forward_fn=distribution.quantile, inverse_fn=distribution.cdf, forward_min_event_ndims=ps.rank_from_shape( distribution.event_shape_tensor, distribution.event_shape))(tfb.NormalCDF()) # If the events are scalar, try to invert the CDF numerically. if implements_cdf and tf.get_static_value( distribution.is_scalar_event()): return tfb.Invert( scalar_function_with_inferred_inverse. ScalarFunctionWithInferredInverse( distribution.cdf, domain_constraint_fn=(event_space_bijector)))( tfb.NormalCDF()) raise NotImplementedError('Could not automatically construct a ' 'bijector for distribution type ' '{}; it does not implement an invertible ' 'CDF.'.format(distribution))
def build_and_invoke_pinned_bijector(pins, *args): bij = joint_distribution._DefaultJointBijector( # pylint: disable=protected-access self._jd.distribution.experimental_pin(**pins), **self._bijector_kwargs) return member_fn(bij, *args)
def get_fixed_topology_joint_bijector( model: tfd.JointDistribution, topology_pins=tp.Dict[str, TensorflowTreeTopology]) -> tfb.Composition: bijector_fn = partial(get_fixed_topology_bijector, topology_pins=topology_pins) return _DefaultJointBijector(model, bijector_fn=bijector_fn)
def __init__(self, jd, **kwargs): parameters = dict(locals()) self._jd = jd self._bijector_kwargs = kwargs self._joint_bijector = joint_distribution_lib._DefaultJointBijector( jd=self._jd, **self._bijector_kwargs) super(_DefaultJointBijectorAutoBatched, self).__init__(forward_min_event_ndims=self._joint_bijector. forward_min_event_ndims, inverse_min_event_ndims=self._joint_bijector. inverse_min_event_ndims, validate_args=self._joint_bijector.validate_args, parameters=parameters, name=self._joint_bijector.name) # Any batch dimensions of the JD must be included in the core # 'event' processed by autobatched bijector methods. This is because # `vectorized_map` has no visibility into the internal batch vs event # semantics of the methods being vectorized. More precisely, if we # didn't do this, then: # 1. Calling `self.inverse_log_det_jacobian(y)` with a `y` of shape # `jd.event_shape` would in general return a result of shape # `jd.batch_shape` (since each batch member can define a different # transformation). # 2. By the semantics of `vectorized_map`, calling # `self.inverse_log_det_jacobian(y)` with an `y` of shape # `concat([jd.batch_shape, jd.event_shape])` would therefore return # a result of shape `concat([jd.batch_shape, jd.batch_shape])`, in # which the batch shape appears *twice*. # 3. This breaks the TFP shape contract and is bad. # We avoid this by requiring that `y` is at least of shape # `jd.sample().shape`. jd_batch_ndims = ps.rank_from_shape(jd.batch_shape_tensor()) forward_core_ndims = tf.nest.map_structure( lambda nd: jd_batch_ndims + nd, self.forward_min_event_ndims) inverse_core_ndims = tf.nest.map_structure( lambda nd: jd_batch_ndims + nd, self.inverse_min_event_ndims) # Wrap the non-batched `joint_bijector` to take batched args. # pylint: disable=protected-access self._forward = self._vectorize_member_fn( lambda bij, x: bij._forward(x), core_ndims=[forward_core_ndims]) self._inverse = self._vectorize_member_fn( lambda bij, y: bij._inverse(y), core_ndims=[inverse_core_ndims]) self._forward_log_det_jacobian = self._vectorize_member_fn( # Need to explicitly broadcast LDJ if `bij` has constant Jacobian. lambda bij, x: tf.broadcast_to( # pylint: disable=g-long-lambda bij._forward_log_det_jacobian( x, event_ndims=self.forward_min_event_ndims), jd.batch_shape_tensor()), core_ndims=[forward_core_ndims]) self._inverse_log_det_jacobian = self._vectorize_member_fn( # Need to explicitly broadcast LDJ if `bij` has constant Jacobian. lambda bij, y: tf.broadcast_to( # pylint: disable=g-long-lambda bij._inverse_log_det_jacobian( y, event_ndims=self.inverse_min_event_ndims), jd.batch_shape_tensor()), core_ndims=[inverse_core_ndims]) for attr in ( '_forward_event_shape', '_forward_event_shape_tensor', '_inverse_event_shape', '_inverse_event_shape_tensor', '_forward_dtype', '_inverse_dtype', 'forward_event_ndims', 'inverse_event_ndims', ): setattr(self, attr, getattr(self._joint_bijector, attr))
def _experimental_default_event_space_bijector(self, *args, **kwargs): """A bijector to pull back unpinned values to unconstrained reals.""" if args or kwargs: return joint_distribution._DefaultJointBijector( self.experimental_pin(*args, **kwargs)) return joint_distribution._DefaultJointBijector(self)