Exemple #1
0
  def __init__(self, model):
    """Constructs the adapter.

    Args:
      model: An Inference Gym model.

    Raises:
      TypeError: If `model` has more than one unique Tensor dtype.
    """
    self._model = model
    dtypes = set(
        tf.nest.flatten(tf.nest.map_structure(tf.as_dtype, self._model.dtype)))
    if len(dtypes) > 1:
      raise TypeError('Model must have only one Tensor dtype, saw: {}'.format(
          self._model.dtype))
    dtype = dtypes.pop()

    # TODO(siege): Make this work with multi-part default_event_bijector.
    def _make_reshaped_bijector(b, s):
      return tfb.Chain([
          tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]),
          b,
          tfb.Reshape(event_shape_out=b.inverse_event_shape(s)),
      ])

    reshaped_bijector = tf.nest.map_structure(
        _make_reshaped_bijector, self._model.default_event_space_bijector,
        self._model.event_shape)

    bijector = tfb.Blockwise(
        bijectors=tf.nest.flatten(reshaped_bijector),
        block_sizes=tf.nest.flatten(
            tf.nest.map_structure(
                lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)),  # pylint: disable=g-long-lambda
                self._model.default_event_space_bijector,
                self._model.event_shape)))

    event_sizes = tf.nest.map_structure(
        lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)),
        self._model.default_event_space_bijector, self._model.event_shape)
    event_shape = tf.TensorShape([sum(tf.nest.flatten(event_sizes))])

    sample_transformations = collections.OrderedDict()

    def make_flattened_transform(transform):
      # We yank this out to avoid capturing the loop variable.
      return transform._replace(
          fn=lambda x: transform(self._split_and_reshape_event(x)))

    for key, transform in self._model.sample_transformations.items():
      sample_transformations[key] = make_flattened_transform(transform)

    super(VectorModel, self).__init__(
        default_event_space_bijector=bijector,
        event_shape=event_shape,
        dtype=dtype,
        name='vector_' + self._model.name,
        pretty_name=str(self._model),
        sample_transformations=sample_transformations,
    )
def make_conditional_linear_gaussian(y_event_shape,
                                     x,
                                     x_event_ndims,
                                     variables=None):
    """Build trainable distribution `p(y | x)` conditioned on an input Tensor `x`.

  The distribution is independent Gaussian with mean linearly transformed
  from `x`:
  `y ~ N(loc=matvec(matrix, x) + loc, scale_diag=scale)`

  Args:
    y_event_shape: int `Tensor` event shape.
    x: `Tensor` input to condition on.
    x_event_ndims: int number of dimensions in `x`'s `event_shape`.
    variables: Optional `LinearGaussianVariables` instance, or `None`.
      Default value: `None`.

  Returns:
    dist: Instance of `tfd.Distribution` representing the conditional
      distribution `p(y | x)`.
    variables: Instance of `LinearGaussianVariables` used to parameterize
      `dist`. If a `variables` arg was passed, it is returned unmodified;
      otherwise new variables are created.
  """
    x_shape = ps.shape(x)
    x_ndims = ps.rank_from_shape(x_shape)
    y_event_ndims = ps.rank_from_shape(y_event_shape)
    batch_shape, x_event_shape = (x_shape[:x_ndims - x_event_ndims],
                                  x_shape[x_ndims - x_event_ndims:])

    x_event_size = ps.reduce_prod(x_event_shape)
    y_event_size = ps.reduce_prod(y_event_shape)

    x_flat_shape = ps.concat([batch_shape, [x_event_size]], axis=0)
    y_flat_shape = ps.concat([batch_shape, [y_event_size]], axis=0)
    y_full_shape = ps.concat([batch_shape, y_event_shape], axis=0)

    if variables is None:
        variables = LinearGaussianVariables(
            matrix=tf.Variable(tf.random.normal(ps.concat(
                [batch_shape, [y_event_size, x_event_size]], axis=0),
                                                dtype=x.dtype),
                               name='matrix'),
            loc=tf.Variable(tf.random.normal(y_flat_shape, dtype=x.dtype),
                            name='loc'),
            scale=tfp_util.TransformedVariable(tf.ones(y_full_shape,
                                                       dtype=x.dtype),
                                               bijector=tfb.Softplus(),
                                               name='scale'))

    flat_x = tf.reshape(x, x_flat_shape)
    dist = tfd.Normal(loc=tf.reshape(
        tf.linalg.matvec(variables.matrix, flat_x) + variables.loc,
        y_full_shape),
                      scale=variables.scale)
    if y_event_ndims != 0:
        dist = tfd.Independent(dist, reinterpreted_batch_ndims=y_event_ndims)
    dist._also_track = variables  # pylint: disable=protected-access
    return dist, variables
 def _make_reshaped_bijector(b, s):
   return tfb.Chain([
       tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]),
       b,
       tfb.Reshape(
           event_shape_in=[ps.reduce_prod(b.inverse_event_shape(s))],
           event_shape_out=b.inverse_event_shape(s)),
   ])
