示例#1
0
    def testSmartConstructors(self):
        x = collections.OrderedDict([('a', [1, 2, 3]), ('b', [4, 5, 6]),
                                     ('c', 7.), ('d', 8.), ('e', 9.)])
        bij = tfb.tree_flatten(x)

        bij_2 = tfb.pack_sequence_as(x)

        # Invert assertion arguments to infer structure from bijector output.
        self.assertAllEqualNested(bij.forward(x), [1, 2, 3, 4, 5, 6, 7, 8, 9],
                                  check_types=True)
        self.assertAllEqualNested(bij_2.forward(bij.forward(x)),
                                  x,
                                  check_types=True)
def init_near_unconstrained_zero(
    model=None, constraining_bijector=None, event_shapes=None,
    event_shape_tensors=None, batch_shapes=None, batch_shape_tensors=None,
    dtypes=None):
  """Returns an initialization Distribution for starting a Markov chain.

  This initialization scheme follows Stan: we sample every latent
  independently, uniformly from -2 to 2 in its unconstrained space,
  and then transform into constrained space to construct an initial
  state that can be passed to `sample_chain` or other MCMC drivers.

  The argument signature is arranged to let the user pass either a
  `JointDistribution` describing their model, if it's in that form, or
  the essential information necessary for the sampling, namely a
  bijector (from unconstrained to constrained space) and the desired
  shape and dtype of each sample (specified in constrained space).

  Note: As currently implemented, this function has the limitation
  that the batch shape of the supplied model is ignored, but that
  could probably be generalized if needed.

  Args:
    model: A `Distribution` (typically a `JointDistribution`) giving the
      model to be initialized.  If supplied, it is queried for
      its default event space bijector, its event shape, and its dtype.
      If not supplied, those three elements must be supplied instead.
    constraining_bijector: A (typically multipart) `Bijector` giving
      the mapping from unconstrained to constrained space.  If
      supplied together with a `model`, acts as an override.  A nested
      structure of `Bijector`s is accepted, and interpreted as
      applying in parallel to a corresponding structure of state parts
      (see `JointMap` for details).
    event_shapes: A structure of shapes giving the (unconstrained)
      event space shape of the desired samples.  Must be an acceptable
      input to `constraining_bijector.inverse_event_shape`.  If
      supplied together with `model`, acts as an override.
    event_shape_tensors: A structure of tensors giving the (unconstrained)
      event space shape of the desired samples.  Must be an acceptable
      input to `constraining_bijector.inverse_event_shape_tensor`.  If
      supplied together with `model`, acts as an override. Required if any of
      `event_shapes` are not fully-defined.
    batch_shapes: A structure of shapes giving the batch shape of the desired
      samples.  If supplied together with `model`, acts as an override.  If
      unspecified, we assume scalar batch `[]`.
    batch_shape_tensors: A structure of tensors giving the batch shape of the
      desired samples.  If supplied together with `model`, acts as an override.
      Required if any of `batch_shapes` are not fully-defined.
    dtypes: A structure of dtypes giving the (unconstrained) dtypes of
      the desired samples.  Must be an acceptable input to
      `constraining_bijector.inverse_dtype`.  If supplied together
      with `model`, acts as an override.

  Returns:
    init_dist: A `Distribution` representing the initialization
      distribution, in constrained space.  Samples from this
      `Distribution` are valid initial states for a Markov chain
      targeting the model.

  #### Example

  Initialize 100 chains from the unconstrained -2, 2 distribution
  for a model expressed as a `JointDistributionCoroutine`:

  ```python
  @tfp.distributions.JointDistributionCoroutine
  def model():
    ...

  init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model)
  states = tfp.mcmc.sample_chain(
    current_state=init_dist.sample(100, seed=[4, 8]),
    ...)
  ```

  """
  # Canonicalize arguments into the parts we need, namely
  # the constraining_bijector, the event_shapes, and the dtypes.
  if model is not None:
    # Got a Distribution model; treat other arguments as overrides if
    # present.
    if constraining_bijector is None:
      # pylint: disable=protected-access
      constraining_bijector = model.experimental_default_event_space_bijector()
    if event_shapes is None:
      event_shapes = model.event_shape
    if event_shape_tensors is None:
      event_shape_tensors = model.event_shape_tensor()
    if dtypes is None:
      dtypes = model.dtype
    if batch_shapes is None:
      batch_shapes = nest_util.broadcast_structure(dtypes, model.batch_shape)
    if batch_shape_tensors is None:
      batch_shape_tensors = nest_util.broadcast_structure(
          dtypes, model.batch_shape_tensor())

  else:
    if constraining_bijector is None or event_shapes is None or dtypes is None:
      msg = ('Must pass either a Distribution (typically a JointDistribution), '
             'or a bijector, a structure of event shapes, and a '
             'structure of dtypes')
      raise ValueError(msg)
    event_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s)
                                     for s in tf.nest.flatten(event_shapes))
    if not event_shapes_fully_defined and event_shape_tensors is None:
      raise ValueError('Must specify `event_shape_tensors` when `event_shapes` '
                       f'are not fully-defined: {event_shapes}')
    if batch_shapes is None:
      batch_shapes = tf.TensorShape([])
    batch_shapes = nest_util.broadcast_structure(dtypes, batch_shapes)
    batch_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s)
                                     for s in tf.nest.flatten(batch_shapes))
    if batch_shape_tensors is None:
      if not batch_shapes_fully_defined:
        raise ValueError(
            'Must specify `batch_shape_tensors` when `batch_shapes` are not '
            f'fully-defined: {batch_shapes}')
      batch_shape_tensors = tf.nest.map_structure(
          tf.convert_to_tensor, batch_shapes)

  # Interpret a structure of Bijectors as the joint multipart bijector.
  if not isinstance(constraining_bijector, tfb.Bijector):
    constraining_bijector = tfb.JointMap(constraining_bijector)

  # Actually initialize
  def one_term(event_shape, event_shape_tensor, batch_shape, batch_shape_tensor,
               dtype):
    if not tensorshape_util.is_fully_defined(event_shape):
      event_shape = event_shape_tensor
    result = tfd.Sample(
        tfd.Uniform(low=tf.constant(-2., dtype=dtype),
                    high=tf.constant(2., dtype=dtype)),
        sample_shape=event_shape)
    if not tensorshape_util.is_fully_defined(batch_shape):
      batch_shape = batch_shape_tensor
      needs_bcast = True
    else:  # Only batch broadcast when batch ndims > 0.
      needs_bcast = bool(tensorshape_util.as_list(batch_shape))
    if needs_bcast:
      result = tfd.BatchBroadcast(result, batch_shape)
    return result

  inv_shapes = constraining_bijector.inverse_event_shape(event_shapes)
  if event_shape_tensors is not None:
    inv_shape_tensors = constraining_bijector.inverse_event_shape_tensor(
        event_shape_tensors)
  else:
    inv_shape_tensors = tf.nest.map_structure(lambda _: None, inv_shapes)
  inv_dtypes = constraining_bijector.inverse_dtype(dtypes)
  terms = tf.nest.map_structure(
      one_term, inv_shapes, inv_shape_tensors, batch_shapes,
      batch_shape_tensors, inv_dtypes)
  unconstrained = tfb.pack_sequence_as(inv_shapes)(
      tfd.JointDistributionSequential(tf.nest.flatten(terms)))
  return tfd.TransformedDistribution(
      unconstrained, bijector=constraining_bijector)