def step(bij, x, x_event_ndims, increased_dof, **kwargs): # pylint: disable=missing-docstring # Transform inputs for the next bijector. y = forward_fn(bij, x, **kwargs) if compute_x_values else None y_event_ndims = forward_event_ndims_fn(bij, x_event_ndims, **kwargs) # Check if the inputs to this bijector have increased degrees of freedom # due to some upstream bijector. We assume that the upstream bijector # produced a valid LDJ, but this one does not (unless LDJ is 0, in which # case it doesn't matter). increased_dof = ps.reduce_any(nest.flatten(increased_dof)) if compute_x_values and self.validate_event_size: assertions = [ self._maybe_warn_increased_dof( component_name=bij.name, increased_dof=increased_dof) ] increased_dof |= (_event_size(y, y_event_ndims) > _event_size(x, x_event_ndims)) else: assertions = [] y = nest_util.broadcast_structure(y_event_ndims, y) increased_dof = nest_util.broadcast_structure(y_event_ndims, increased_dof) bijectors_with_metadata.append( BijectorWithMetadata( bijector=bij, x=x, x_event_ndims=x_event_ndims, kwargs=kwargs, assertions=assertions, )) return y, y_event_ndims, increased_dof
def __init__(self, bijectors=None, validate_args=False, parameters=None, name=None): """Instantiates `JointMap` bijector. Args: bijectors: Structure of bijector instances to apply in parallel. validate_args: Python `bool` indicating whether arguments should be checked for correctness. parameters: Locals dict captured by subclass constructor, to be used for copy/slice re-instantiation operators. name: Python `str`, name given to ops managed by this object. Default: E.g., ``` JointMap([Exp(), Softplus()]).name == "jointmap_of_exp_and_softplus" ```. Raises: ValueError: if bijectors have different dtypes. """ parameters = dict(locals()) if parameters is None else parameters if not bijectors: raise ValueError('`bijectors` must not be empty.') if name is None: name = ('jointmap_of_' + '_and_'.join([b.name for b in nest.flatten(bijectors)])) name = name.replace('/', '') with tf.name_scope(name) as name: # Structured dtypes are based on the non-wrapped input. # Keep track of the non-wrapped structure of bijectors to correctly # wrap inputs/outputs in _walk methods. self._nested_structure = self._no_dependency( nest.map_structure(lambda b: None, bijectors)) super(JointMap, self).__init__( bijectors=bijectors, validate_args=validate_args, parameters=parameters, name=name, # JointMap and other bijectors that operate independently on # parts of structured inputs do not have statically-known # `min_event_ndims`. Infer the input/output structures, and fill them # with `None`. forward_min_event_ndims=nest.map_structure( lambda b: nest_util.broadcast_structure( # pylint: disable=g-long-lambda b.forward_min_event_ndims, None), bijectors), inverse_min_event_ndims=nest.map_structure( lambda b: nest_util.broadcast_structure( # pylint: disable=g-long-lambda b.forward_min_event_ndims, None), bijectors), )
def step(bij, y, y_event_ndims, increased_dof=False, **kwargs): # pylint: disable=missing-docstring nonlocal ldj_sum # Compute the LDJ for this step, and add it to the rolling sum. component_ldj = tf.convert_to_tensor( bij.inverse_log_det_jacobian(y, y_event_ndims, **kwargs), dtype_hint=ldj_sum.dtype) if not dtype_util.is_floating(component_ldj.dtype): raise TypeError(('Nested bijector "{}" of Composition "{}" returned ' 'LDJ with a non-floating dtype: {}') .format(bij.name, self.name, component_ldj.dtype)) ldj_sum = _max_precision_sum(ldj_sum, component_ldj) # Transform inputs for the next bijector. x = bij.inverse(y, **kwargs) x_event_ndims = bij.inverse_event_ndims(y_event_ndims, **kwargs) # Check if the inputs to this bijector have increased degrees of freedom # due to some upstream bijector. We assume that the upstream bijector # produced a valid LDJ, but this one does not (unless LDJ is 0, in which # case it doesn't matter). increased_dof = ps.reduce_any(nest.flatten(increased_dof)) if self.validate_event_size: assertions.append(self._maybe_warn_increased_dof( component_name=bij.name, component_ldj=component_ldj, increased_dof=increased_dof)) increased_dof |= (_event_size(x, x_event_ndims) > _event_size(y, y_event_ndims)) increased_dof = nest_util.broadcast_structure(x, increased_dof) return x, x_event_ndims, increased_dof
def from_shape(cls, shape=(), independent_chain_ndims=1, dtype=tf.float32): """Starts an empty `RunningPotentialScaleReduction` from metadata. Args: shape: Python `Tuple` or `TensorShape` representing the shape of incoming samples. Using a collection implies that future samples will mimic that exact structure. This is useful to supply if the `RunningPotentialScaleReduction` will be carried by a `tf.while_loop`, so that broadcasting does not change the shape across loop iterations. independent_chain_ndims: Integer or Integer type `Tensor` with value `>= 1` giving the number of leading dimensions holding independent chain results to be tested for convergence. Using a collection implies that future samples will mimic that exact structure. dtype: Dtype of incoming samples and the resulting statistics. By default, the dtype is `tf.float32`. Any integer dtypes will be cast to corresponding floats (i.e. `tf.int32` will be cast to `tf.float32`), as intermediate calculations should be performing floating-point division. Returns: state: `RunningPotentialScaleReduction` representing a stream of no inputs. """ dtype = tf.nest.map_structure(_float_dtype_like, dtype) dtype = nest_util.broadcast_structure(independent_chain_ndims, dtype) chain_variances = nest.map_structure_up_to(independent_chain_ndims, RunningVariance.from_shape, shape, dtype, check_types=False) return cls(chain_variances, independent_chain_ndims)
def one_step(self, new_chain_state, current_reducer_state, previous_kernel_results, axis=None): """Update the `current_reducer_state` with a new chain state. Chunking semantics are similar to those of batching and are specified by the `axis` parameter. If chunking is enabled (axis is not `None`), all elements along the specified `axis` will be treated as separate samples. If a single scalar value is provided for a non-scalar sample structure, that value will be used for all elements in the structure. If not, an identical structure must be provided. Args: new_chain_state: A (possibly nested) structure of incoming chain state(s) with shape and dtype compatible with those used to initialize the `current_reducer_state`. current_reducer_state: `CovarianceReducerState`s representing the current state of the running covariance. previous_kernel_results: A (possibly nested) structure of `Tensor`s representing internal calculations made in a related `TransitionKernel`. axis: If chunking is desired, this is a (possibly nested) structure of integers that specifies the axis with chunked samples. For individual samples, set this to `None`. By default, samples are not chunked (`axis` is None). Returns: new_reducer_state: `CovarianceReducerState` with updated running statistics. Its `cov_state` field has an identical structure to the results of `self.transform_fn`. Each of the individual values in that structure subsequently mimics the structure of `current_reducer_state`. """ with tf.name_scope( mcmc_util.make_name(self.name, 'covariance_reducer', 'one_step')): cov_streams = _prepare_args(current_reducer_state.init_structure, self.event_ndims) new_chain_state = tf.nest.map_structure(tf.convert_to_tensor, new_chain_state) previous_kernel_results = tf.nest.map_structure( tf.convert_to_tensor, previous_kernel_results) fn_results = tf.nest.map_structure( lambda fn: fn(new_chain_state, previous_kernel_results), self.transform_fn, ) if not nest.is_nested(axis): axis = nest_util.broadcast_structure(fn_results, axis) running_cov_state = nest.map_structure_up_to( current_reducer_state.init_structure, lambda strm, *args: strm.update(*args), cov_streams, current_reducer_state.cov_state, fn_results, axis, check_types=False, ) return CovarianceReducerState(current_reducer_state.init_structure, running_cov_state)
def update(self, state, new_sample): """Update the `RunningPotentialScaleReductionState` with a new sample. Args: state: `RunningPotentialScaleReductionState` that represents the current state of running statistics. new_sample: Incoming `Tensor` sample or (possibly nested) collection of `Tensor`s with shape and dtype compatible with those used to form the `RunningPotentialScaleReductionState`. Returns: state: `RunningPotentialScaleReductionState` with updated calculations. """ def _update_for_one_state(shape, dtype, chain_var, new_sample): """Updates the running variance for one group of Markov chains.""" # TODO(axch): chunking could be reasonably added here by accepting and # including the chunked axis to the running variance object var_stream = RunningVariance(shape, dtype=dtype) return var_stream.update(chain_var, new_sample) broadcasted_dtype = nest_util.broadcast_structure( self.independent_chain_ndims, self.dtype) updated_chain_vars = nest.map_structure_up_to( self.independent_chain_ndims, _update_for_one_state, self.shape, broadcasted_dtype, state.chain_var, new_sample, check_types=False) return RunningPotentialScaleReductionState(updated_chain_vars)
def test_specifying_distribution_type(self, event_shape, base_distribution_cls, is_stateless=JAX_MODE): init_seed, sample_seed = samplers.split_seed( test_util.test_seed(sampler_type='stateless'), n=2) surrogate_posterior = self._initialize_surrogate( 'build_factored_surrogate_posterior', is_stateless=is_stateless, seed=init_seed, event_shape=event_shape, base_distribution_cls=base_distribution_cls, validate_args=True) # Test that the surrogate uses the expected distribution types. if tf.nest.is_nested(surrogate_posterior.event_shape): ds, _ = surrogate_posterior.sample_distributions(seed=sample_seed) else: ds = [surrogate_posterior] for cls, d in zip( nest_util.broadcast_structure(ds, base_distribution_cls), ds): d = _as_concrete_instance(d) while isinstance(d, tfd.Independent): d = _as_concrete_instance(d.distribution) self.assertIsInstance(d, cls)
def _prepare_args(target, event_ndims): """Creates a structure of `RunningCovariance`s based on inferred metadata. Metadata required to create a `RunningCovariance` object (`shape`, `dtype`, and `event_ndims` of incoming chain states) will be inferred from the `target`. Using that information, an identical structure of `RunningCovariance`s to `target` will be returned. Args: target: A (possibly nested) structure of `Tensor`s or Python `list`s of `Tensor`s representing the current state(s) of the Markov chain(s). It is used to infer the shape and dtype of future samples. event_ndims: A (possibly nested) structure of integers. Defines the number of inner-most dimensions that represent the event shape. Must be either a singleton or of the same shape as `target`. Returns: cov_streams: Structure of `sample_stats.RunningCovariance` matching the shape of `target`. """ shape = tf.nest.map_structure(lambda target: target.shape, target) dtype = tf.nest.map_structure(lambda target: target.dtype, target) if event_ndims is None: event_ndims = tf.nest.map_structure(ps.rank, target) elif not nest.is_nested(event_ndims): event_ndims = nest_util.broadcast_structure(target, event_ndims) return nest.map_structure_up_to( target, sample_stats.RunningCovariance, shape, event_ndims, dtype, check_types=False, )
def step_broadcast(step_size): # Only apply the bijector to nested step sizes or non-scalar batches. if tf.nest.is_nested(step_size): return step_bijector( nest_util.broadcast_structure( pinned_model.event_shape_tensor(), step_size)) else: return step_size
def one_step(self, new_chain_state, current_reducer_state, previous_kernel_results=None, axis=None): """Update the `current_reducer_state` with a new chain state. Chunking semantics are specified by the `axis` parameter. If chunking is enabled (axis is not `None`), all elements along the specified `axis` will be treated as separate samples. If a single scalar value is provided for a non-scalar sample structure, that value will be used for all elements in the structure. If not, an identical structure must be provided. Args: new_chain_state: A (possibly nested) structure of incoming chain state(s) with shape and dtype compatible with those used to initialize the `current_reducer_state`. current_reducer_state: `ExpectationsReducerState` representing the current reducer state. previous_kernel_results: A (possibly nested) structure of `Tensor`s representing internal calculations made in a related `TransitionKernel`. axis: If chunking is desired, this is a (possibly nested) structure of integers that specifies the axis with chunked samples. For individual samples, set this to `None`. By default, samples are not chunked (`axis` is None). Returns: new_reducer_state: `ExpectationsReducerState` with updated running statistics. It tracks a running total and the number of processed samples. """ with tf.name_scope( mcmc_util.make_name(self.name, 'expectations_reducer', 'one_step')): new_chain_state = tf.nest.map_structure(tf.convert_to_tensor, new_chain_state) if previous_kernel_results is not None: previous_kernel_results = tf.nest.map_structure( tf.convert_to_tensor, previous_kernel_results, expand_composites=True) fn_results = tf.nest.map_structure( lambda fn: fn(new_chain_state, previous_kernel_results), self.transform_fn) if not nest.is_nested(axis): axis = nest_util.broadcast_structure(fn_results, axis) def update(fn_results, state, axis): return state.update(fn_results, axis=axis) return ExpectationsReducerState( nest.map_structure(update, fn_results, current_reducer_state.expectation_state, axis, check_types=False))
def _canonicalize_event_ndims(target, event_ndims): """Returns `event_ndims` shaped parallel to `target`, repeating as needed.""" # This is only here to support the possibility of different event_ndims across # different Tensors in the target structure. Otherwise, event_ndims could # just be an integer (or None) and wouldn't need to be canonicalized to a # structure. if not nest.is_nested(event_ndims): return nest_util.broadcast_structure(target, event_ndims) else: return event_ndims
def _get_bijectors_with_metadata(self, x, event_ndims, forward=True, **kwargs): """Trace bijectors + metadata forward/backward.""" bijectors_with_metadata = [] if forward: forward_fn = lambda bij, *args, **kwargs: bij.forward(*args, **kwargs) forward_event_ndims_fn = ( lambda bij, *args, **kwargs: bij.forward_event_ndims(*args, **kwargs)) walk_forward_fn = self._call_walk_forward else: forward_fn = lambda bij, *args, **kwargs: bij.inverse(*args, **kwargs) forward_event_ndims_fn = ( lambda bij, *args, **kwargs: bij.inverse_event_ndims(*args, **kwargs)) walk_forward_fn = self._call_walk_inverse def step(bij, x, x_event_ndims, increased_dof, **kwargs): # pylint: disable=missing-docstring # Transform inputs for the next bijector. y = forward_fn(bij, x, **kwargs) y_event_ndims = forward_event_ndims_fn(bij, x_event_ndims, **kwargs) # Check if the inputs to this bijector have increased degrees of freedom # due to some upstream bijector. We assume that the upstream bijector # produced a valid LDJ, but this one does not (unless LDJ is 0, in which # case it doesn't matter). increased_dof = ps.reduce_any(nest.flatten(increased_dof)) if self.validate_event_size: assertions = [ self._maybe_warn_increased_dof( component_name=bij.name, increased_dof=increased_dof) ] increased_dof |= (_event_size(y, y_event_ndims) > _event_size(x, x_event_ndims)) else: assertions = [] increased_dof = nest_util.broadcast_structure(y, increased_dof) bijectors_with_metadata.append( BijectorWithMetadata( bijector=bij, x=x, x_event_ndims=x_event_ndims, kwargs=kwargs, assertions=assertions, )) return y, y_event_ndims, increased_dof increased_dof = nest_util.broadcast_structure(event_ndims, False) walk_forward_fn(step, x, event_ndims, increased_dof, **kwargs) return bijectors_with_metadata
def initialize(self): """Initializes an empty `RunningPotentialScaleReductionState`. Returns: state: `RunningPotentialScaleReductionState` representing a stream of no inputs. """ broadcasted_dtype = nest_util.broadcast_structure( self.independent_chain_ndims, self.dtype) chain_var = nest.map_structure_up_to(self.independent_chain_ndims, RunningVariance.from_shape, self.shape, broadcasted_dtype, check_types=False) return RunningPotentialScaleReductionState(chain_var)
def _forward_log_det_jacobian(self, x, event_ndims, **kwargs): # Container for accumulated LDJ. ldj_sum = tf.zeros([], dtype=tf.float32) # Container for accumulated assertions. assertions = [] def step(bij, x, x_event_ndims, increased_dof, **kwargs): # pylint: disable=missing-docstring nonlocal ldj_sum # Compute the LDJ for this step, and add it to the rolling sum. component_ldj = tf.convert_to_tensor(bij.forward_log_det_jacobian( x, x_event_ndims, **kwargs), dtype_hint=ldj_sum.dtype) if not dtype_util.is_floating(component_ldj.dtype): raise TypeError( ('Nested bijector "{}" of Composition "{}" returned ' 'LDJ with a non-floating dtype: {}').format( bij.name, self.name, component_ldj.dtype)) ldj_sum = _max_precision_sum(ldj_sum, component_ldj) # Transform inputs for the next bijector. y = bij.forward(x, **kwargs) y_event_ndims = bij.forward_event_ndims(x_event_ndims, **kwargs) # Check if the inputs to this bijector have increased degrees of freedom # due to some upstream bijector. We assume that the upstream bijector # produced a valid LDJ, but this one does not (unless LDJ is 0, in which # case it doesn't matter). increased_dof = ps.reduce_any(nest.flatten(increased_dof)) if self.validate_event_size: assertions.append( self._maybe_warn_increased_dof( component_name=bij.name, component_ldj=component_ldj, increased_dof=increased_dof)) increased_dof |= (_event_size(y, y_event_ndims) > _event_size( x, x_event_ndims)) increased_dof = nest_util.broadcast_structure(y, increased_dof) return y, y_event_ndims, increased_dof increased_dof = nest_util.broadcast_structure(event_ndims, False) self._call_walk_forward(step, x, event_ndims, increased_dof, **kwargs) with tf.control_dependencies([x for x in assertions if x is not None]): return tf.identity(ldj_sum, name='fldj')
def test_specifying_distribution_type(self, event_shape, base_distribution_cls): surrogate_posterior = ( tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=event_shape, base_distribution_cls=base_distribution_cls, validate_args=True)) # Test that the surrogate uses the expected distribution types. if tf.nest.is_nested(surrogate_posterior.event_shape): ds, _ = surrogate_posterior.sample_distributions() else: ds = [surrogate_posterior] for cls, d in zip( nest_util.broadcast_structure(ds, base_distribution_cls), ds): while isinstance(d, tfd.Independent): d = d.distribution self.assertIsInstance(d, cls)
def initialize(self): """Initializes an empty `RunningPotentialScaleReductionState`. Returns: state: `RunningPotentialScaleReductionState` representing a stream of no inputs. """ def _initialize_for_one_state(shape, dtype): """Initializes a running variance state for one group of Markov chains.""" var_stream = RunningVariance(shape, dtype=dtype) return var_stream.initialize() broadcasted_dtype = nest_util.broadcast_structure( self.independent_chain_ndims, self.dtype) chain_var = nest.map_structure_up_to(self.independent_chain_ndims, _initialize_for_one_state, self.shape, broadcasted_dtype, check_types=False) return RunningPotentialScaleReductionState(chain_var)
def _get_bijectors_with_metadata(self, x=None, event_ndims=None, forward=True, pack_as_original_structure=False, **kwargs): """Trace bijectors + metadata forward/backward.""" bijectors_with_metadata = [] compute_x_values = x is not None if event_ndims is None: event_ndims = self.forward_min_event_ndims if forward: forward_fn = lambda bij, *args, **kwargs: bij.forward(*args, **kwargs) forward_event_ndims_fn = ( lambda bij, *args, **kwargs: bij.forward_event_ndims(*args, **kwargs)) walk_forward_fn = self._call_walk_forward else: forward_fn = lambda bij, *args, **kwargs: bij.inverse(*args, **kwargs) forward_event_ndims_fn = ( lambda bij, *args, **kwargs: bij.inverse_event_ndims(*args, **kwargs)) walk_forward_fn = self._call_walk_inverse def step(bij, x, x_event_ndims, increased_dof, **kwargs): # pylint: disable=missing-docstring # Transform inputs for the next bijector. y = forward_fn(bij, x, **kwargs) if compute_x_values else None y_event_ndims = forward_event_ndims_fn(bij, x_event_ndims, **kwargs) # Check if the inputs to this bijector have increased degrees of freedom # due to some upstream bijector. We assume that the upstream bijector # produced a valid LDJ, but this one does not (unless LDJ is 0, in which # case it doesn't matter). increased_dof = ps.reduce_any(nest.flatten(increased_dof)) if compute_x_values and self.validate_event_size: assertions = [ self._maybe_warn_increased_dof( component_name=bij.name, increased_dof=increased_dof) ] increased_dof |= (_event_size(y, y_event_ndims) > _event_size(x, x_event_ndims)) else: assertions = [] y = nest_util.broadcast_structure(y_event_ndims, y) increased_dof = nest_util.broadcast_structure(y_event_ndims, increased_dof) bijectors_with_metadata.append( BijectorWithMetadata( bijector=bij, x=x, x_event_ndims=x_event_ndims, kwargs=kwargs, assertions=assertions, )) return y, y_event_ndims, increased_dof x = nest_util.broadcast_structure(event_ndims, x) increased_dof = nest_util.broadcast_structure(event_ndims, False) walk_forward_fn(step, x, event_ndims, increased_dof, **kwargs) if pack_as_original_structure: bijector_map = {id(bm.bijector): bm for bm in bijectors_with_metadata} bijectors_with_metadata = tf.nest.map_structure( lambda b: bijector_map[id(b)], self.bijectors) return bijectors_with_metadata
def effective_sample_size(states, filter_threshold=0., filter_beyond_lag=None, filter_beyond_positive_pairs=False, cross_chain_dims=None, validate_args=False, name=None): """Estimate a lower bound on effective sample size for each independent chain. Roughly speaking, "effective sample size" (ESS) is the size of an iid sample with the same variance as `state`. More precisely, given a stationary sequence of possibly correlated random variables `X_1, X_2, ..., X_N`, identically distributed, ESS is the number such that ``` Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }. ``` If the sequence is uncorrelated, `ESS = N`. If the sequence is positively auto-correlated, `ESS` will be less than `N`. If there are negative correlations, then `ESS` can exceed `N`. Some math shows that, with `R_k` the auto-correlation sequence, `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have ``` ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ] ``` This function estimates the above by first estimating the auto-correlation. Since `R_k` must be estimated using only `N - k` samples, it becomes progressively noisier for larger `k`. For this reason, the summation over `R_k` should be truncated at some number `filter_beyond_lag < N`. This function provides two methods to perform this truncation. * `filter_threshold` -- since many MCMC methods generate chains where `R_k > 0`, a reasonable criterion is to truncate at the first index where the estimated auto-correlation becomes negative. This method does not estimate the `ESS` of super-efficient chains (where `ESS > N`) correctly. * `filter_beyond_positive_pairs` -- reversible MCMC chains produce an auto-correlation sequence with the property that pairwise sums of the elements of that sequence are positive [Geyer][1], i.e. `R_{2k} + R_{2k + 1} > 0` for `k in {0, ..., N/2}`. Deviations are only possible due to noise. This method truncates the auto-correlation sequence where the pairwise sums become non-positive. The arguments `filter_beyond_lag`, `filter_threshold` and `filter_beyond_positive_pairs` are filters intended to remove noisy tail terms from `R_k`. You can combine `filter_beyond_lag` with `filter_threshold` or `filter_beyond_positive_pairs. E.g., combining `filter_beyond_lag` and `filter_beyond_positive_pairs` means that terms are removed if they were to be filtered under the `filter_beyond_lag` OR `filter_beyond_positive_pairs` criteria. This function can also compute cross-chain ESS following [Vehtari et al. (2019)][2] by specifying the `cross_chain_dims` argument. Cross-chain ESS takes into account the cross-chain variance to reduce the ESS in cases where the chains are not mixing well. In general, this will be a smaller number than computing the ESS for individual chains and then summing them. In an extreme case where the chains have fallen into K non-mixing modes, this function will return ESS ~ K. Even when chains are mixing well it is still preferrable to compute cross-chain ESS via this method because it will reduce the noise in the estimate of `R_k`, reducing the need for truncation. Args: states: `Tensor` or Python structure of `Tensor` objects. Dimension zero should index identically distributed states. filter_threshold: `Tensor` or Python structure of `Tensor` objects. Must broadcast with `state`. The sequence of auto-correlations is truncated after the first appearance of a term less than `filter_threshold`. Setting to `None` means we use no threshold filter. Since `|R_k| <= 1`, setting to any number less than `-1` has the same effect. Ignored if `filter_beyond_positive_pairs` is `True`. filter_beyond_lag: `Tensor` or Python structure of `Tensor` objects. Must be `int`-like and scalar valued. The sequence of auto-correlations is truncated to this length. Setting to `None` means we do not filter based on the size of lags. filter_beyond_positive_pairs: Python boolean. If `True`, only consider the initial auto-correlation sequence where the pairwise sums are positive. cross_chain_dims: An integer `Tensor` or a structure of integer `Tensors` corresponding to each state component. If a list of `states` is provided, then this argument should also be a list of the same length. Which dimensions of `states` to treat as independent chains that ESS will be summed over. If `None`, no summation is performed. Note this requires at least 2 chains. validate_args: Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed. name: `String` name to prepend to created ops. Returns: ess: `Tensor` structure parallel to `states`. The effective sample size of each component of `states`. If `cross_chain_dims` is None, the shape will be `states.shape[1:]`. Otherwise, the shape is `tf.reduce_mean(states, cross_chain_dims).shape[1:]`. Raises: ValueError: If `states` and `filter_threshold` or `states` and `filter_beyond_lag` are both structures of different shapes. ValueError: If `cross_chain_dims` is not `None` and there are less than 2 chains. #### Examples We use ESS to estimate standard error. ``` import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) # Get 1000 states from one chain. states = tfp.mcmc.sample_chain( num_burnin_steps=200, num_results=1000, current_state=tf.constant([0., 0.]), trace_fn=None, kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target.log_prob, step_size=0.05, num_leapfrog_steps=20)) print(states.shape) ==> (1000, 2) ess = effective_sample_size(states, filter_beyond_positive_pairs=True) print(ess.shape) ==> (2,) mean, variance = tf.nn.moments(states, axes=0) standard_error = tf.sqrt(variance / ess) ``` #### References [1]: Charles J. Geyer, Practical Markov chain Monte Carlo (with discussion). Statistical Science, 7:473-511, 1992. [2]: Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter, Paul-Christian Burkner. Rank-normalization, folding, and localization: An improved R-hat for assessing convergence of MCMC, 2019. Retrieved from http://arxiv.org/abs/1903.08008 """ if cross_chain_dims is None: cross_chain_dims = nest_util.broadcast_structure(states, None) filter_beyond_lag = nest_util.broadcast_structure(states, filter_beyond_lag) filter_threshold = nest_util.broadcast_structure(states, filter_threshold) filter_beyond_positive_pairs = nest_util.broadcast_structure( states, filter_beyond_positive_pairs) # Process items, one at a time. def single_state(*args): return _effective_sample_size_single_state( *args, validate_args=validate_args) with tf.name_scope('effective_sample_size' if name is None else name): return nest.map_structure_up_to( states, single_state, states, filter_beyond_lag, filter_threshold, filter_beyond_positive_pairs, cross_chain_dims)
def init_near_unconstrained_zero( model=None, constraining_bijector=None, event_shapes=None, event_shape_tensors=None, batch_shapes=None, batch_shape_tensors=None, dtypes=None): """Returns an initialization Distribution for starting a Markov chain. This initialization scheme follows Stan: we sample every latent independently, uniformly from -2 to 2 in its unconstrained space, and then transform into constrained space to construct an initial state that can be passed to `sample_chain` or other MCMC drivers. The argument signature is arranged to let the user pass either a `JointDistribution` describing their model, if it's in that form, or the essential information necessary for the sampling, namely a bijector (from unconstrained to constrained space) and the desired shape and dtype of each sample (specified in constrained space). Note: As currently implemented, this function has the limitation that the batch shape of the supplied model is ignored, but that could probably be generalized if needed. Args: model: A `Distribution` (typically a `JointDistribution`) giving the model to be initialized. If supplied, it is queried for its default event space bijector, its event shape, and its dtype. If not supplied, those three elements must be supplied instead. constraining_bijector: A (typically multipart) `Bijector` giving the mapping from unconstrained to constrained space. If supplied together with a `model`, acts as an override. A nested structure of `Bijector`s is accepted, and interpreted as applying in parallel to a corresponding structure of state parts (see `JointMap` for details). event_shapes: A structure of shapes giving the (unconstrained) event space shape of the desired samples. Must be an acceptable input to `constraining_bijector.inverse_event_shape`. If supplied together with `model`, acts as an override. event_shape_tensors: A structure of tensors giving the (unconstrained) event space shape of the desired samples. Must be an acceptable input to `constraining_bijector.inverse_event_shape_tensor`. If supplied together with `model`, acts as an override. Required if any of `event_shapes` are not fully-defined. batch_shapes: A structure of shapes giving the batch shape of the desired samples. If supplied together with `model`, acts as an override. If unspecified, we assume scalar batch `[]`. batch_shape_tensors: A structure of tensors giving the batch shape of the desired samples. If supplied together with `model`, acts as an override. Required if any of `batch_shapes` are not fully-defined. dtypes: A structure of dtypes giving the (unconstrained) dtypes of the desired samples. Must be an acceptable input to `constraining_bijector.inverse_dtype`. If supplied together with `model`, acts as an override. Returns: init_dist: A `Distribution` representing the initialization distribution, in constrained space. Samples from this `Distribution` are valid initial states for a Markov chain targeting the model. #### Example Initialize 100 chains from the unconstrained -2, 2 distribution for a model expressed as a `JointDistributionCoroutine`: ```python @tfp.distributions.JointDistributionCoroutine def model(): ... init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model) states = tfp.mcmc.sample_chain( current_state=init_dist.sample(100, seed=[4, 8]), ...) ``` """ # Canonicalize arguments into the parts we need, namely # the constraining_bijector, the event_shapes, and the dtypes. if model is not None: # Got a Distribution model; treat other arguments as overrides if # present. if constraining_bijector is None: # pylint: disable=protected-access constraining_bijector = model.experimental_default_event_space_bijector() if event_shapes is None: event_shapes = model.event_shape if event_shape_tensors is None: event_shape_tensors = model.event_shape_tensor() if dtypes is None: dtypes = model.dtype if batch_shapes is None: batch_shapes = nest_util.broadcast_structure(dtypes, model.batch_shape) if batch_shape_tensors is None: batch_shape_tensors = nest_util.broadcast_structure( dtypes, model.batch_shape_tensor()) else: if constraining_bijector is None or event_shapes is None or dtypes is None: msg = ('Must pass either a Distribution (typically a JointDistribution), ' 'or a bijector, a structure of event shapes, and a ' 'structure of dtypes') raise ValueError(msg) event_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(event_shapes)) if not event_shapes_fully_defined and event_shape_tensors is None: raise ValueError('Must specify `event_shape_tensors` when `event_shapes` ' f'are not fully-defined: {event_shapes}') if batch_shapes is None: batch_shapes = tf.TensorShape([]) batch_shapes = nest_util.broadcast_structure(dtypes, batch_shapes) batch_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(batch_shapes)) if batch_shape_tensors is None: if not batch_shapes_fully_defined: raise ValueError( 'Must specify `batch_shape_tensors` when `batch_shapes` are not ' f'fully-defined: {batch_shapes}') batch_shape_tensors = tf.nest.map_structure( tf.convert_to_tensor, batch_shapes) # Interpret a structure of Bijectors as the joint multipart bijector. if not isinstance(constraining_bijector, tfb.Bijector): constraining_bijector = tfb.JointMap(constraining_bijector) # Actually initialize def one_term(event_shape, event_shape_tensor, batch_shape, batch_shape_tensor, dtype): if not tensorshape_util.is_fully_defined(event_shape): event_shape = event_shape_tensor result = tfd.Sample( tfd.Uniform(low=tf.constant(-2., dtype=dtype), high=tf.constant(2., dtype=dtype)), sample_shape=event_shape) if not tensorshape_util.is_fully_defined(batch_shape): batch_shape = batch_shape_tensor needs_bcast = True else: # Only batch broadcast when batch ndims > 0. needs_bcast = bool(tensorshape_util.as_list(batch_shape)) if needs_bcast: result = tfd.BatchBroadcast(result, batch_shape) return result inv_shapes = constraining_bijector.inverse_event_shape(event_shapes) if event_shape_tensors is not None: inv_shape_tensors = constraining_bijector.inverse_event_shape_tensor( event_shape_tensors) else: inv_shape_tensors = tf.nest.map_structure(lambda _: None, inv_shapes) inv_dtypes = constraining_bijector.inverse_dtype(dtypes) terms = tf.nest.map_structure( one_term, inv_shapes, inv_shape_tensors, batch_shapes, batch_shape_tensors, inv_dtypes) unconstrained = tfb.pack_sequence_as(inv_shapes)( tfd.JointDistributionSequential(tf.nest.flatten(terms))) return tfd.TransformedDistribution( unconstrained, bijector=constraining_bijector)
def _infer_min_event_ndims(bijectors): """Computes `min_event_ndims` for a sequence of bijectors.""" # Find the index of the first bijector with statically-known min_event_ndims. try: idx = next(i for i, b in enumerate(bijectors) if b.has_static_min_event_ndims) except StopIteration: # If none of the nested bijectors have static min_event_ndims, give up # and return tail-structures filled with `None`. return (nest_util.broadcast_structure( bijectors[-1].forward_min_event_ndims, None), nest_util.broadcast_structure( bijectors[0].inverse_min_event_ndims, None)) # Accumulator tracking the maximum value of "min_event_ndims - ndims". rolling_offset = 0 def update_event_ndims(input_event_ndims, input_min_event_ndims, output_min_event_ndims): """Returns output_event_ndims and updates rolling_offset as needed.""" nonlocal rolling_offset ldj_reduce_ndims = bijector_lib.ldj_reduction_ndims( input_event_ndims, input_min_event_ndims) # Update rolling_offset when batch_ndims are negative. rolling_offset = ps.maximum(rolling_offset, -ldj_reduce_ndims) return nest.map_structure(lambda nd: ldj_reduce_ndims + nd, output_min_event_ndims) def sanitize_event_ndims(event_ndims): """Updates `rolling_offset` when event_ndims are negative.""" nonlocal rolling_offset max_missing_ndims = -ps.reduce_min(nest.flatten(event_ndims)) rolling_offset = ps.maximum(rolling_offset, max_missing_ndims) return event_ndims # Wrappers for Bijector.forward_event_ndims and Bijector.inverse_event_ndims # that recursively walk into Composition bijectors when static min_event_ndims # is not available. def update_f_event_ndims(bij, event_ndims): event_ndims = nest_util.coerce_structure(bij.inverse_min_event_ndims, event_ndims) if bij.has_static_min_event_ndims: return update_event_ndims( input_event_ndims=event_ndims, input_min_event_ndims=bij.inverse_min_event_ndims, output_min_event_ndims=bij.forward_min_event_ndims) elif isinstance(bij, composition.Composition): return bij._call_walk_inverse(update_f_event_ndims, event_ndims) # pylint: disable=protected-access else: return sanitize_event_ndims(bij.inverse_event_ndims(event_ndims)) def update_i_event_ndims(bij, event_ndims): event_ndims = nest_util.coerce_structure(bij.forward_min_event_ndims, event_ndims) if bij.has_static_min_event_ndims: return update_event_ndims( input_event_ndims=event_ndims, input_min_event_ndims=bij.forward_min_event_ndims, output_min_event_ndims=bij.inverse_min_event_ndims) elif isinstance(bij, composition.Composition): return bij._call_walk_forward(update_i_event_ndims, event_ndims) # pylint: disable=protected-access else: return sanitize_event_ndims(bij.forward_event_ndims(event_ndims)) # Initialize event_ndims to the first statically-known min_event_ndims in # the Chain of bijectors. f_event_ndims = i_event_ndims = bijectors[idx].inverse_min_event_ndims for b in bijectors[idx:]: f_event_ndims = update_f_event_ndims(b, f_event_ndims) for b in reversed(bijectors[:idx]): i_event_ndims = update_i_event_ndims(b, i_event_ndims) # Shift both event_ndims to satisfy min_event_ndims for nested components. return (nest.map_structure(lambda nd: rolling_offset + nd, f_event_ndims), nest.map_structure(lambda nd: rolling_offset + nd, i_event_ndims))
def __init__(self, output_structure, input_structure=None, name='restructure'): """Creates a `Restructure` bijector. Args: output_structure: A tf.nest-compatible structure of tokens describing the output of `forward` (equivalently, the input of `inverse`). input_structure: A tf.nest-compatible structure of tokens describing the input to `forward`. If unspecified, a default structure is inferred from `output_structure`. The default structure expects a `list` if tokens are integers, or a `dict` if the tokens are strings. name: Name of this bijector. Raises: ValueError: If tokens are duplicated, or a required default structure cannot be inferred. """ parameters = dict(locals()) # Get the flat set of tokens, making sure they're unique. output_tokens = unique_token_set(output_structure) # Create a default input_structure when it isn't provided. if input_structure is None: # If all tokens are strings, assume input is a dict. if all(isinstance(tok, six.string_types) for tok in output_tokens): input_structure = {token: token for token in output_tokens} # If tokens are contiguous 0-based ints, return a list. elif (all( isinstance(tok, six.integer_types) for tok in output_tokens) and output_tokens == set(range(len(output_tokens)))): input_structure = list(range(len(output_tokens))) # Otherwise, we cannot infer a default structure. else: raise ValueError( ('Tokens in output_structure must be all strings or ' 'contiguous 0-based indices when input_structure ' 'is not specified. Saw: {}').format(output_tokens)) # If input_structure _is_ provided, make sure tokens are unique # and that they match the output_structure tokens. else: input_tokens = unique_token_set(output_structure) if input_tokens != output_tokens: raise ValueError( ('The `input_structure` tokens must match the ' '`output_structure` tokens exactly. Missing from ' '`input_structure`: {}. Missing from ' '`output_structure`: {}.').format( output_tokens - input_tokens, input_tokens - output_tokens)) self._input_structure = self._no_dependency(input_structure) self._output_structure = self._no_dependency(output_structure) super(Restructure, self).__init__( forward_min_event_ndims=nest_util.broadcast_structure( self._input_structure, 0), inverse_min_event_ndims=nest_util.broadcast_structure( self._output_structure, 0), is_constant_jacobian=True, validate_args=False, parameters=parameters, name=name)
def testBroadcastStructure(self, from_structure, to_structure, expected): ret = nest_util.broadcast_structure(to_structure, from_structure) self.assertAllEqual(expected, ret)
def _factored_surrogate_posterior( # pylint: disable=dangerous-default-value event_shape=None, bijector=None, batch_shape=(), base_distribution_cls=normal.Normal, initial_parameters={'scale': 1e-2}, dtype=tf.float32, validate_args=False, name=None): """Builds a joint variational posterior that factors over model variables. By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent. Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. bijector: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may contain `None` values. A posterior variable will be modeled as `tfd.TransformedDistribution(underlying_dist, bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. batch_shape: The `batch_shape` of the output distribution. Default value: `()`. base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated and optionally transformed by the bijector to define the component distributions. May optionally be a structure of such subclasses matching `event_shape`. Default value: `tfd.Normal`. initial_parameters: Optional `str : Tensor` dictionary specifying initial values for some or all of the base distribution's trainable parameters, or a Python `callable` with signature `value = parameter_init_fn(parameter_name, shape, dtype, seed, constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`. May optionally be a structure matching `event_shape` of such dictionaries and/or callables. Dictionary entries that do not correspond to parameter names are ignored. Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does not have a `scale` parameter). dtype: Optional float `dtype` for trainable parameters. May optionally be a structure of such `dtype`s matching `event_shape`. Default value: `tf.float32`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. ### Examples Consider a Gamma model with unknown parameters, expressed as a joint Distribution: ```python Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) ``` Let's use variational inference to approximate the posterior over the data-generating parameters for some observed `y`. We'll build a surrogate posterior distribution by specifying the shapes of the latent `rate` and `concentration` parameters, and that both are constrained to be positive. ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. bijector=[tfb.Softplus(), # Rate is positive. tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` to fit this distribution by minimizing a divergence to the true posterior. ```python y = [0.2, 0.5, 0.3, 0.7] losses = tfp.vi.fit_surrogate_posterior( lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate # samples from the true posterior. samples = surrogate_posterior.sample(100) posterior_mean = [tf.reduce_mean(x) for x in samples] # mean ~= [1.1, 2.1] posterior_std = [tf.math.reduce_std(x) for x in samples] # std ~= [0.3, 0.8] ``` If we wanted to initialize the optimization at a specific location, we can specify initial parameters when we build the surrogate posterior. Note that these parameterize the distribution(s) over unconstrained values, so we need to transform our desired constrained locations using the inverse of the constraining bijector(s). ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), bijector={'concentration': tfb.Softplus(), # Rate is positive. 'rate': tfb.Softplus()} # Concentration is positive. initial_parameters={ 'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2}, 'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}}) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): # Convert event shapes to Tensors. shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) if nest.is_nested(bijector): event_space_bijector = joint_map.JointMap( nest.map_structure( lambda b: identity.Identity() if b is None else b, nest_util.coerce_structure(event_shape, bijector)), validate_args=validate_args) else: event_space_bijector = bijector if event_space_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( event_space_bijector.inverse_event_shape_tensor(event_shape)) unconstrained_batch_and_event_shape = tf.nest.map_structure( lambda s: ps.concat([batch_shape, s], axis=0), unconstrained_event_shape) base_distribution_cls = nest_util.broadcast_structure( event_shape, base_distribution_cls) try: # Check that we have initial parameters for each event part. nest.assert_shallow_structure(event_shape, initial_parameters) except (ValueError, TypeError): # If not, broadcast the parameters to match the event structure. # We do this manually rather than using `nest_util.broadcast_structure` # because the initial parameters can themselves be structures (dicts). initial_parameters = nest.map_structure( lambda x: initial_parameters, event_shape) unconstrained_trainable_distributions = yield from ( nest_util.map_structure_coroutine( trainable._make_trainable, # pylint: disable=protected-access cls=base_distribution_cls, initial_parameters=initial_parameters, batch_and_event_shape=unconstrained_batch_and_event_shape, parameter_dtype=nest_util.broadcast_structure( event_shape, dtype), _up_to=event_shape)) unconstrained_trainable_distribution = ( joint_distribution_util. independent_joint_distribution_from_structure( unconstrained_trainable_distributions, batch_ndims=ps.rank_from_shape(batch_shape), validate_args=validate_args)) if event_space_bijector is None: return unconstrained_trainable_distribution return transformed_distribution.TransformedDistribution( unconstrained_trainable_distribution, event_space_bijector)
def step_broadcast(step_size): return step_bijector( nest_util.broadcast_structure(pinned_model.event_shape_tensor(), step_size))
def __init__(self, output_structure, input_structure=None, name='restructure'): """Converts between nested structures of Tensor. This is useful when constructing non-trivial chains of multipart bijectors. It partitions inputs into different logical "blocks", which may be fed as arguments to downstream multipart bijectors. Example Usage: ```python # What restructure does: restructure = Restructure({ 'foo': [0, 1], 'bar': [3, 2], 'baz': [4, 5, 6] }) # Note that x is a *python-list* of tensors. # To permute elements of an individual Tensor, see `tfb.Permute`. x = [1, 2, 4, 8, 16, 32, 64] assert restructure.forward() == { 'foo': [1, 2], 'bar': [8, 4], 'baz': [16, 32, 64] } # Where restructure is useful: complex_bijector = Chain([ # Apply different transformations to each block. JointMap({ 'foo': ScaleMatVecLinearOperator(...), # Operates on the full block 'bar': ScaleMatVecLinearOperator(...), # Operates on the full block 'baz': [Exp(), Scale(10.), Shift(-1.)] # Different bijectors for each }), # Group the tensor into logical blocks. Restructure({ 'foo': [0, 1], 'bar': [3, 2], 'baz': [4, 5, 6], }), # Split an input tensor into 7 chunks. Split([2, 4, 6, 8, 10, 12, 14]) ]) ``` Args: output_structure: A tf.nest-compatible structure of tokens describing the output of `forward` (equivalently, the input of `inverse`). input_structure: A tf.nest-compatible structure of tokens describing the input to `forward`. If unspecified, a default structure is inferred from `output_structure`. The default structure expects a `list` if tokens are integers, or a `dict` if the tokens are strings. name: Name of this bijector. Raises: ValueError: If tokens are duplicated, or a required default structure cannot be inferred. """ parameters = dict(locals()) # Get the flat set of tokens, making sure they're unique. output_tokens = unique_token_set(output_structure) # Create a default input_structure when it isn't provided. if input_structure is None: # If all tokens are strings, assume input is a dict. if all(isinstance(tok, six.string_types) for tok in output_tokens): input_structure = {token: token for token in output_tokens} # If tokens are contiguous 0-based ints, return a list. elif (all( isinstance(tok, six.integer_types) for tok in output_tokens) and output_tokens == set(range(len(output_tokens)))): input_structure = list(range(len(output_tokens))) # Otherwise, we cannot infer a default structure. else: raise ValueError( ('Tokens in output_structure must be all strings or ' 'contiguous 0-based indices when input_structure ' 'is not specified. Saw: {}').format(output_tokens)) # If input_structure _is_ provided, make sure tokens are unique # and that they match the output_structure tokens. else: input_tokens = unique_token_set(output_structure) if input_tokens != output_tokens: raise ValueError( ('The `input_structure` tokens must match the ' '`output_structure` tokens exactly. Missing from ' '`input_structure`: {}. Missing from ' '`output_structure`: {}.').format( output_tokens - input_tokens, input_tokens - output_tokens)) self._input_structure = self._no_dependency(input_structure) self._output_structure = self._no_dependency(output_structure) super(Restructure, self).__init__( forward_min_event_ndims=nest_util.broadcast_structure( self._input_structure, None), inverse_min_event_ndims=nest_util.broadcast_structure( self._output_structure, None), is_constant_jacobian=True, validate_args=False, parameters=parameters, name=name)