Exemple #4
0
def pairwise_square_distance_matrix(x1, x2, feature_ndims):
    """Returns pairwise square distance between x1 and x2.

  Given `x1` and `x2`, Tensors with shape `[..., N, D1, ... Dk]` and
  `[..., M, D1, ... Dk]`, compute the pairwise distance matrix `a_ij` of shape
  `[..., N, M]`, where each entry `a_ij` is the square of the euclidean norm of
  `x1[..., i, ...] - x2[..., j, ...]`.

  The approach uses the fact that (where k = 1).
  ```none
    a_ij = sum_d (x1[i, d] - x2[j, d]) ** 2 =
    sum_d x1[i, d] ** 2 + x2[j, d] ** 2 - 2 * x1[i, d] * x2[j, d]
  ```

  The latter term can be written as a matmul between `x1` and `x2`.
  This reduces the memory from the naive approach of computing the
  squared difference of `x1` and `x2` by a factor of `(prod_k D_k) ** 2`.
  This is at the cost of the computation being more numerically unstable.

  Args:
    x1: Floating point `Tensor` with shape `B1 + [N] + [D1, ..., Dk]`,
      where `B1` is a (possibly empty) batch shape.
    x2: Floating point `Tensor` with shape `B2 + [M] + [D1, ..., Dk]`,
      where `B2` is a (possibly empty) batch shape that broadcasts
      with `B1`.
    feature_ndims: The number of dimensions to consider for the euclidean
      norm. This is `k` from above.
  Returns:
    `Tensor` of shape `[..., N, M]` representing the pairwise square
    distance matrix.
  """
    row_norm_x1 = sum_rightmost_ndims_preserving_shape(
        tf.square(x1), feature_ndims)[..., tf.newaxis]
    row_norm_x2 = sum_rightmost_ndims_preserving_shape(
        tf.square(x2), feature_ndims)[..., tf.newaxis, :]

    x1 = tf.reshape(
        x1,
        ps.concat([
            ps.shape(x1)[:-feature_ndims],
            [ps.reduce_prod(ps.shape(x1)[-feature_ndims:])]
        ],
                  axis=0))
    x2 = tf.reshape(
        x2,
        ps.concat([
            ps.shape(x2)[:-feature_ndims],
            [ps.reduce_prod(ps.shape(x2)[-feature_ndims:])]
        ],
                  axis=0))
    pairwise_sq = row_norm_x1 + row_norm_x2 - 2 * tf.linalg.matmul(
        x1, x2, transpose_b=True)
    pairwise_sq = tf.clip_by_value(pairwise_sq, 0., np.inf)
    return pairwise_sq
Exemple #5
0
 def _augment_sample_shape(self, sample_shape):
     # Suppose we have:
     #   - sample shape of `[n]`,
     #   - underlying distribution batch shape of `[2, 1]`,
     #   - final broadcast batch shape of `[4, 2, 3]`.
     # Then we must draw `sample_shape + [12]` samples, where
     # `12 == n_batch // underlying_n_batch`.
     batch_shape = self.batch_shape_tensor()
     n_batch = ps.reduce_prod(batch_shape)
     underlying_batch_shape = self.distribution.batch_shape_tensor()
     underlying_n_batch = ps.reduce_prod(underlying_batch_shape)
     return ps.concat(
         [sample_shape, [ps.maximum(0, n_batch // underlying_n_batch)]],
         axis=0)
Exemple #6
0
  def _split_and_reshape_event(self, x):
    event_tensors = self._distribution.event_shape_tensor()
    splits = [
        ps.maximum(1, ps.reduce_prod(s))
        for s in tf.nest.flatten(event_tensors)
    ]
    x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1))

    def _reshape_part(part, dtype, event_shape):
      part = tf.cast(part, dtype)
      static_rank = tf.get_static_value(ps.rank_from_shape(event_shape))
      if static_rank == 1:
        return part
      new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
      return tf.reshape(part, ps.cast(new_shape, tf.int32))

    if all(
        tensorshape_util.is_fully_defined(s)
        for s in tf.nest.flatten(self._distribution.event_shape)):
      x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype,
                                self._distribution.event_shape)
    else:
      x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype,
                                self._distribution.event_shape_tensor())
    return x
