Example #1
0
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    A `tfb.Bijector` where the `.forward` method flattens and unconstrains
    points.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    to_chain = [jd_model.experimental_default_event_space_bijector()]
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())
    to_chain.append(flat_bijector)

    unconstrained_shapes = flat_bijector.inverse_event_shape_tensor(
        jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    reshapers = [
        reshape.Reshape(event_shape_out=x, event_shape_in=[-1])
        for x in unconstrained_shapes
    ]
    to_chain.append(joint_map.JointMap(bijectors=reshapers))

    size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes]
    to_chain.append(split.Split(num_or_size_splits=size_splits))

    return invert.Invert(chain.Chain(to_chain))
Example #2
0
def build_split_flow_surrogate_posterior(event_shape,
                                         trainable_bijector,
                                         constraining_bijector=None,
                                         base_distribution=normal.Normal,
                                         batch_shape=(),
                                         dtype=tf.float32,
                                         validate_args=False,
                                         name=None):
    """Builds a joint variational posterior by splitting a normalizing flow.

  Args:
    event_shape: (Nested) event shape of the surrogate posterior.
    trainable_bijector: A trainable `tfb.Bijector` instance that operates on
      `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or
      `tfb.RealNVP`. This bijector transforms the base distribution before it is
      split.
    constraining_bijector: `tfb.Bijector` instance, or nested structure of
      `tfb.Bijector` instances, that maps (nested) values in R^n to the support
      of the posterior. (This can be the
      `experimental_default_event_space_bijector` of the distribution over the
      prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution for the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_split_flow_surrogate_posterior').

  Returns:
    surrogate_distribution: Trainable `tfd.TransformedDistribution` with event
      shape equal to `event_shape`.

  ### Examples
  ```python

  # Train a normalizing flow on the Eight Schools model [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
  model = tfd.JointDistributionNamed({
      'avg_effect':
          tfd.Normal(loc=0., scale=10., name='avg_effect'),
      'log_stddev':
          tfd.Normal(loc=5., scale=1., name='log_stddev'),
      'school_effects':
          lambda log_stddev, avg_effect: (
              tfd.Independent(
                  tfd.Normal(
                      loc=avg_effect[..., None] * tf.ones(8),
                      scale=tf.exp(log_stddev[..., None]) * tf.ones(8),
                      name='school_effects'),
                  reinterpreted_batch_ndims=1)),
      'treatment_effects': lambda school_effects: tfd.Independent(
          tfd.Normal(loc=school_effects, scale=treatment_stddevs),
          reinterpreted_batch_ndims=1)
  })

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Create a Masked Autoregressive Flow bijector.
  net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32)
  maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net)

  # Build and fit the surrogate posterior.
  surrogate_posterior = (
      tfp.experimental.vi.build_split_flow_surrogate_posterior(
          event_shape=target_model.event_shape_tensor(),
          trainable_bijector=maf,
          constraining_bijector=(
              target_model.experimental_default_event_space_bijector())))

  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name or 'build_split_flow_surrogate_posterior'):

        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(shallow_structure,
                                               ps.convert_to_shape_tensor,
                                               event_shape)

        if nest.is_nested(constraining_bijector):
            constraining_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity()
                    if b is None else b, constraining_bijector),
                validate_args=validate_args)

        if constraining_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                constraining_bijector.inverse_event_shape_tensor(event_shape))

        flat_base_event_shape = nest.flatten(unconstrained_event_shape)
        flat_base_event_size = nest.map_structure(tf.reduce_prod,
                                                  flat_base_event_shape)
        event_size = tf.reduce_sum(flat_base_event_size)

        base_distribution = sample.Sample(
            base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.),
            [event_size])

        # After transforming base distribution samples with `trainable_bijector`,
        # split them into vector-valued components.
        split_bijector = split.Split(flat_base_event_size,
                                     validate_args=validate_args)

        # Reshape the vectors to the correct posterior event shape.
        event_reshape = joint_map.JointMap(nest.map_structure(
            reshape.Reshape, unconstrained_event_shape),
                                           validate_args=validate_args)

        # Restructure the flat list of components to the correct posterior
        # structure.
        event_unflatten = restructure.Restructure(
            nest.pack_sequence_as(unconstrained_event_shape,
                                  range(len(flat_base_event_shape))))

        bijectors = [] if constraining_bijector is None else [
            constraining_bijector
        ]
        bijectors.extend([
            event_reshape, event_unflatten, split_bijector, trainable_bijector
        ])
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        return transformed_distribution.TransformedDistribution(
            base_distribution, bijector=bijector, validate_args=validate_args)
Example #3
0
  def __init__(self,
               bijectors,
               block_sizes=None,
               validate_args=False,
               maybe_changes_size=True,
               name=None):
    """Creates the bijector.

    Args:
      bijectors: A non-empty list of bijectors.
      block_sizes: A 1-D integer `Tensor` with each element signifying the
        length of the block of the input vector to pass to the corresponding
        bijector. The length of `block_sizes` must be be equal to the length of
        `bijectors`. If left as None, a vector of 1's is used.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      maybe_changes_size: Python `bool` indicating that this bijector might
        change the event size. If this is known to be false and set
        appropriately, then this will lead to improved static shape inference
        when the block sizes are not statically known.
      name: Python `str`, name given to ops managed by this object. Default:
        E.g., `Blockwise([Exp(), Softplus()]).name ==
        'blockwise_of_exp_and_softplus'`.

    Raises:
      NotImplementedError: If there is a bijector with `event_ndims` > 1.
      ValueError: If `bijectors` list is empty.
      ValueError: If size of `block_sizes` does not equal to the length of
        bijectors or is not a vector.
    """
    parameters = dict(locals())
    if not name:
      name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors])
      name = name.replace('/', '')

    with tf.name_scope(name) as name:
      for b in bijectors:
        if (nest.is_nested(b.forward_min_event_ndims)
            or nest.is_nested(b.inverse_min_event_ndims)):
          raise ValueError('Bijectors must all be single-part.')
        elif isinstance(b.forward_min_event_ndims, int):
          if b.forward_min_event_ndims != b.inverse_min_event_ndims:
            raise ValueError('Rank-changing bijectors are not supported.')
          elif b.forward_min_event_ndims > 1:
            raise ValueError('Only scalar and vector event-shape '
                             'bijectors are supported at this time.')

      b_joint = joint_map.JointMap(list(bijectors), name='jointmap')

      block_sizes = (
          np.ones(len(bijectors), dtype=np.int32)
          if block_sizes is None else
          _validate_block_sizes(block_sizes, bijectors, validate_args))
      b_split = split.Split(
          block_sizes, name='split', validate_args=validate_args)

      if maybe_changes_size:
        i_block_sizes = _validate_block_sizes(
            ps.concat(b_joint.forward_event_shape_tensor(
                ps.split(block_sizes, len(bijectors))), axis=0),
            bijectors, validate_args)
        maybe_changes_size = not tf.get_static_value(
            ps.reduce_all(block_sizes == i_block_sizes))
      b_concat = invert.Invert(
          (split.Split(i_block_sizes, name='isplit')
           if maybe_changes_size else b_split),
          name='concat')

      self._maybe_changes_size = maybe_changes_size
      super(Blockwise, self).__init__(
          bijectors=[b_concat, b_joint, b_split],
          validate_args=validate_args,
          parameters=parameters,
          name=name)