def _random_binomial( shape, counts, probs=None, logits=None, output_dtype=tf.float32, seed=None, name=None): """Sample a binomial, CPU specialized to stateless_binomial. Args: shape: Shape of the full sample output. Trailing dims should match the broadcast shape of `counts` with `probs|logits`. counts: Batch of total_count. probs: Batch of p(success). logits: Batch of log-odds(success). output_dtype: DType of samples. seed: int or Tensor seed. name: Optional name for related ops. Returns: samples: Samples from binomial distributions. runtime_used_for_sampling: One of `implementation_selection._RUNTIME_*`. """ with tf.name_scope(name or 'random_binomial'): seed = samplers.sanitize_seed(seed) shape = tf.convert_to_tensor(shape, dtype_hint=tf.int32, name='shape') params = dict(shape=shape, counts=counts, probs=probs, logits=logits, output_dtype=output_dtype, seed=seed, name=name) sampler_impl = implementation_selection.implementation_selecting( fn_name='binomial', default_fn=_random_binomial_noncpu, cpu_fn=_random_binomial_cpu) return sampler_impl(**params)
def _random_poisson( shape, rates=None, log_rates=None, output_dtype=tf.float32, seed=None, name=None): """Sample a poisson, CPU specialized to stateless_poisson. Args: shape: Shape of the full sample output. Trailing dims should match the broadcast shape of `counts` with `probs|logits`. rates: Batch of rates for Poisson distribution. log_rates: Batch of log rates for Poisson distribution. output_dtype: DType of samples. seed: int or Tensor seed. name: Optional name for related ops. Returns: samples: Samples from poisson distributions. runtime_used_for_sampling: One of `implementation_selection._RUNTIME_*`. """ with tf.name_scope(name or 'random_poisson'): seed = samplers.sanitize_seed(seed) shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') params = dict(shape=shape, rates=rates, log_rates=log_rates, output_dtype=output_dtype, seed=seed, name=name) sampler_impl = implementation_selection.implementation_selecting( fn_name='poisson', default_fn=_random_poisson_noncpu, cpu_fn=_random_poisson_cpu) return sampler_impl(**params)
def _random_gamma_no_gradient(shape, concentration, rate, log_rate, seed, log_space): """Sample a gamma, CPU specialized to stateless_gamma. Args: shape: Sample shape. concentration: Concentration of gamma distribution. rate: Rate parameter of gamma distribution. log_rate: Log-rate parameter of gamma distribution. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. log_space: If `True`, draw log-of-gamma samples. Returns: samples: Samples from gamma distributions. """ seed = samplers.sanitize_seed(seed) sampler_impl = implementation_selection.implementation_selecting( fn_name='gamma', default_fn=_random_gamma_noncpu, cpu_fn=_random_gamma_cpu) return sampler_impl(shape=shape, concentration=concentration, rate=rate, log_rate=log_rate, seed=seed, log_space=log_space)
def cumsum(x): def _xla_friendly(x): return tf.math.cumsum(x) def _xla_hostile(x): return tf.while_loop( cond=lambda x, _: tf.size(x) > 0, body=lambda x, cx: ( x[:-1], # pylint: disable=g-long-lambda cx + tf.pad(x, [[tf.size(cx) - tf.size(x), 0]])), loop_vars=[x, tf.zeros_like(x)], shape_invariants=[tf.TensorShape([None]), x.shape])[1] return implementation_selection.implementation_selecting( 'cumsum', default_fn=_xla_friendly, cpu_fn=_xla_hostile)(x=x)
def _random_gamma_no_gradient(shape, concentration, rate, seed): """Sample a gamma, CPU specialized to stateless_gamma. Args: shape: Sample shape. concentration: Concentration of gamma distribution. rate: Rate parameter of gamma distribution. seed: int or Tensor seed. Returns: samples: Samples from gamma distributions. """ sampler_impl = implementation_selection.implementation_selecting( fn_name='gamma', default_fn=_random_gamma_noncpu, cpu_fn=_random_gamma_cpu) return sampler_impl( shape=shape, concentration=concentration, rate=rate, seed=seed)
def sample(concentration, rate): sampler_impl = implementation_selection.implementation_selecting( fn_name='gamma', default_fn=_random_gamma_noncpu, cpu_fn=_random_gamma_cpu) samples = sampler_impl(shape=shape, concentration=concentration, rate=rate, seed=seed) # Ignore any gradient contributions that come from the implementation enum. def grad(dy, _): """The gradient of the gamma samples w.r.t alpha and beta.""" partial_alpha = tf.raw_ops.RandomGammaGrad( alpha=concentration, sample=samples[0] * rate) / rate partial_beta = dy * -samples[0] / rate # These will need to be shifted by the extra dimensions added from # `sample_shape`. grad_a = tf.math.reduce_sum(dy * partial_alpha, axis=tf.range(tf.size(shape))) grad_b = tf.math.reduce_sum(dy * partial_beta, axis=tf.range(tf.size(shape))) if (tensorshape_util.is_fully_defined(concentration.shape) and tensorshape_util.is_fully_defined(rate.shape) and concentration.shape == rate.shape): return [grad_a, grad_b] ra, rb = tf.raw_ops.BroadcastGradientArgs( s0=tf.shape(concentration), s1=tf.shape(rate)) grad_a = tf.reshape( tf.math.reduce_sum(grad_a, axis=ra, keepdims=True), tf.shape(concentration)) grad_b = tf.reshape( tf.math.reduce_sum(grad_b, axis=rb, keepdims=True), tf.shape(rate)) return [grad_a, grad_b] return samples, grad