Exemple #7
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         # Avoid computing intermediates needed to construct the assertions.
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self._batch_shape_unexpanded):
         implicit_dim_mask = ps.equal(self._batch_shape_unexpanded, -1)
         assertions.append(
             assert_util.assert_rank(self._batch_shape_unexpanded,
                                     1,
                                     message='New shape must be a vector.'))
         assertions.append(
             assert_util.assert_less_equal(
                 tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32),
                 1,
                 message='At most one dimension can be unknown.'))
         assertions.append(
             assert_util.assert_non_negative(
                 self._batch_shape_unexpanded + 1,
                 message='Shape elements must be >=-1.'))
         # Check that the old and new shapes are the same size.
         expanded_new_shape, original_size = self._calculate_new_shape()
         new_size = ps.reduce_prod(expanded_new_shape)
         assertions.append(
             assert_util.assert_equal(new_size,
                                      tf.cast(original_size,
                                              new_size.dtype),
                                      message='Shape sizes do not match.'))
     return assertions
Exemple #8
0
def _axis_size(x, axis=None):
    """Get number of elements of `x` in `axis`, as type `x.dtype`."""
    if axis is None:
        return prefer_static.cast(prefer_static.size(x), x.dtype)
    return prefer_static.cast(
        prefer_static.reduce_prod(
            prefer_static.gather(prefer_static.shape(x), axis)), x.dtype)
 def _sample_direction_part(state_part, part_seed):
     state_part_shape = ps.shape(state_part)
     batch_shape = state_part_shape[:batch_rank]
     dimension = ps.reduce_prod(state_part_shape[batch_rank:])
     return ps.reshape(
         random_ops.spherical_uniform(shape=batch_shape,
                                      dimension=dimension,
                                      dtype=state_part.dtype,
                                      seed=part_seed), state_part_shape)
def iid_sample(sample_fn, sample_shape):
  """Lift a sampling function to one that draws multiple iid samples.

  Args:
    sample_fn: Python `callable` that returns a (possibly nested) structure of
      `Tensor`s. May optionally take a `seed` named arg: if so, any `int`
      seeds (for stateful samplers) are passed through directly, while any
      pair-of-`int` seeds (for stateless samplers) are split into independent
      seeds for each sample.
    sample_shape: `int` `Tensor` shape of iid samples to draw.
  Returns:
    iid_sample_fn: Python `callable` taking the same arguments as `sample_fn`
      and returning iid samples. Each returned `Tensor` will have shape
      `concat([sample_shape, shape_of_original_returned_tensor])`.
  """
  sample_shape = distribution_util.expand_to_vector(
      ps.cast(sample_shape, np.int32), tensor_name='sample_shape')
  n = ps.cast(ps.reduce_prod(sample_shape), dtype=np.int32)

  def unflatten(x):
    unflattened_shape = ps.cast(
        ps.concat([sample_shape, ps.shape(x)[1:]], axis=0),
        dtype=np.int32)
    return tf.reshape(x, unflattened_shape)

  def iid_sample_fn(*args, **kwargs):
    """Draws iid samples from `fn`."""

    with tf.name_scope('iid_sample_fn'):

      seed = kwargs.pop('seed', None)
      if samplers.is_stateful_seed(seed):
        kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
        def pfor_loop_body(_):
          with tf.name_scope('iid_sample_fn_stateful_body'):
            return sample_fn(*args, **kwargs)
      else:
        # If a stateless seed arg is passed, split it into `n` different
        # stateless seeds, so that we don't just get a bunch of copies of the
        # same sample.
        if not JAX_MODE:
          warnings.warn(
              'Saw Tensor seed {}, implying stateless sampling. Autovectorized '
              'functions that use stateless sampling may be quite slow because '
              'the current implementation falls back to an explicit loop. This '
              'will be fixed in the future. For now, you will likely see '
              'better performance from stateful sampling, which you can invoke '
              'by passing a Python `int` seed.'.format(seed))
        seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
        def pfor_loop_body(i):
          with tf.name_scope('iid_sample_fn_stateless_body'):
            return sample_fn(*args, seed=tf.gather(seed, i), **kwargs)

      draws = parallel_for.pfor(pfor_loop_body, n)
      return tf.nest.map_structure(unflatten, draws, expand_composites=True)

  return iid_sample_fn
