def slice_batch_shape_tensor(base_shape, event_ndims):
  base_shape = ps.convert_to_shape_tensor(base_shape, dtype_hint=np.int32)
  event_ndims = ps.convert_to_shape_tensor(event_ndims, dtype_hint=np.int32)
  base_rank = ps.rank_from_shape(base_shape)
  return base_shape[:(base_rank -
                      # Don't try to slice away more ndims than the parameter
                      # actually has, if that's fewer than `event_ndims` (i.e.,
                      # if it relies on broadcasting).
                      ps.minimum(event_ndims, base_rank))]
Exemple #2
0
def _truncate_shape_tensor(shape, ndims_to_truncate):
    shape = ps.convert_to_shape_tensor(shape, dtype_hint=np.int32)
    ndims_to_truncate = ps.convert_to_shape_tensor(ndims_to_truncate,
                                                   dtype_hint=np.int32)
    base_rank = ps.rank_from_shape(shape)
    return shape[:(
        base_rank -
        # Don't try to slice away more ndims than the parameter
        # actually has, if that's fewer than `event_ndims` (i.e.,
        # if it relies on broadcasting).
        ps.minimum(ndims_to_truncate, base_rank))]
Exemple #3
0
def rademacher(shape, dtype=tf.float32, seed=None, name=None):
    """Generates `Tensor` consisting of `-1` or `+1`, chosen uniformly at random.

  For more details, see [Rademacher distribution](
  https://en.wikipedia.org/wiki/Rademacher_distribution).

  Args:
    shape: Vector-shaped, `int` `Tensor` representing shape of output.
    dtype: (Optional) TF `dtype` representing `dtype` of output.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'random_rademacher').

  Returns:
    rademacher: `Tensor` with specified `shape` and `dtype` consisting of `-1`
      or `+1` chosen uniformly-at-random.
  """
    with tf.name_scope(name or 'rademacher'):
        # Choose the dtype to cause `2 * random_bernoulli - 1` to run in the same
        # memory (host or device) as the downstream cast will want to put it.  The
        # convention on GPU is that int32 are in host memory and int64 are in device
        # memory.
        shape = ps.convert_to_shape_tensor(shape)
        generation_dtype = tf.int64 if tf.as_dtype(
            dtype) != tf.int32 else tf.int32
        random_bernoulli = samplers.uniform(shape,
                                            minval=0,
                                            maxval=2,
                                            dtype=generation_dtype,
                                            seed=seed)
        return tf.cast(2 * random_bernoulli - 1, dtype)
Exemple #4
0
def _random_binomial(
    shape,
    counts,
    probs,
    output_dtype=tf.float32,
    seed=None,
    name=None):
  """Sample a binomial, CPU specialized to stateless_binomial.

  Args:
    shape: Shape of the full sample output. Trailing dims should match the
      broadcast shape of `counts` with `probs|logits`.
    counts: Batch of total_count.
    probs: Batch of p(success).
    output_dtype: DType of samples.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Optional name for related ops.

  Returns:
    samples: Samples from binomial distributions.
    runtime_used_for_sampling: One of `implementation_selection._RUNTIME_*`.
  """
  with tf.name_scope(name or 'random_binomial'):
    seed = samplers.sanitize_seed(seed)
    shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape')
    params = dict(shape=shape, counts=counts, probs=probs,
                  output_dtype=output_dtype, seed=seed, name=name)
    sampler_impl = implementation_selection.implementation_selecting(
        fn_name='binomial',
        default_fn=_random_binomial_noncpu,
        cpu_fn=_random_binomial_cpu)
    return sampler_impl(**params)
    def _call_sample_n(self, sample_shape, seed, name, **kwargs):
        # We override `_call_sample_n` rather than `_sample_n` so we can ensure that
        # the result of `self.bijector.forward` is not modified (and thus caching
        # works).
        with self._name_and_control_scope(name):
            sample_shape = ps.convert_to_shape_tensor(sample_shape,
                                                      dtype=tf.int32,
                                                      name='sample_shape')
            sample_shape, n = self._expand_sample_shape_to_vector(
                sample_shape, 'sample_shape')

            distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(
                kwargs)

            # First, generate samples from the base distribution.
            x = self.distribution.sample(sample_shape=[n],
                                         seed=seed,
                                         **distribution_kwargs)

            # Next, we reshape `x` into its final form. We do this prior to the call
            # to the bijector to ensure that the bijector caching works.
            def reshape_sample_shape(t):
                batch_event_shape = ps.shape(t)[1:]
                final_shape = ps.concat([sample_shape, batch_event_shape], 0)
                return tf.reshape(t, final_shape)

            x = tf.nest.map_structure(reshape_sample_shape, x)

            # Finally, we apply the bijector's forward transformation. For caching to
            # work, it is imperative that this is the last modification to the
            # returned result.
            y = self.bijector.forward(x, **bijector_kwargs)
            y = self._set_sample_static_shape(y, sample_shape)

            return y
