def test_f64_state(self, method, method_kwargs): states, _ = callable_util.get_output_spec(lambda: method( # pylint: disable=g-long-lambda 5, tfd.Normal(tf.constant(0., tf.float64), 1.), n_chains=2, num_adaptation_steps=100, seed=test_util.test_seed(), **method_kwargs)) self.assertEqual(tf.float64, states.dtype)
def __init__( self, log_prob_increment, validate_args=False, allow_nan_stats=False, # pylint: disable=unused-argument reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, # pylint: disable=unused-argument log_prob_ratio_fn=None, name='IncrementLogProb', **kwargs): """Construct a `IncrementLogProb` distribution-like object. Args: log_prob_increment: Float Tensor or callable returning a float Tensor. Log probability/density to increment by. validate_args: This argument is ignored, but is present because it is used in certain situations where `Distribution`s are expected. allow_nan_stats: This argument is ignored, but is present because it is used in certain situations where `Distribution`s are expected. reparameterization_type: This argument is ignored, but is present because it is used in certain situations where `Distribution`s are expected. log_prob_ratio_fn: Optional callable with signature `(p_kwargs, q_kwargs) -> log_prob_ratio`, used to implement a custom `p_log_prob_increment - q_log_prob_increment` computation. name: Python `str` name prefixed to Ops created by this class. **kwargs: Passed to `log_prob_increment` if it is callable. """ self._parameters = dict(locals()) with tf.name_scope(name) as name: if callable(log_prob_increment): log_prob_increment_fn = lambda: tensor_util.convert_nonref_to_tensor( # pylint: disable=g-long-lambda log_prob_increment(**kwargs)) spec = callable_util.get_output_spec(log_prob_increment_fn) else: if kwargs: raise ValueError( '`kwargs` is only valid when `log_prob_increment` is callable.' ) log_prob_increment = tensor_util.convert_nonref_to_tensor( log_prob_increment) log_prob_increment_fn = lambda: log_prob_increment spec = log_prob_increment self._log_prob_increment_fn = log_prob_increment_fn self._log_prob_increment = log_prob_increment self._dtype = spec.dtype self._batch_shape = spec.shape self._name = name self._validate_args = validate_args self._log_prob_ratio_fn = log_prob_ratio_fn self._kwargs = kwargs
def test_get_output_spec_from_tensor_specs(self): args = (tf.TensorSpec([], dtype=tf.float32), (tf.TensorSpec([1, 1], dtype=tf.float32), tf.TensorSpec([2], dtype=tf.float64))) additional_args = (tf.TensorSpec([2, 1], dtype=tf.int32), ) # Trace using both positional and keyword args. results = callable_util.get_output_spec( _return_args_from_infinite_loop, *args, additional_loop_vars=additional_args) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, args + additional_args), tf.nest.map_structure(lambda x: x.shape, results)) self.assertAllAssertsNested( self.assertEqual, tf.nest.map_structure(lambda x: x.dtype, args + additional_args), tf.nest.map_structure(lambda x: x.dtype, results))
def test_get_output_spec_loop(self): args = (np.array(0., dtype=np.float64), (tf.convert_to_tensor(0.), tf.convert_to_tensor([1., 1.], dtype=tf.float64))) additional_args = (tf.convert_to_tensor([[3], [4]], dtype=tf.int32), ) # Trace using both positional and keyword args. results = callable_util.get_output_spec( _return_args_from_infinite_loop, *args, additional_loop_vars=additional_args) self.assertAllEqualNested( tf.nest.map_structure(lambda x: tf.convert_to_tensor(x).shape, args + additional_args), tf.nest.map_structure(lambda x: x.shape, results)) self.assertAllAssertsNested( self.assertEqual, tf.nest.map_structure(lambda x: tf.convert_to_tensor(x).dtype, args + additional_args), tf.nest.map_structure(lambda x: x.dtype, results))
def make_distribution_bijector(distribution, name='make_distribution_bijector'): """Builds a bijector to approximately transform `N(0, 1)` into `distribution`. This represents a distribution as a bijector that transforms a (multivariate) standard normal distribution into the distribution of interest. Args: distribution: A `tfd.Distribution` instance; this may be a joint distribution. name: Python `str` name for ops created by this method. Returns: distribution_bijector: a `tfb.Bijector` instance such that `distribution_bijector(tfd.Normal(0., 1.))` is approximately equivalent to `distribution`. #### Examples This method may be used to convert structured variational distributions into MCMC preconditioners. Consider a model containing [funnel geometry](https://crackedbassoon.com/writing/funneling), which may be difficult for an MCMC algorithm to sample directly. ```python model_with_funnel = tfd.JointDistributionSequentialAutoBatched([ tfd.Normal(loc=-1., scale=2., name='z'), lambda z: tfd.Normal(loc=[0., 0., 0.], scale=tf.exp(z), name='x'), lambda x: tfd.Poisson(log_rate=x, name='y')]) pinned_model = tfp.experimental.distributions.JointDistributionPinned( model_with_funnel, y=[1, 3, 0]) ``` We can approximate the posterior in this model using a structured variational surrogate distribution, which will capture the funnel geometry, but cannot exactly represent the (non-Gaussian) posterior. ```python # Build and fit a structured surrogate posterior distribution. surrogate_posterior = tfp.experimental.vi.build_asvi_surrogate_posterior( pinned_model) _ = tfp.vi.fit_surrogate_posterior(pinned_model.unnormalized_log_prob, surrogate_posterior=surrogate_posterior, optimizer=tf.optimizers.Adam(0.01), num_steps=200) ``` Creating a preconditioning bijector allows us to obtain higher-quality posterior samples, without any Gaussianity assumption, by using the surrogate to guide an MCMC sampler. ```python surrogate_posterior_bijector = ( tfp.experimental.bijectors.make_distribution_bijector(surrogate_posterior)) samples, _ = tfp.mcmc.sample_chain( kernel=tfp.mcmc.DualAveragingStepSizeAdaptation( tfp.mcmc.TransformedTransitionKernel( tfp.mcmc.NoUTurnSampler(pinned_model.unnormalized_log_prob, step_size=0.1), bijector=surrogate_posterior_bijector), num_adaptation_steps=80), current_state=surrogate_posterior.sample(), num_burnin_steps=100, trace_fn=lambda _0, _1: [], num_results=500) ``` #### Mathematical details The bijectors returned by this method generally follow the following principles, although the specific bijectors returned may vary without notice. Normal distributions are reparameterized by a location-scale transform. ```python b = tfp.experimental.bijectors.make_distribution_bijector( tfd.Normal(loc=10., scale=5.)) # ==> tfb.Shift(10.)(tfb.Scale(5.))) b = tfp.experimental.bijectors.make_distribution_bijector( tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril)) # ==> tfb.Shift(loc)(tfb.ScaleMatvecTriL(scale_tril)) ``` The distribution's `quantile` function is used, when available: ```python d = tfd.Cauchy(loc=loc, scale=scale) b = tfp.experimental.bijectors.make_distribution_bijector(d) # ==> tfb.Inline(forward_fn=d.quantile, inverse_fn=d.cdf)(tfb.NormalCDF()) ``` Otherwise, a quantile function is derived by inverting the CDF: ```python d = tfd.Gamma(concentration=alpha, rate=beta) b = tfp.experimental.bijectors.make_distribution_bijector(d) # ==> tfb.Invert( # tfp.experimental.bijectors.ScalarFunctionWithInferredInverse(fn=d.cdf))( # tfb.NormalCDF()) ``` Transformed distributions are represented by chaining the transforming bijector with a preconditioning bijector for the base distribution: ```python b = tfp.experimental.bijectors.make_distribution_bijector( tfb.Exp(tfd.Normal(loc=10., scale=5.))) # ==> tfb.Exp(tfb.Shift(10.)(tfb.Scale(5.))) ``` Joint distributions are represented by a joint bijector, which converts each component distribution to a bijector with parameters conditioned on the previous variables in the model. The joint bijector's inputs and outputs follow the structure of the joint distribution. ```python jd = tfd.JointDistributionNamed( {'a': tfd.InverseGamma(concentration=2., scale=1.), 'b': lambda a: tfd.Normal(loc=3., scale=tf.sqrt(a))}) b = tfp.experimental.bijectors.make_distribution_bijector(jd) whitened_jd = tfb.Invert(b)(jd) x = whitened_jd.sample() # x <=> {'a': tfd.Normal(0., 1.).sample(), 'b': tfd.Normal(0., 1.).sample()} ``` """ with tf.name_scope(name): event_space_bijector = ( distribution.experimental_default_event_space_bijector()) if event_space_bijector is None: # Fail if the distribution is discrete. raise NotImplementedError( 'Cannot transform distribution {} to a standard normal ' 'distribution.'.format(distribution)) # Recurse over joint distributions. if isinstance(distribution, joint_distribution.JointDistribution): return joint_distribution._DefaultJointBijector( # pylint: disable=protected-access distribution, bijector_fn=make_distribution_bijector) # Recurse through transformed distributions. if isinstance(distribution, transformed_distribution.TransformedDistribution): return distribution.bijector( make_distribution_bijector(distribution.distribution)) # If we've annotated a specific bijector for this distribution, use that. if isinstance(distribution, tuple(preconditioning_bijector_fns)): return preconditioning_bijector_fns[type(distribution)]( distribution) # Otherwise, if this distribution implements a CDF and inverse CDF, build # a bijector from those. implements_cdf = False implements_quantile = False input_spec = tf.zeros(shape=distribution.event_shape, dtype=distribution.dtype) try: callable_util.get_output_spec(distribution.cdf, input_spec) implements_cdf = True except NotImplementedError: pass try: callable_util.get_output_spec(distribution.quantile, input_spec) implements_quantile = True except NotImplementedError: pass if implements_cdf and implements_quantile: # This path will only trigger for scalar distributions, since multivariate # distributions have non-invertible CDF and so cannot define a `quantile`. return tfb.Inline(forward_fn=distribution.quantile, inverse_fn=distribution.cdf, forward_min_event_ndims=ps.rank_from_shape( distribution.event_shape_tensor, distribution.event_shape))(tfb.NormalCDF()) # If the events are scalar, try to invert the CDF numerically. if implements_cdf and tf.get_static_value( distribution.is_scalar_event()): return tfb.Invert( scalar_function_with_inferred_inverse. ScalarFunctionWithInferredInverse( distribution.cdf, domain_constraint_fn=(event_space_bijector)))( tfb.NormalCDF()) raise NotImplementedError('Could not automatically construct a ' 'bijector for distribution type ' '{}; it does not implement an invertible ' 'CDF.'.format(distribution))
def bracket_root(objective_fn, dtype=tf.float32, num_points=512, name='bracket_root'): """Finds bounds that bracket a root of the objective function. This method attempts to return an interval bracketing a root of the objective function. It evaluates the objective in parallel at `num_points` locations, at exponentially increasing distance from the origin, and returns the first pair of adjacent points `[low, high]` such that the objective is finite and has a different sign at the two points. If no such pair was observed, it returns the trivial interval `[np.finfo(dtype).min, np.finfo(dtype).max]` containing all float values of the specified `dtype`. If the objective has multiple roots, the returned interval will contain at least one (but perhaps not all) of the roots. Args: objective_fn: Python callable for which roots are searched. It must be a continuous function that accepts a scalar `Tensor` of type `dtype` and returns a `Tensor` of shape `batch_shape`. dtype: Optional float `dtype` of inputs to `objective_fn`. Default value: `tf.float32`. num_points: Optional Python `int` number of points at which to evaluate the objective. Default value: `512`. name: Python `str` name given to ops created by this method. Returns: low: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Lower bound on a root of `objective_fn`. high: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Upper bound on a root of `objective_fn`. """ with tf.name_scope(name): # Build a logarithmic sequence of `num_points` values from -inf to inf. dtype_info = np.finfo(dtype_util.as_numpy_dtype(dtype)) xs_positive = tf.exp( tf.linspace(tf.cast(-10., dtype), tf.math.log(dtype_info.max), num_points // 2)) xs = tf.concat([tf.reverse(-xs_positive, axis=[0]), xs_positive], axis=0) # Evaluate the objective at all points. The objective function may return # a batch of values (e.g., `objective(x) = x - batch_of_roots`). if NUMPY_MODE: objective_output_spec = objective_fn(tf.zeros([], dtype=dtype)) else: objective_output_spec = callable_util.get_output_spec( objective_fn, tf.convert_to_tensor(0., dtype=dtype)) batch_ndims = tensorshape_util.rank(objective_output_spec.shape) if batch_ndims is None: raise ValueError('Cannot infer tensor rank of objective values.') xs_pad_shape = ps.pad([num_points], paddings=[[0, batch_ndims]], constant_values=1) ys = objective_fn(tf.reshape(xs, xs_pad_shape)) # Find the smallest point where the objective is finite. is_finite = tf.math.is_finite(ys) ys_transposed = distribution_util.move_dimension( # For batch gather. ys, 0, -1) first_finite_value = tf.gather( ys_transposed, tf.argmax(is_finite, axis=0), # Index of smallest finite point. batch_dims=batch_ndims, axis=-1) # Select the next point where the objective has a different sign. sign_change_idx = tf.argmax( tf.not_equal(tf.math.sign(ys), tf.math.sign(first_finite_value)) & is_finite, axis=0) # If the sign never changes, we can't bracket a root. bracketing_failed = tf.equal(sign_change_idx, 0) # If the objective's sign is zero, we've found an actual root. root_found = tf.equal( tf.gather(tf.math.sign(ys_transposed), sign_change_idx, batch_dims=batch_ndims, axis=-1), 0.) return _structure_broadcasting_where( bracketing_failed, # If we didn't detect a sign change, fall back to the trivial interval. (dtype_info.min, dtype_info.max), # Otherwise, return the points around the sign change, unless we # actually evaluated a root, in which case, return the zero-width # bracket at that root. (tf.gather( xs, tf.where(bracketing_failed | root_found, sign_change_idx, sign_change_idx - 1)), tf.gather(xs, sign_change_idx) ))
def test_get_output_spec_oom(self): result = callable_util.get_output_spec(_compute_oom) self.assertEqual((int(1e9), int(1e9)), result.shape) self.assertEqual(tf.float32, result.dtype)