Exemple #11
0
def iid_sample(sample_fn, sample_shape):
    """Lift a sampling function to one that draws multiple iid samples.

  Args:
    sample_fn: Python `callable` that returns a (possibly nested) structure of
      `Tensor`s. May optionally take a `seed` named arg: if so, any `int`
      seeds (for stateful samplers) are passed through directly, while any
      pair-of-`int` seeds (for stateless samplers) are split into independent
      seeds for each sample.
    sample_shape: `int` `Tensor` shape of iid samples to draw.
  Returns:
    iid_sample_fn: Python `callable` taking the same arguments as `sample_fn`
      and returning iid samples. Each returned `Tensor` will have shape
      `concat([sample_shape, shape_of_original_returned_tensor])`.
  """
    sample_shape = distribution_util.expand_to_vector(
        prefer_static.cast(sample_shape, np.int32), tensor_name='sample_shape')
    n = prefer_static.cast(prefer_static.reduce_prod(sample_shape),
                           dtype=np.int32)

    def unflatten(x):
        unflattened_shape = prefer_static.cast(prefer_static.concat(
            [sample_shape, prefer_static.shape(x)[1:]], axis=0),
                                               dtype=np.int32)
        return tf.reshape(x, unflattened_shape)

    def iid_sample_fn(*args, **kwargs):
        """Draws iid samples from `fn`."""

        pfor_loop_body = lambda _: sample_fn(*args, **kwargs)

        seed = kwargs.pop('seed', None)
        try:  # Assume that `seed` is a valid stateful seed (Python `int`).
            kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
            pfor_loop_body = lambda _: sample_fn(*args, **kwargs)
        except TypeError as e:
            # If a stateless seed arg is passed, split it into `n` different stateless
            # seeds, so that we don't just get a bunch of copies of the same sample.
            if TENSOR_SEED_MSG_PREFIX not in str(e):
                raise
            warnings.warn(
                'Saw non-`int` seed {}, implying stateless sampling. '
                'Autovectorized functions that use stateless sampling '
                'may be quite slow because the current implementation '
                'falls back to an explicit loop. This will be fixed in the '
                'future. For now, you will likely see better performance '
                'from stateful sampling, which you can invoke by passing a'
                'traditional Python `int` seed.'.format(seed))
            seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
            pfor_loop_body = (
                lambda i: sample_fn(*args, seed=tf.gather(seed, i), **kwargs))

        draws = parallel_for.pfor(pfor_loop_body, n)
        return tf.nest.map_structure(unflatten, draws, expand_composites=True)

    return iid_sample_fn
    def _sample_n(self, n, seed=None):
        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)
        n_batch = ps.reduce_prod(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)
        underlying_n_batch = ps.reduce_prod(underlying_batch_shape)

        # Left pad underlying shape with any necessary ones.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)

        # Determine how many underlying samples to produce.
        n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch)
        samps = self.distribution.sample([n, n_bcast_samples], seed=seed)

        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)

        event_shape = self.event_shape_tensor()
        event_rank = ps.rank_from_shape(event_shape)
        shp = ps.concat([[n],
                         ps.where(is_dim_bcast, batch_shape, 1),
                         underlying_bcast_shp, event_shape],
                        axis=0)
        # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp.
        samps = tf.reshape(samps, shp)
        # Interleave broadcast and underlying axis indices for transpose.
        interleaved_batch_axes = ps.reshape(
            ps.stack([ps.range(batch_rank),
                      ps.range(batch_rank) + batch_rank],
                     axis=-1), [-1]) + 1

        event_axes = ps.range(event_rank) + (1 + 2 * batch_rank)
        perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0)
        samps = tf.transpose(samps, perm=perm)
        # Finally, reshape to the fully-broadcast batch shape.
        return tf.reshape(samps,
                          ps.concat([[n], batch_shape, event_shape], axis=0))
Exemple #13
0
def vectorize_over_batch_dims(
    fn, elems, event_shape, batch_shape, vectorized_map=True, fn_output_signature=None
):
    flat_batch_shape = tf.expand_dims(ps.reduce_prod(batch_shape), 0)
    flat_structure = reshape_structure(elems, event_shape, flat_batch_shape)
    if vectorized_map:
        result = tf.vectorized_map(fn, flat_structure, fallback_to_while_loop=False)
    else:
        assert fn_output_signature is not None
        result = tf.map_fn(fn, flat_structure, fn_output_signature=fn_output_signature)
    new_event_shape = tf.nest.map_structure(lambda elem: tf.shape(elem)[1:], result)
    return reshape_structure(result, new_event_shape, batch_shape)