Exemple #6
0
def random_von_mises(shape, concentration, dtype=tf.float32, seed=None):
  """Samples from the standardized von Mises distribution.

  The distribution is vonMises(loc=0, concentration=concentration), so the mean
  is zero.
  The location can then be changed by adding it to the samples.

  The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1].
  The samples are pathwise differentiable using the approach of [2].

  Args:
    shape: The output sample shape.
    concentration: The concentration parameter of the von Mises distribution.
    dtype: The data type of concentration and the outputs.
    seed: (optional) The random seed.

  Returns:
    Differentiable samples of standardized von Mises.

  References:
    [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag,
    1986; Chapter 9, p. 473-476.
    http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
    + corrections http://www.nrbook.com/devroye/Devroye_files/errors.pdf
    [2] Michael Figurnov, Shakir Mohamed, Andriy Mnih. "Implicit
    Reparameterization Gradients", 2018.
  """
  shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape')
  seed = samplers.sanitize_seed(seed, salt='von_mises')
  concentration = tf.convert_to_tensor(
      concentration, dtype=dtype, name='concentration')

  return _von_mises_sample_with_gradient(shape, concentration, seed)
Exemple #7
0
def random_gamma_with_runtime(shape,
                              concentration,
                              rate=None,
                              log_rate=None,
                              seed=None,
                              log_space=False):
    """Returns both a sample and the id of the implementation-selected runtime."""
    # This method exists chiefly for testing purposes.
    dtype = dtype_util.common_dtype([concentration, rate, log_rate],
                                    tf.float32)
    concentration = tf.convert_to_tensor(concentration, dtype=dtype)
    shape = ps.convert_to_shape_tensor(shape,
                                       dtype_hint=tf.int32,
                                       name='shape')

    if rate is not None and log_rate is not None:
        raise ValueError(
            'At most one of `rate` and `log_rate` may be specified.')
    if rate is not None:
        rate = tf.convert_to_tensor(rate, dtype=dtype)
    if log_rate is not None:
        log_rate = tf.convert_to_tensor(log_rate, dtype=dtype)
    total_shape = ps.concat([
        shape,
        ps.broadcast_shape(ps.shape(concentration),
                           _shape_or_scalar(rate, log_rate))
    ],
                            axis=0)
    seed = samplers.sanitize_seed(seed, salt='random_gamma')
    return _random_gamma_gradient(total_shape, concentration, rate, log_rate,
                                  seed, log_space)
Exemple #8
0
def _might_have_excess_ndims(flat_value, flat_core_ndims):
    for v, nd in zip(flat_value, flat_core_ndims):
        static_excess_ndims = (0 if v is None else tf.get_static_value(
            ps.convert_to_shape_tensor(ps.rank(v) - nd)))
        if static_excess_ndims is None or static_excess_ndims > 0:
            return True
    return False
Exemple #9
0
def _random_poisson(
    shape,
    rates=None,
    log_rates=None,
    output_dtype=tf.float32,
    seed=None,
    name=None):
  """Sample a poisson, CPU specialized to stateless_poisson.

  Args:
    shape: Shape of the full sample output. Trailing dims should match the
      broadcast shape of `counts` with `probs|logits`.
    rates: Batch of rates for Poisson distribution.
    log_rates: Batch of log rates for Poisson distribution.
    output_dtype: DType of samples.
    seed: int or Tensor seed.
    name: Optional name for related ops.

  Returns:
    samples: Samples from poisson distributions.
    runtime_used_for_sampling: One of `implementation_selection._RUNTIME_*`.
  """
  with tf.name_scope(name or 'random_poisson'):
    seed = samplers.sanitize_seed(seed)
    shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape')
    params = dict(shape=shape, rates=rates, log_rates=log_rates,
                  output_dtype=output_dtype, seed=seed, name=name)
    sampler_impl = implementation_selection.implementation_selecting(
        fn_name='poisson',
        default_fn=_random_poisson_noncpu,
        cpu_fn=_random_poisson_cpu)
    return sampler_impl(**params)
    def _sample_n(self, n, seed=None):
        loc = tf.convert_to_tensor(self.loc)
        scale = tf.convert_to_tensor(self.scale)
        tailweight = tf.convert_to_tensor(self.tailweight)
        skewness = tf.convert_to_tensor(self.skewness)
        ig_seed, normal_seed = samplers.split_seed(
            seed, salt='normal_inverse_gaussian')
        batch_shape = self._batch_shape_tensor(loc=loc,
                                               scale=scale,
                                               tailweight=tailweight,
                                               skewness=skewness)
        w = tailweight * tf.math.exp(
            0.5 * tf.math.log1p(-tf.math.square(skewness / tailweight)))
        w = tf.broadcast_to(w, batch_shape)
        ig_samples = inverse_gaussian.InverseGaussian(
            scale / w, tf.math.square(scale)).sample(n, seed=ig_seed)

        sample_shape = ps.concat([[n], batch_shape], axis=0)
        normal_samples = samplers.normal(
            shape=ps.convert_to_shape_tensor(sample_shape),
            mean=0.,
            stddev=1.,
            dtype=self.dtype,
            seed=normal_seed)
        return (loc + tf.math.sqrt(ig_samples) *
                (skewness * tf.math.sqrt(ig_samples) + normal_samples))
