def _get_flat_unconstraining_bijector(jd_model): """Create a bijector from a joint distribution that flattens and unconstrains. The intention is (loosely) to go from a model joint distribution supported on U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j} to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense of base measures: some distribution may be supported on an m-dimensional subset of R^n, and the default transform for that distribution may then have support on R^m. See [1] for details. Args: jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a model. Returns: A `tfb.Bijector` where the `.forward` method flattens and unconstrains points. """ # TODO(b/180396233): This bijector is in general point-dependent. to_chain = [jd_model.experimental_default_event_space_bijector()] flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor()) to_chain.append(flat_bijector) unconstrained_shapes = flat_bijector.inverse_event_shape_tensor( jd_model.event_shape_tensor()) # this reshaping is required as as split can produce a tensor of shape [1] # when the distribution event shape is [] reshapers = [ reshape.Reshape(event_shape_out=x, event_shape_in=[-1]) for x in unconstrained_shapes ] to_chain.append(joint_map.JointMap(bijectors=reshapers)) size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes] to_chain.append(split.Split(num_or_size_splits=size_splits)) return invert.Invert(chain.Chain(to_chain))
def build_split_flow_surrogate_posterior(event_shape, trainable_bijector, constraining_bijector=None, base_distribution=normal.Normal, batch_shape=(), dtype=tf.float32, validate_args=False, name=None): """Builds a joint variational posterior by splitting a normalizing flow. Args: event_shape: (Nested) event shape of the surrogate posterior. trainable_bijector: A trainable `tfb.Bijector` instance that operates on `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or `tfb.RealNVP`. This bijector transforms the base distribution before it is split. constraining_bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector` instances, that maps (nested) values in R^n to the support of the posterior. (This can be the `experimental_default_event_space_bijector` of the distribution over the prior latent variables.) Default value: `None` (i.e., the posterior is over R^n). base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and `scale`. The base distribution for the transformed surrogate has `loc=0.` and `scale=1.`. Default value: `tfd.Normal`. batch_shape: The `batch_shape` of the output distribution. Default value: `()`. dtype: The `dtype` of the surrogate posterior. Default value: `tf.float32`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_split_flow_surrogate_posterior'). Returns: surrogate_distribution: Trainable `tfd.TransformedDistribution` with event shape equal to `event_shape`. ### Examples ```python # Train a normalizing flow on the Eight Schools model [1]. treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.] treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.] model = tfd.JointDistributionNamed({ 'avg_effect': tfd.Normal(loc=0., scale=10., name='avg_effect'), 'log_stddev': tfd.Normal(loc=5., scale=1., name='log_stddev'), 'school_effects': lambda log_stddev, avg_effect: ( tfd.Independent( tfd.Normal( loc=avg_effect[..., None] * tf.ones(8), scale=tf.exp(log_stddev[..., None]) * tf.ones(8), name='school_effects'), reinterpreted_batch_ndims=1)), 'treatment_effects': lambda school_effects: tfd.Independent( tfd.Normal(loc=school_effects, scale=treatment_stddevs), reinterpreted_batch_ndims=1) }) # Pin the observed values in the model. target_model = model.experimental_pin(treatment_effects=treatment_effects) # Create a Masked Autoregressive Flow bijector. net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32) maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net) # Build and fit the surrogate posterior. surrogate_posterior = ( tfp.experimental.vi.build_split_flow_surrogate_posterior( event_shape=target_model.event_shape_tensor(), trainable_bijector=maf, constraining_bijector=( target_model.experimental_default_event_space_bijector()))) losses = tfp.vi.fit_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) ``` #### References [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013. """ with tf.name_scope(name or 'build_split_flow_surrogate_posterior'): shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to(shallow_structure, ps.convert_to_shape_tensor, event_shape) if nest.is_nested(constraining_bijector): constraining_bijector = joint_map.JointMap( nest.map_structure( lambda b: identity.Identity() if b is None else b, constraining_bijector), validate_args=validate_args) if constraining_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( constraining_bijector.inverse_event_shape_tensor(event_shape)) flat_base_event_shape = nest.flatten(unconstrained_event_shape) flat_base_event_size = nest.map_structure(tf.reduce_prod, flat_base_event_shape) event_size = tf.reduce_sum(flat_base_event_size) base_distribution = sample.Sample( base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.), [event_size]) # After transforming base distribution samples with `trainable_bijector`, # split them into vector-valued components. split_bijector = split.Split(flat_base_event_size, validate_args=validate_args) # Reshape the vectors to the correct posterior event shape. event_reshape = joint_map.JointMap(nest.map_structure( reshape.Reshape, unconstrained_event_shape), validate_args=validate_args) # Restructure the flat list of components to the correct posterior # structure. event_unflatten = restructure.Restructure( nest.pack_sequence_as(unconstrained_event_shape, range(len(flat_base_event_shape)))) bijectors = [] if constraining_bijector is None else [ constraining_bijector ] bijectors.extend([ event_reshape, event_unflatten, split_bijector, trainable_bijector ]) bijector = chain.Chain(bijectors, validate_args=validate_args) return transformed_distribution.TransformedDistribution( base_distribution, bijector=bijector, validate_args=validate_args)
def __init__(self, bijectors, block_sizes=None, validate_args=False, maybe_changes_size=True, name=None): """Creates the bijector. Args: bijectors: A non-empty list of bijectors. block_sizes: A 1-D integer `Tensor` with each element signifying the length of the block of the input vector to pass to the corresponding bijector. The length of `block_sizes` must be be equal to the length of `bijectors`. If left as None, a vector of 1's is used. validate_args: Python `bool` indicating whether arguments should be checked for correctness. maybe_changes_size: Python `bool` indicating that this bijector might change the event size. If this is known to be false and set appropriately, then this will lead to improved static shape inference when the block sizes are not statically known. name: Python `str`, name given to ops managed by this object. Default: E.g., `Blockwise([Exp(), Softplus()]).name == 'blockwise_of_exp_and_softplus'`. Raises: NotImplementedError: If there is a bijector with `event_ndims` > 1. ValueError: If `bijectors` list is empty. ValueError: If size of `block_sizes` does not equal to the length of bijectors or is not a vector. """ parameters = dict(locals()) if not name: name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors]) name = name.replace('/', '') with tf.name_scope(name) as name: for b in bijectors: if (nest.is_nested(b.forward_min_event_ndims) or nest.is_nested(b.inverse_min_event_ndims)): raise ValueError('Bijectors must all be single-part.') elif isinstance(b.forward_min_event_ndims, int): if b.forward_min_event_ndims != b.inverse_min_event_ndims: raise ValueError('Rank-changing bijectors are not supported.') elif b.forward_min_event_ndims > 1: raise ValueError('Only scalar and vector event-shape ' 'bijectors are supported at this time.') b_joint = joint_map.JointMap(list(bijectors), name='jointmap') block_sizes = ( np.ones(len(bijectors), dtype=np.int32) if block_sizes is None else _validate_block_sizes(block_sizes, bijectors, validate_args)) b_split = split.Split( block_sizes, name='split', validate_args=validate_args) if maybe_changes_size: i_block_sizes = _validate_block_sizes( ps.concat(b_joint.forward_event_shape_tensor( ps.split(block_sizes, len(bijectors))), axis=0), bijectors, validate_args) maybe_changes_size = not tf.get_static_value( ps.reduce_all(block_sizes == i_block_sizes)) b_concat = invert.Invert( (split.Split(i_block_sizes, name='isplit') if maybe_changes_size else b_split), name='concat') self._maybe_changes_size = maybe_changes_size super(Blockwise, self).__init__( bijectors=[b_concat, b_joint, b_split], validate_args=validate_args, parameters=parameters, name=name)