Exemple #14
0
            def update_running_variance():
                diags = [
                    variance_part.variance()
                    for variance_part in variance_parts
                ]
                new_state_parts = tf.nest.flatten(new_state)
                new_variance_parts = []
                for variance_part, diag, state_part in zip(
                        variance_parts, diags, new_state_parts):
                    # Compute new variance for each variance part, accounting for partial
                    # batching of the variance calculation across chains (ie, some, all,
                    # or none of the chains may share the estimated mass matrix).
                    #
                    # For example, say
                    #
                    # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                    # variance_part has shape          [4] + [5, 6]
                    # log_prob has shape         [2, 3, 4]
                    #
                    # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                    # matrices, each being shared across a [2, 3]-batch of chains. Note
                    # this division is inferred from the shapes of the state part, the
                    # log_prob, and the user-provided initial running variances.
                    #
                    # Until RunningVariance supports rank > 1 chunking, we need to flatten
                    # the states that go into updating the variance estimates. In the
                    # above example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                    # fed to `RunningVariance.update(state_part, axis=0)`, recording
                    # 6 new observations in the running variance calculation.
                    # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                    # the resulting momentum distribution will have batch shape of
                    # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                    state_rank = ps.rank(state_part)
                    variance_rank = ps.rank(diag)
                    num_reduce_dims = state_rank - variance_rank

                    state_part_shape = ps.shape(state_part)
                    # This reshape adds a 1 when reduce_dims==0, and collapses all the
                    # lead dimensions to a single one otherwise.
                    reshaped_state = ps.reshape(
                        state_part,
                        ps.concat([[
                            ps.reduce_prod(state_part_shape[:num_reduce_dims])
                        ], state_part_shape[num_reduce_dims:]],
                                  axis=0))

                    # The `axis=0` here removes the leading dimension we got from the
                    # reshape above, so the new_variance_parts have the correct shape
                    # again.
                    new_variance_parts.append(
                        variance_part.update(reshaped_state, axis=0))
                return new_variance_parts
Exemple #15
0
 def _calculate_new_shape(self):
     # Try to get the old shape statically if available.
     original_shape = self._distribution.batch_shape
     if not tensorshape_util.is_fully_defined(original_shape):
         original_shape = self._distribution.batch_shape_tensor()
     # This is not a check for falseness, it's a check for exactly that shape.
     if original_shape == ():  # pylint: disable=g-explicit-bool-comparison
         # Force the size to be an integer, not a float, when the shape contains no
         # dtype information.
         original_size = 1
     else:
         original_size = ps.reduce_prod(original_shape)
     original_size = ps.cast(original_size, tf.int32)
     # Compute the new shape, filling in the `-1` dimension if present.
     new_shape = self._batch_shape_unexpanded
     implicit_dim_mask = ps.equal(new_shape, -1)
     size_implicit_dim = (original_size //
                          ps.maximum(1, -ps.reduce_prod(new_shape)))
     expanded_new_shape = ps.where(  # Assumes exactly one `-1`.
         implicit_dim_mask, size_implicit_dim, new_shape)
     # Return the original size on the side because one caller would otherwise
     # have to recompute it.
     return expanded_new_shape, original_size
    def resample(log_weights, current_state, particle_info, seed=None):
      """Resample particles based on importance weights."""
      with tf.name_scope('resample_particles'):
        seed = SeedStream(seed, salt='resample_particles')
        resampling_indexes = tf.random.categorical(
            [log_weights], ps.reduce_prod(*ps.shape(log_weights)), seed=seed())
        next_state = tf.nest.map_structure(
            lambda x: tf.reshape(tf.gather(x, resampling_indexes), ps.shape(x)),
            current_state)
        next_particle_info = tf.nest.map_structure(
            lambda x: tf.reshape(tf.gather(x, resampling_indexes), ps.shape(x)),
            particle_info)

        return next_state, next_particle_info
def _split_and_reshape_event(x, model):
  """Splits and reshapes a flat event `x` to match the structure of `model`."""
  splits = [
      ps.maximum(1, ps.reduce_prod(s))
      for s in tf.nest.flatten(model.event_shape)
  ]
  x = tf.nest.pack_sequence_as(model.event_shape, tf.split(x, splits, axis=-1))

  def _reshape_part(part, dtype, event_shape):
    part = tf.cast(part, dtype)
    new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
    return tf.reshape(part, ps.cast(new_shape, tf.int32))

  x = tf.nest.map_structure(_reshape_part, x, model.dtype, model.event_shape)
  return x
Exemple #18
0
    def _split_and_reshape_event(self, x):
        splits = [
            ps.maximum(1, ps.reduce_prod(s))
            for s in tf.nest.flatten(self._model.event_shape)
        ]
        x = tf.nest.pack_sequence_as(self._model.event_shape,
                                     tf.split(x, splits, axis=-1))

        def _reshape_part(part, dtype, event_shape):
            part = tf.cast(part, dtype)
            new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
            return tf.reshape(part, ps.cast(new_shape, tf.int32))

        x = tf.nest.map_structure(_reshape_part, x, self._model.dtype,
                                  self._model.event_shape)
        return x