Exemple #11
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
  """Helper to validate block sizes."""
  block_sizes = ps.convert_to_shape_tensor(
      block_sizes, name='block_sizes', dtype_hint=tf.int32)
  block_sizes_shape = block_sizes.shape
  if tensorshape_util.is_fully_defined(block_sizes_shape):
    if (tensorshape_util.rank(block_sizes_shape) != 1 or
        (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))):
      raise ValueError(
          '`block_sizes` must be `None`, or a vector of the same length as '
          '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
          'length {}'.format(block_sizes_shape, len(bijectors)))
    return block_sizes

  elif validate_args:
    message = ('`block_sizes` must be `None`, or a vector of the same length '
               'as `bijectors`.')
    with tf.control_dependencies([
        assert_util.assert_equal(
            tf.size(block_sizes), len(bijectors), message=message),
        assert_util.assert_equal(tf.rank(block_sizes), 1)
    ]):
      block_sizes = tf.identity(block_sizes)

  # Set the shape if missing to pass statically known structure to split.
  tensorshape_util.set_shape(block_sizes, [len(bijectors)])
  return block_sizes
Exemple #12
0
    def _sample_n(self, n, seed=None):
        seed = samplers.sanitize_seed(seed, salt='gamma')

        return random_gamma(shape=ps.convert_to_shape_tensor([n]),
                            concentration=tf.convert_to_tensor(
                                self.concentration, self.dtype),
                            rate=tf.convert_to_tensor(self.rate, self.dtype),
                            seed=seed)
def _squeeze(x, axis):
    """A version of squeeze that works with dynamic axis."""
    x = tf.convert_to_tensor(x, name='x')
    if axis is None:
        return tf.squeeze(x, axis=None)
    axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32)
    axis = axis + ps.zeros([1], dtype=axis.dtype)  # Make axis at least 1d.
    keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis)
    return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
Exemple #14
0
def _squeeze(x, axis):
    """A version of squeeze that works with dynamic axis."""
    x = tf.convert_to_tensor(x, name='x')
    if axis is None:
        return tf.squeeze(x, axis=None)
    axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32)
    axis = _make_list_or_1d_tensor(axis)  # Ensure at least 1d.
    keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis)
    return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
Exemple #15
0
 def _sample_n(self, n, seed=None):
     seed = samplers.sanitize_seed(seed)
     return random_poisson(shape=ps.convert_to_shape_tensor([n]),
                           rates=(None if self._rate is None else
                                  tf.convert_to_tensor(self._rate)),
                           log_rates=(None if self._log_rate is None else
                                      tf.convert_to_tensor(self._log_rate)),
                           output_dtype=self.dtype,
                           seed=seed)[0]
