def slice_batch_shape_tensor(base_shape, event_ndims): base_shape = ps.convert_to_shape_tensor(base_shape, dtype_hint=np.int32) event_ndims = ps.convert_to_shape_tensor(event_ndims, dtype_hint=np.int32) base_rank = ps.rank_from_shape(base_shape) return base_shape[:(base_rank - # Don't try to slice away more ndims than the parameter # actually has, if that's fewer than `event_ndims` (i.e., # if it relies on broadcasting). ps.minimum(event_ndims, base_rank))]
def _truncate_shape_tensor(shape, ndims_to_truncate): shape = ps.convert_to_shape_tensor(shape, dtype_hint=np.int32) ndims_to_truncate = ps.convert_to_shape_tensor(ndims_to_truncate, dtype_hint=np.int32) base_rank = ps.rank_from_shape(shape) return shape[:( base_rank - # Don't try to slice away more ndims than the parameter # actually has, if that's fewer than `event_ndims` (i.e., # if it relies on broadcasting). ps.minimum(ndims_to_truncate, base_rank))]
def rademacher(shape, dtype=tf.float32, seed=None, name=None): """Generates `Tensor` consisting of `-1` or `+1`, chosen uniformly at random. For more details, see [Rademacher distribution]( https://en.wikipedia.org/wiki/Rademacher_distribution). Args: shape: Vector-shaped, `int` `Tensor` representing shape of output. dtype: (Optional) TF `dtype` representing `dtype` of output. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'random_rademacher'). Returns: rademacher: `Tensor` with specified `shape` and `dtype` consisting of `-1` or `+1` chosen uniformly-at-random. """ with tf.name_scope(name or 'rademacher'): # Choose the dtype to cause `2 * random_bernoulli - 1` to run in the same # memory (host or device) as the downstream cast will want to put it. The # convention on GPU is that int32 are in host memory and int64 are in device # memory. shape = ps.convert_to_shape_tensor(shape) generation_dtype = tf.int64 if tf.as_dtype( dtype) != tf.int32 else tf.int32 random_bernoulli = samplers.uniform(shape, minval=0, maxval=2, dtype=generation_dtype, seed=seed) return tf.cast(2 * random_bernoulli - 1, dtype)
def _random_binomial( shape, counts, probs, output_dtype=tf.float32, seed=None, name=None): """Sample a binomial, CPU specialized to stateless_binomial. Args: shape: Shape of the full sample output. Trailing dims should match the broadcast shape of `counts` with `probs|logits`. counts: Batch of total_count. probs: Batch of p(success). output_dtype: DType of samples. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Optional name for related ops. Returns: samples: Samples from binomial distributions. runtime_used_for_sampling: One of `implementation_selection._RUNTIME_*`. """ with tf.name_scope(name or 'random_binomial'): seed = samplers.sanitize_seed(seed) shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') params = dict(shape=shape, counts=counts, probs=probs, output_dtype=output_dtype, seed=seed, name=name) sampler_impl = implementation_selection.implementation_selecting( fn_name='binomial', default_fn=_random_binomial_noncpu, cpu_fn=_random_binomial_cpu) return sampler_impl(**params)
def _call_sample_n(self, sample_shape, seed, name, **kwargs): # We override `_call_sample_n` rather than `_sample_n` so we can ensure that # the result of `self.bijector.forward` is not modified (and thus caching # works). with self._name_and_control_scope(name): sample_shape = ps.convert_to_shape_tensor(sample_shape, dtype=tf.int32, name='sample_shape') sample_shape, n = self._expand_sample_shape_to_vector( sample_shape, 'sample_shape') distribution_kwargs, bijector_kwargs = self._kwargs_split_fn( kwargs) # First, generate samples from the base distribution. x = self.distribution.sample(sample_shape=[n], seed=seed, **distribution_kwargs) # Next, we reshape `x` into its final form. We do this prior to the call # to the bijector to ensure that the bijector caching works. def reshape_sample_shape(t): batch_event_shape = ps.shape(t)[1:] final_shape = ps.concat([sample_shape, batch_event_shape], 0) return tf.reshape(t, final_shape) x = tf.nest.map_structure(reshape_sample_shape, x) # Finally, we apply the bijector's forward transformation. For caching to # work, it is imperative that this is the last modification to the # returned result. y = self.bijector.forward(x, **bijector_kwargs) y = self._set_sample_static_shape(y, sample_shape) return y
def random_von_mises(shape, concentration, dtype=tf.float32, seed=None): """Samples from the standardized von Mises distribution. The distribution is vonMises(loc=0, concentration=concentration), so the mean is zero. The location can then be changed by adding it to the samples. The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1]. The samples are pathwise differentiable using the approach of [2]. Args: shape: The output sample shape. concentration: The concentration parameter of the von Mises distribution. dtype: The data type of concentration and the outputs. seed: (optional) The random seed. Returns: Differentiable samples of standardized von Mises. References: [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf + corrections http://www.nrbook.com/devroye/Devroye_files/errors.pdf [2] Michael Figurnov, Shakir Mohamed, Andriy Mnih. "Implicit Reparameterization Gradients", 2018. """ shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') seed = samplers.sanitize_seed(seed, salt='von_mises') concentration = tf.convert_to_tensor( concentration, dtype=dtype, name='concentration') return _von_mises_sample_with_gradient(shape, concentration, seed)
def random_gamma_with_runtime(shape, concentration, rate=None, log_rate=None, seed=None, log_space=False): """Returns both a sample and the id of the implementation-selected runtime.""" # This method exists chiefly for testing purposes. dtype = dtype_util.common_dtype([concentration, rate, log_rate], tf.float32) concentration = tf.convert_to_tensor(concentration, dtype=dtype) shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') if rate is not None and log_rate is not None: raise ValueError( 'At most one of `rate` and `log_rate` may be specified.') if rate is not None: rate = tf.convert_to_tensor(rate, dtype=dtype) if log_rate is not None: log_rate = tf.convert_to_tensor(log_rate, dtype=dtype) total_shape = ps.concat([ shape, ps.broadcast_shape(ps.shape(concentration), _shape_or_scalar(rate, log_rate)) ], axis=0) seed = samplers.sanitize_seed(seed, salt='random_gamma') return _random_gamma_gradient(total_shape, concentration, rate, log_rate, seed, log_space)
def _might_have_excess_ndims(flat_value, flat_core_ndims): for v, nd in zip(flat_value, flat_core_ndims): static_excess_ndims = (0 if v is None else tf.get_static_value( ps.convert_to_shape_tensor(ps.rank(v) - nd))) if static_excess_ndims is None or static_excess_ndims > 0: return True return False
def _random_poisson( shape, rates=None, log_rates=None, output_dtype=tf.float32, seed=None, name=None): """Sample a poisson, CPU specialized to stateless_poisson. Args: shape: Shape of the full sample output. Trailing dims should match the broadcast shape of `counts` with `probs|logits`. rates: Batch of rates for Poisson distribution. log_rates: Batch of log rates for Poisson distribution. output_dtype: DType of samples. seed: int or Tensor seed. name: Optional name for related ops. Returns: samples: Samples from poisson distributions. runtime_used_for_sampling: One of `implementation_selection._RUNTIME_*`. """ with tf.name_scope(name or 'random_poisson'): seed = samplers.sanitize_seed(seed) shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') params = dict(shape=shape, rates=rates, log_rates=log_rates, output_dtype=output_dtype, seed=seed, name=name) sampler_impl = implementation_selection.implementation_selecting( fn_name='poisson', default_fn=_random_poisson_noncpu, cpu_fn=_random_poisson_cpu) return sampler_impl(**params)
def _sample_n(self, n, seed=None): loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) tailweight = tf.convert_to_tensor(self.tailweight) skewness = tf.convert_to_tensor(self.skewness) ig_seed, normal_seed = samplers.split_seed( seed, salt='normal_inverse_gaussian') batch_shape = self._batch_shape_tensor(loc=loc, scale=scale, tailweight=tailweight, skewness=skewness) w = tailweight * tf.math.exp( 0.5 * tf.math.log1p(-tf.math.square(skewness / tailweight))) w = tf.broadcast_to(w, batch_shape) ig_samples = inverse_gaussian.InverseGaussian( scale / w, tf.math.square(scale)).sample(n, seed=ig_seed) sample_shape = ps.concat([[n], batch_shape], axis=0) normal_samples = samplers.normal( shape=ps.convert_to_shape_tensor(sample_shape), mean=0., stddev=1., dtype=self.dtype, seed=normal_seed) return (loc + tf.math.sqrt(ig_samples) * (skewness * tf.math.sqrt(ig_samples) + normal_samples))
def _validate_block_sizes(block_sizes, bijectors, validate_args): """Helper to validate block sizes.""" block_sizes = ps.convert_to_shape_tensor( block_sizes, name='block_sizes', dtype_hint=tf.int32) block_sizes_shape = block_sizes.shape if tensorshape_util.is_fully_defined(block_sizes_shape): if (tensorshape_util.rank(block_sizes_shape) != 1 or (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))): raise ValueError( '`block_sizes` must be `None`, or a vector of the same length as ' '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of ' 'length {}'.format(block_sizes_shape, len(bijectors))) return block_sizes elif validate_args: message = ('`block_sizes` must be `None`, or a vector of the same length ' 'as `bijectors`.') with tf.control_dependencies([ assert_util.assert_equal( tf.size(block_sizes), len(bijectors), message=message), assert_util.assert_equal(tf.rank(block_sizes), 1) ]): block_sizes = tf.identity(block_sizes) # Set the shape if missing to pass statically known structure to split. tensorshape_util.set_shape(block_sizes, [len(bijectors)]) return block_sizes
def _sample_n(self, n, seed=None): seed = samplers.sanitize_seed(seed, salt='gamma') return random_gamma(shape=ps.convert_to_shape_tensor([n]), concentration=tf.convert_to_tensor( self.concentration, self.dtype), rate=tf.convert_to_tensor(self.rate, self.dtype), seed=seed)
def _squeeze(x, axis): """A version of squeeze that works with dynamic axis.""" x = tf.convert_to_tensor(x, name='x') if axis is None: return tf.squeeze(x, axis=None) axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32) axis = axis + ps.zeros([1], dtype=axis.dtype) # Make axis at least 1d. keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis) return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
def _squeeze(x, axis): """A version of squeeze that works with dynamic axis.""" x = tf.convert_to_tensor(x, name='x') if axis is None: return tf.squeeze(x, axis=None) axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32) axis = _make_list_or_1d_tensor(axis) # Ensure at least 1d. keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis) return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
def _sample_n(self, n, seed=None): seed = samplers.sanitize_seed(seed) return random_poisson(shape=ps.convert_to_shape_tensor([n]), rates=(None if self._rate is None else tf.convert_to_tensor(self._rate)), log_rates=(None if self._log_rate is None else tf.convert_to_tensor(self._log_rate)), output_dtype=self.dtype, seed=seed)[0]
def __init__( self, num_or_size_splits, axis=-1, validate_args=False, name='split'): """Creates the bijector. Args: num_or_size_splits: Either a Python integer indicating the number of splits along `axis` or a 1-D integer `Tensor` or Python list containing the sizes of each output tensor along `axis`. If a list/`Tensor`, it may contain at most one value of `-1`, which indicates a split size that is unknown and determined from input. axis: A negative integer or scalar `int32` `Tensor`. The dimension along which to split. Must be negative to enable the bijector to support arbitrary batch dimensions. Defaults to -1 (note that this is different from the `tf.Split` default of `0`). Must be statically known. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. """ parameters = dict(locals()) with tf.name_scope(name) as name: if isinstance(num_or_size_splits, numbers.Integral): self._num_splits = num_or_size_splits self._split_sizes = None else: self._split_sizes = tensor_util.convert_nonref_to_tensor( num_or_size_splits, name='num_or_size_splits', dtype=tf.int32, as_shape_tensor=True) if tensorshape_util.rank(self._split_sizes.shape) != 1: raise ValueError( '`num_or_size_splits` must be an integer or 1-D `Tensor`.') num_splits = tensorshape_util.as_list(self._split_sizes.shape)[0] if num_splits is None: raise ValueError('If `num_or_size_splits` is a vector of split sizes ' 'it must have a statically-known number of ' 'elements.') self._num_splits = num_splits static_axis = tf.get_static_value(axis) if static_axis is None: raise ValueError('`axis` must be statically known.') if static_axis >= 0: raise ValueError('`axis` must be negative. Got {}'.format(axis)) self._axis = ps.convert_to_shape_tensor(axis, tf.int32) super(Split, self).__init__( forward_min_event_ndims=-axis, inverse_min_event_ndims=[-axis] * self.num_splits, is_constant_jacobian=True, validate_args=validate_args, parameters=parameters, name=name)
def _sample_n(self, n, seed=None): seed = samplers.sanitize_seed(seed, salt='gamma') return random_gamma( shape=ps.convert_to_shape_tensor([n]), concentration=tf.convert_to_tensor(self.concentration), rate=None if self.rate is None else tf.convert_to_tensor(self.rate), log_rate=(None if self.log_rate is None else tf.convert_to_tensor(self.log_rate)), seed=seed)
def _sample_n(self, n, seed=None): seed = samplers.sanitize_seed(seed, salt='binomial') return _random_binomial(shape=ps.convert_to_shape_tensor([n]), counts=tf.convert_to_tensor(self._total_count), probs=(None if self._probs is None else tf.convert_to_tensor(self._probs)), logits=(None if self._logits is None else tf.convert_to_tensor(self._logits)), output_dtype=self.dtype, seed=seed)[0]
def normal_generator(shape): shape = ps.convert_to_shape_tensor(shape, dtype=np.int32) loc = yield trainable_state_util.Parameter( init_fn=functools.partial(samplers.normal, shape=shape), name='loc') bij = tfb.Softplus() scale = yield trainable_state_util.Parameter( init_fn=lambda seed: bij.forward(samplers.normal(shape, seed=seed)), constraining_bijector=bij, name='scale') return tfd.Normal(loc=loc, scale=scale, validate_args=True)
def random_gamma(shape, concentration, rate, seed=None): shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') total_shape = ps.concat( [shape, ps.broadcast_shape(ps.shape(concentration), ps.shape(rate))], axis=0) seed = samplers.sanitize_seed(seed, salt='random_gamma') return _random_gamma_gradient(total_shape, concentration, rate, seed)
def _dimension(self): """Scalar dimension of underlying vector space.""" with tf.name_scope('dimension'): if tf.compat.dimension_value(self._scale.shape[-1]) is None: return tf.cast(self._scale.domain_dimension_tensor(), dtype=self._scale.dtype, name='dimension') else: return ps.convert_to_shape_tensor(tf.compat.dimension_value( self._scale.shape[-1]), dtype=self._scale.dtype, name='dimension')
def sample(self, sample_shape=(), seed=None, name='sample'): # pylint: disable=unused-argument return tf.zeros( ps.concat( [ # sample_shape might be a scalar ps.reshape(ps.convert_to_shape_tensor( sample_shape, tf.int32), shape=[-1]), self.batch_shape_tensor(), self.event_shape_tensor() ], axis=0))
def _expand_x_fn(tensor): # Reshape tensor to tensor.shape + [1] * M. extended_shape = ps.concat( [ ps.shape(tensor), ps.ones_like( ps.convert_to_shape_tensor( ps.shape_slice(y_ref, np.s_[batch_dims + nd:]))) ], axis=0, ) return tf.reshape(tensor, extended_shape)
def _sample_n(self, n, seed=None): seed = samplers.sanitize_seed(seed, salt='inverse_gaussian') loc = tf.convert_to_tensor(self.loc) concentration = tf.convert_to_tensor(self.concentration) total_shape = ps.concat([ ps.convert_to_shape_tensor([n]), self._batch_shape_tensor(loc=loc, concentration=concentration) ], axis=0) return _random_inverse_gaussian_gradient(total_shape, loc, concentration, seed)
def _sample_n(self, n, seed=None): seed = samplers.sanitize_seed(seed, salt='binomial') total_count = tf.convert_to_tensor(self._total_count) if self._probs is None: probs = self._probs_parameter_no_checks(total_count=total_count) else: probs = tf.convert_to_tensor(self._probs) return _random_binomial(shape=ps.convert_to_shape_tensor([n]), counts=total_count, probs=probs, output_dtype=self.dtype, seed=seed)[0]
def _num_samples_to_skip(self, call_counter): """Calculates how many samples to skip based on the call number.""" # If `self.num_burnin_steps` is statically known to be 0, # `self.num_steps_between_results` will be returned outright. num_burnin_steps = ps.convert_to_shape_tensor(self.num_burnin_steps, dtype_hint=tf.int32) num_burnin_steps_ = tf.get_static_value(num_burnin_steps) if num_burnin_steps_ == 0: return self.num_steps_between_results else: return (tf.where(tf.equal(call_counter, 0), num_burnin_steps, 0) + tf.convert_to_tensor(self.num_steps_between_results, dtype_hint=tf.int32))
def event_shape_tensor(self, name='event_shape_tensor'): """Shape of a single sample from a single batch as a 1-D int32 `Tensor`. Args: name: name to give to the op Returns: event_shape: `Tensor`. """ with tf.name_scope(name): return ps.convert_to_shape_tensor(self.event_shape, name='event_shape')
def _pad_sample_dims(self, x, event_ndims=None): with tf.name_scope('pad_sample_dims'): if event_ndims is None: event_ndims = self._event_ndims() ndims = ps.rank(x) # Must do the c_t_t in case ndims or event_ndims are Tensors and shape is # ndarray. Otherwise we get `TypeError: slice indices must be integers # or None or have an __index__ method`. shape = ps.convert_to_shape_tensor(ps.shape(x)) d = ndims - event_ndims x = tf.reshape( x, shape=ps.concat([shape[:d], [1], shape[d:]], axis=0)) return x
def convert_fn(path, value, dtype, dtype_hint, name=None): if not allow_packing and nest.is_nested(value) and any( # Treat arrays like Tensors for full parity in JAX backend. tf.is_tensor(x) or isinstance(x, np.ndarray) for x in nest.flatten(value)): raise NotImplementedError( ('Cannot convert a structure of tensors to a ' 'single tensor. Saw {} at path {}.').format(value, path)) if as_shape_tensor: return ps.convert_to_shape_tensor(value, dtype, dtype_hint, name=name) else: return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
def _make_list_or_1d_tensor(values): """Return a list (preferred) or 1d Tensor from values, if values.ndims < 2.""" values = ps.convert_to_shape_tensor(values, name='values') values_ = tf.get_static_value(values) # Static didn't work. if values_ is None: # Cheap way to bring to at least 1d. return values + tf.zeros([1], dtype=values.dtype) # Static worked! if values_.ndim > 1: raise ValueError('values had > 1 dim: {}'.format(values_.shape)) # Cheap way to bring to at least 1d. values_ = values_ + np.zeros([1], dtype=values_.dtype) return list(values_)