Exemple #19
0
def make_momentum_distribution(state_parts,
                               batch_shape,
                               running_variance_parts=None):
    """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    state_parts: List of `Tensor`.
    batch_shape: Batch shape.
    running_variance_parts: Optional, list of `Tensor`
       outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults
       to ones with the same shape as state_parts.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    if running_variance_parts is None:
        running_variance_parts = tf.nest.map_structure(tf.ones_like,
                                                       state_parts)
    distributions = []
    batch_ndims = ps.rank_from_shape(batch_shape)
    for variance_part, state_part in zip(running_variance_parts, state_parts):
        event_shape = state_part.shape[batch_ndims:]
        if not tensorshape_util.is_fully_defined(event_shape):
            event_shape = ps.shape(state_part,
                                   name='state_part_shp')[batch_ndims:]
        variance_tiled = tf.broadcast_to(
            variance_part, ps.concat([batch_shape, event_shape], axis=0))
        nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32)
        variance_flattened = tf.reshape(
            variance_tiled, ps.concat([batch_shape, [nevt]], axis=0))

        distribution = _CompositeTransformedDistribution(
            bijector=_CompositeReshape(event_shape_out=event_shape,
                                       name='reshape_mvnpfl'),
            distribution=(
                _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                    precision_factor=_CompositeLinearOperatorDiag(
                        tf.math.sqrt(variance_flattened)),
                    precision=_CompositeLinearOperatorDiag(variance_flattened),
                    name='momentum')))
        distributions.append(distribution)
    return maybe_make_list_and_batch_broadcast(
        _CompositeJointDistributionSequential(distributions), batch_shape)
def _make_momentum_distribution(running_variance_parts, state_parts,
                                batch_ndims):
    """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    running_variance_parts: List of `Tensor`, outputs of
      `tfp.experimental.stats.RunningVariance.variance()`.
    state_parts: List of `Tensor`.
    batch_ndims: Scalar, for leading batch dimensions.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    distributions = []
    for variance_part, state_part in zip(running_variance_parts, state_parts):
        running_variance_rank = ps.rank(variance_part)
        state_rank = ps.rank(state_part)
        event_shape = ps.shape(state_part)[batch_ndims:]
        nevt = ps.reduce_prod(event_shape)
        # Pad dimensions and tile by multiplying by tf.ones to add a batch shape
        ones = tf.ones(
            ps.shape(state_part)[:-(state_rank - running_variance_rank)],
            dtype=variance_part.dtype)
        ones = bu.left_justified_expand_dims_like(ones, state_part)
        variance_tiled = ones * variance_part
        variance_flattened = tf.reshape(
            variance_tiled,
            ps.concat([ps.shape(variance_tiled)[:batch_ndims], [nevt]],
                      axis=0))

        distributions.append(
            _CompositeTransformedDistribution(
                bijector=_CompositeReshape(event_shape_out=event_shape,
                                           event_shape_in=[nevt]),
                distribution=(
                    _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                        precision_factor=_CompositeLinearOperatorDiag(
                            tf.math.sqrt(variance_flattened)),
                        precision=_CompositeLinearOperatorDiag(
                            variance_flattened)))))
    return _CompositeJointDistributionSequential(distributions)
Exemple #21
0
def _compute_fans_from_shape(shape, batch_ndims=0):
  """Extracts `fan_in, fan_out` from specified shape `Tensor`."""
  # Ensure shape is a vector of length >=2.
  num_pad = prefer_static.maximum(0, 2 - prefer_static.size(shape))
  shape = prefer_static.pad(
      shape, paddings=[[0, num_pad]], constant_values=1)
  (
      batch_shape,  # pylint: disable=unused-variable
      extra_shape,
      fan_in,
      fan_out,
  ) = prefer_static.split(shape, [batch_ndims, -1, 1, 1])
  # The following logic is primarily intended for convolutional layers which
  # have spatial semantics in addition to input/output channels.
  receptive_field_size = prefer_static.reduce_prod(extra_shape)
  fan_in = fan_in[0] * receptive_field_size
  fan_out = fan_out[0] * receptive_field_size
  return fan_in, fan_out