Exemple #16
0
  def __init__(
      self, num_or_size_splits, axis=-1, validate_args=False, name='split'):
    """Creates the bijector.

    Args:
      num_or_size_splits: Either a Python integer indicating the number of
        splits along `axis` or a 1-D integer `Tensor` or Python list containing
        the sizes of each output tensor along `axis`. If a list/`Tensor`, it may
        contain at most one value of `-1`, which indicates a split size that is
        unknown and determined from input.
      axis: A negative integer or scalar `int32` `Tensor`. The dimension along
        which to split. Must be negative to enable the bijector to support
        arbitrary batch dimensions. Defaults to -1 (note that this is different
        from the `tf.Split` default of `0`). Must be statically known.
      validate_args: Python `bool` indicating whether arguments should
        be checked for correctness.
      name: Python `str`, name given to ops managed by this object.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:

      if isinstance(num_or_size_splits, numbers.Integral):
        self._num_splits = num_or_size_splits
        self._split_sizes = None
      else:
        self._split_sizes = tensor_util.convert_nonref_to_tensor(
            num_or_size_splits,
            name='num_or_size_splits',
            dtype=tf.int32,
            as_shape_tensor=True)

        if tensorshape_util.rank(self._split_sizes.shape) != 1:
          raise ValueError(
              '`num_or_size_splits` must be an integer or 1-D `Tensor`.')

        num_splits = tensorshape_util.as_list(self._split_sizes.shape)[0]
        if num_splits is None:
          raise ValueError('If `num_or_size_splits` is a vector of split sizes '
                           'it must have a statically-known number of '
                           'elements.')
        self._num_splits = num_splits

      static_axis = tf.get_static_value(axis)
      if static_axis is None:
        raise ValueError('`axis` must be statically known.')
      if static_axis >= 0:
        raise ValueError('`axis` must be negative. Got {}'.format(axis))

      self._axis = ps.convert_to_shape_tensor(axis, tf.int32)

      super(Split, self).__init__(
          forward_min_event_ndims=-axis,
          inverse_min_event_ndims=[-axis] * self.num_splits,
          is_constant_jacobian=True,
          validate_args=validate_args,
          parameters=parameters,
          name=name)
Exemple #17
0
  def _sample_n(self, n, seed=None):
    seed = samplers.sanitize_seed(seed, salt='gamma')

    return random_gamma(
        shape=ps.convert_to_shape_tensor([n]),
        concentration=tf.convert_to_tensor(self.concentration),
        rate=None if self.rate is None else tf.convert_to_tensor(self.rate),
        log_rate=(None if self.log_rate is None else
                  tf.convert_to_tensor(self.log_rate)),
        seed=seed)
Exemple #18
0
 def _sample_n(self, n, seed=None):
     seed = samplers.sanitize_seed(seed, salt='binomial')
     return _random_binomial(shape=ps.convert_to_shape_tensor([n]),
                             counts=tf.convert_to_tensor(self._total_count),
                             probs=(None if self._probs is None else
                                    tf.convert_to_tensor(self._probs)),
                             logits=(None if self._logits is None else
                                     tf.convert_to_tensor(self._logits)),
                             output_dtype=self.dtype,
                             seed=seed)[0]
def normal_generator(shape):
  shape = ps.convert_to_shape_tensor(shape, dtype=np.int32)
  loc = yield trainable_state_util.Parameter(
      init_fn=functools.partial(samplers.normal, shape=shape),
      name='loc')
  bij = tfb.Softplus()
  scale = yield trainable_state_util.Parameter(
      init_fn=lambda seed: bij.forward(samplers.normal(shape, seed=seed)),
      constraining_bijector=bij,
      name='scale')
  return tfd.Normal(loc=loc, scale=scale, validate_args=True)
Exemple #20
0
def random_gamma(shape, concentration, rate, seed=None):
    shape = ps.convert_to_shape_tensor(shape,
                                       dtype_hint=tf.int32,
                                       name='shape')

    total_shape = ps.concat(
        [shape,
         ps.broadcast_shape(ps.shape(concentration), ps.shape(rate))],
        axis=0)
    seed = samplers.sanitize_seed(seed, salt='random_gamma')
    return _random_gamma_gradient(total_shape, concentration, rate, seed)
Exemple #21
0
 def _dimension(self):
     """Scalar dimension of underlying vector space."""
     with tf.name_scope('dimension'):
         if tf.compat.dimension_value(self._scale.shape[-1]) is None:
             return tf.cast(self._scale.domain_dimension_tensor(),
                            dtype=self._scale.dtype,
                            name='dimension')
         else:
             return ps.convert_to_shape_tensor(tf.compat.dimension_value(
                 self._scale.shape[-1]),
                                               dtype=self._scale.dtype,
                                               name='dimension')
Exemple #22
0
 def sample(self, sample_shape=(), seed=None, name='sample'):  # pylint: disable=unused-argument
     return tf.zeros(
         ps.concat(
             [
                 # sample_shape might be a scalar
                 ps.reshape(ps.convert_to_shape_tensor(
                     sample_shape, tf.int32),
                            shape=[-1]),
                 self.batch_shape_tensor(),
                 self.event_shape_tensor()
             ],
             axis=0))
 def _expand_x_fn(tensor):
     # Reshape tensor to tensor.shape + [1] * M.
     extended_shape = ps.concat(
         [
             ps.shape(tensor),
             ps.ones_like(
                 ps.convert_to_shape_tensor(
                     ps.shape_slice(y_ref, np.s_[batch_dims + nd:])))
         ],
         axis=0,
     )
     return tf.reshape(tensor, extended_shape)
Exemple #24
0
    def _sample_n(self, n, seed=None):
        seed = samplers.sanitize_seed(seed, salt='inverse_gaussian')

        loc = tf.convert_to_tensor(self.loc)
        concentration = tf.convert_to_tensor(self.concentration)
        total_shape = ps.concat([
            ps.convert_to_shape_tensor([n]),
            self._batch_shape_tensor(loc=loc, concentration=concentration)
        ],
                                axis=0)
        return _random_inverse_gaussian_gradient(total_shape, loc,
                                                 concentration, seed)
Exemple #25
0
    def _sample_n(self, n, seed=None):
        seed = samplers.sanitize_seed(seed, salt='binomial')
        total_count = tf.convert_to_tensor(self._total_count)
        if self._probs is None:
            probs = self._probs_parameter_no_checks(total_count=total_count)
        else:
            probs = tf.convert_to_tensor(self._probs)

        return _random_binomial(shape=ps.convert_to_shape_tensor([n]),
                                counts=total_count,
                                probs=probs,
                                output_dtype=self.dtype,
                                seed=seed)[0]
Exemple #26
0
 def _num_samples_to_skip(self, call_counter):
     """Calculates how many samples to skip based on the call number."""
     # If `self.num_burnin_steps` is statically known to be 0,
     # `self.num_steps_between_results` will be returned outright.
     num_burnin_steps = ps.convert_to_shape_tensor(self.num_burnin_steps,
                                                   dtype_hint=tf.int32)
     num_burnin_steps_ = tf.get_static_value(num_burnin_steps)
     if num_burnin_steps_ == 0:
         return self.num_steps_between_results
     else:
         return (tf.where(tf.equal(call_counter, 0), num_burnin_steps, 0) +
                 tf.convert_to_tensor(self.num_steps_between_results,
                                      dtype_hint=tf.int32))
Exemple #27
0
    def event_shape_tensor(self, name='event_shape_tensor'):
        """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.

    Args:
      name: name to give to the op

    Returns:
      event_shape: `Tensor`.
    """
        with tf.name_scope(name):

            return ps.convert_to_shape_tensor(self.event_shape,
                                              name='event_shape')
 def _pad_sample_dims(self, x, event_ndims=None):
   with tf.name_scope('pad_sample_dims'):
     if event_ndims is None:
       event_ndims = self._event_ndims()
     ndims = ps.rank(x)
     # Must do the c_t_t in case ndims or event_ndims are Tensors and shape is
     # ndarray. Otherwise we get `TypeError: slice indices must be integers
     # or None or have an __index__ method`.
     shape = ps.convert_to_shape_tensor(ps.shape(x))
     d = ndims - event_ndims
     x = tf.reshape(
         x, shape=ps.concat([shape[:d], [1], shape[d:]], axis=0))
     return x
Exemple #29
0
 def convert_fn(path, value, dtype, dtype_hint, name=None):
     if not allow_packing and nest.is_nested(value) and any(
             # Treat arrays like Tensors for full parity in JAX backend.
             tf.is_tensor(x) or isinstance(x, np.ndarray)
             for x in nest.flatten(value)):
         raise NotImplementedError(
             ('Cannot convert a structure of tensors to a '
              'single tensor. Saw {} at path {}.').format(value, path))
     if as_shape_tensor:
         return ps.convert_to_shape_tensor(value,
                                           dtype,
                                           dtype_hint,
                                           name=name)
     else:
         return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
Exemple #30
0
def _make_list_or_1d_tensor(values):
    """Return a list (preferred) or 1d Tensor from values, if values.ndims < 2."""
    values = ps.convert_to_shape_tensor(values, name='values')
    values_ = tf.get_static_value(values)

    # Static didn't work.
    if values_ is None:
        # Cheap way to bring to at least 1d.
        return values + tf.zeros([1], dtype=values.dtype)

    # Static worked!
    if values_.ndim > 1:
        raise ValueError('values had > 1 dim: {}'.format(values_.shape))
    # Cheap way to bring to at least 1d.
    values_ = values_ + np.zeros([1], dtype=values_.dtype)
    return list(values_)