Exemple #22
0
def _check_at_least_two_chains(accept_prob, reduce_chain_axis_names,
                               validate_args, message):
    """Checks that the number of chains is at least 2."""
    # Number of total chains is local batch size * distributed axis size
    local_axis_size = ps.size(accept_prob)
    distributed_axis_size = int(
        ps.reduce_prod([
            distribute_lib.get_axis_size(a) for a in reduce_chain_axis_names
        ]))
    num_chains = local_axis_size * distributed_axis_size
    num_chains_ = tf.get_static_value(num_chains)
    if num_chains_ is not None:
        if num_chains_ < 2:
            raise ValueError('{} Got: {}'.format(message, num_chains_))
    elif validate_args:
        with tf.control_dependencies(
            [assert_util.assert_greater_equal(num_chains, 2, message)]):
            accept_prob = tf.identity(accept_prob)
    return accept_prob
    def preprocess_state(init_state):
      """Initial preprocessing at Stage 0."""
      dimension = ps.reduce_sum([
          ps.reduce_prod(ps.shape(x)[1:]) for x in init_state])
      likelihood_log_prob = likelihood_log_prob_fn(*init_state)

      # Default to the optimal for normal distributed targets.
      # TODO(b/152412213): Revisit this default parameter.
      scale_start = (
          tf.constant(2.38 ** 2, dtype=likelihood_log_prob.dtype) /
          tf.constant(dimension, dtype=likelihood_log_prob.dtype))
      # TODO(b/152412213): Enable batch of batches style by using non-scalar
      # inverse_temperature
      inverse_temperature = tf.zeros([], dtype=likelihood_log_prob.dtype)
      scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(scale_start, 1.)
      kernel = make_kernel_fn(
          _make_tempered_target_log_prob_fn(
              prior_log_prob_fn,
              likelihood_log_prob_fn,
              inverse_temperature),
          init_state,
          scalings,
          seed=seed_stream())
      pkr = kernel.bootstrap_results(current_state)
      _, kernel_target_log_prob = gather_mh_like_result(pkr)

      particle_info = ParticleInfo(
          log_accept_prob=ps.zeros_like(likelihood_log_prob),
          log_scalings=tf.math.log(scalings),
          tempered_log_prob=kernel_target_log_prob,
          likelihood_log_prob=likelihood_log_prob,
      )

      return SMCResults(
          num_steps=tf.convert_to_tensor(
              max_num_steps, dtype=tf.int32, name='num_steps'),
          inverse_temperature=inverse_temperature,
          log_marginal_likelihood=tf.constant(
              0., dtype=likelihood_log_prob.dtype),
          particle_info=particle_info
      )
Exemple #24
0
    def sample(self, sample_shape=(), seed=None, name=None):
        with tf.name_scope(name or 'sample'):
            # Grab the required number of values from the provided tensors.
            sample_shape = dist_util.expand_to_vector(sample_shape)
            n = ps.cast(ps.reduce_prod(sample_shape), dtype=tf.int32)

            # Check that we're not trying to draw too many samples.
            assertions = []
            will_overflow_ = tf.get_static_value(n > self.max_num_samples)
            if will_overflow_:
                raise ValueError(
                    'Trying to draw {} samples from a '
                    '`DeterministicEmpirical` instance for which only {} '
                    'samples were provided.'.format(
                        tf.get_static_value(n),
                        tf.get_static_value(self.max_num_samples)))
            elif (will_overflow_ is None  # Couldn't determine statically.
                  and self.validate_args):
                assertions.append(
                    tf.debugging.assert_less_equal(
                        n,
                        self.max_num_samples,
                        message='Number of samples to draw '
                        'from a `DeterministicEmpirical` instance must not exceed the '
                        'number provided at construction.'))

            # Extract the appropriate number of sampled values.
            with tf.control_dependencies(assertions):
                sampled = tf.nest.map_structure(lambda x: x[:n, ...],
                                                self.values_with_sample_dim)

            # Reshape the values to the appropriate sample shape.
            return tf.nest.map_structure(
                lambda x: tf.reshape(
                    x,  # pylint: disable=g-long-lambda
                    ps.concat([
                        ps.cast(sample_shape, tf.int32),
                        ps.cast(ps.shape(x)[1:], tf.int32)
                    ],
                              axis=0)),
                sampled)
Exemple #25
0
def _kl_sample(a, b, name='kl_sample'):
    """Batched KL divergence `KL(a || b)` for Sample distributions.

  We can leverage the fact that:

  ```
  KL(Sample(a) || Sample(b)) = sum(KL(a || b))
  ```

  where the sum is over the `sample_shape` dims.

  Args:
    a: Instance of `Sample` distribution.
    b: Instance of `Sample` distribution.
    name: (optional) name to use for created ops.
      Default value: `"kl_sample"`'.

  Returns:
    kldiv: Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the `sample_shape` of `a` and `b` don't match.
  """
    assertions = []
    a_ss = tf.get_static_value(a.sample_shape)
    b_ss = tf.get_static_value(b.sample_shape)
    msg = '`a.sample_shape` must be identical to `b.sample_shape`.'
    if a_ss is not None and b_ss is not None:
        if not np.array_equal(a_ss, b_ss):
            raise ValueError(msg)
    elif a.validate_args or b.validate_args:
        assertions.append(
            assert_util.assert_equal(a.sample_shape,
                                     b.sample_shape,
                                     message=msg))
    with tf.control_dependencies(assertions):
        kl = kullback_leibler.kl_divergence(a.distribution,
                                            b.distribution,
                                            name=name)
        n = ps.reduce_prod(a.sample_shape)
        return tf.cast(x=n, dtype=kl.dtype) * kl
Exemple #26
0
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    A `tfb.Bijector` where the `.forward` method flattens and unconstrains
    points.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    to_chain = [jd_model.experimental_default_event_space_bijector()]
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())
    to_chain.append(flat_bijector)

    unconstrained_shapes = flat_bijector.inverse_event_shape_tensor(
        jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    reshapers = [
        reshape.Reshape(event_shape_out=x, event_shape_in=[-1])
        for x in unconstrained_shapes
    ]
    to_chain.append(joint_map.JointMap(bijectors=reshapers))

    size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes]
    to_chain.append(split.Split(num_or_size_splits=size_splits))

    return invert.Invert(chain.Chain(to_chain))
    def _log_prob(self, x):
        assertions = []
        message = 'Input must have at least one dimension.'
        if tensorshape_util.rank(x.shape) is not None:
            if tensorshape_util.rank(x.shape) == 0:
                raise ValueError(message)
        elif self.validate_args:
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=message))
        with tf.control_dependencies(assertions):
            event_tensors = self._distribution.event_shape_tensor()
            splits = [
                ps.maximum(1, ps.reduce_prod(s))
                for s in tf.nest.flatten(event_tensors)
            ]
            x = tf.nest.pack_sequence_as(event_tensors,
                                         tf.split(x, splits, axis=-1))

            def _reshape_part(part, dtype, event_shape):
                part = tf.cast(part, dtype)
                static_rank = tf.get_static_value(
                    ps.rank_from_shape(event_shape))
                if static_rank == 1:
                    return part
                new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                                      axis=-1)
                return tf.reshape(part, ps.cast(new_shape, tf.int32))

            if all(
                    tensorshape_util.is_fully_defined(s)
                    for s in tf.nest.flatten(self._distribution.event_shape)):
                x = tf.nest.map_structure(_reshape_part, x,
                                          self._distribution.dtype,
                                          self._distribution.event_shape)
            else:
                x = tf.nest.map_structure(
                    _reshape_part, x, self._distribution.dtype,
                    self._distribution.event_shape_tensor())

            return self._distribution.log_prob(x)
Exemple #28
0
def _make_flatten_unflatten_fns(batch_shape):
    """Builds functions for flattening and unflattening batch dimensions."""
    batch_shape = tuple(batch_shape)
    batch_rank = len(batch_shape)
    ndims = ps.cast(ps.reduce_prod(batch_shape), tf.int32)

    def flatten_fn(x):
        x_shape = tuple(x.shape)
        if x_shape[:batch_rank] != batch_shape:
            raise ValueError(
                'Expected batch-shape=%s; received array of shape=%s' %
                (batch_shape, x_shape))
        flat_shape = (ndims, ) + x_shape[batch_rank:]
        return tf.reshape(x, flat_shape)

    def unflatten_fn(x):
        x_shape = tuple(x.shape)
        if x_shape[0] != ndims:
            raise ValueError('Expected batch-size=%d; received shape=%s' %
                             (ndims, x_shape))
        return tf.reshape(x, batch_shape + x_shape[1:])

    return flatten_fn, unflatten_fn
def _make_vector_event_space_bijector(model):
  """Creates a vector bijector that constrains like the structured model."""

  # TODO(siege): Make this work with multi-part default_event_bijector.
  def _make_reshaped_bijector(b, s):
    return tfb.Chain([
        tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]),
        b,
        tfb.Reshape(
            event_shape_in=[ps.reduce_prod(b.inverse_event_shape(s))],
            event_shape_out=b.inverse_event_shape(s)),
    ])

  reshaped_bijector = tf.nest.map_structure(_make_reshaped_bijector,
                                            model.default_event_space_bijector,
                                            model.event_shape)

  return tfb.Blockwise(
      bijectors=tf.nest.flatten(reshaped_bijector),
      block_sizes=tf.nest.flatten(
          tf.nest.map_structure(
              lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)),  # pylint: disable=g-long-lambda
              model.default_event_space_bijector,
              model.event_shape)))
Exemple #30
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        df = tf.convert_to_tensor(self.df)
        batch_shape = self._batch_shape_tensor(df)
        event_shape = self._event_shape_tensor()
        dimension = self._dimension()
        x_ndims = ps.rank(x_sqrt)
        num_singleton_axes_to_prepend = (
            ps.maximum(ps.size(batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = ps.concat([
            ps.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            ps.shape(x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = ps.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - ps.size(batch_shape) - 2
        sample_shape = ps.shape(x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = ps.concat(
            [ps.range(sample_ndims, ndims),
             ps.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            ps.cast(dimension, dtype=tf.int32) *
            ps.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
        shape = ps.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [ps.cast(dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self._scale.solve(scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = ps.concat(
            [ps.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
            axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = ps.concat([
            ps.range(ndims - sample_ndims, ndims),
            ps.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt),
                                          axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((df - dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x -
                    self._log_normalization(df=df, scale=self._scale))

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob