def _default_event_space_bijector(self): return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), exp_bijector.Exp(validate_args=self.validate_args) ], validate_args=self.validate_args)
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name='LogNormal'): """Construct a log-normal distribution. The LogNormal distribution models positive-valued random variables whose logarithm is normally distributed with mean `loc` and standard deviation `scale`. It is constructed as the exponential transformation of a Normal distribution. Args: loc: Floating-point `Tensor`; the means of the underlying Normal distribution(s). scale: Floating-point `Tensor`; the stddevs of the underlying Normal distribution(s). validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = dict(locals()) with tf.name_scope(name) as name: super(LogNormal, self).__init__( distribution=normal.Normal(loc=loc, scale=scale), bijector=exp_bijector.Exp(), validate_args=validate_args, parameters=parameters, name=name)
def __init__(self, nchan, dtype=tf.float32, validate_args=False, name=None): parameters = dict(locals()) self._initialized = tf.Variable(False, trainable=False) self._m = tf.Variable(tf.zeros(nchan, dtype)) self._s = TransformedVariable(tf.ones(nchan, dtype), exp.Exp()) self._bijector = invert.Invert( chain.Chain([ scale.Scale(self._s), shift.Shift(self._m), ])) super(ActivationNormalization, self).__init__( validate_args=validate_args, forward_min_event_ndims=1, parameters=parameters, name=name or 'ActivationNormalization')
def __init__( self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedOneHotCategorical'): """Initialize RelaxedOneHotCategorical using class log-probabilities. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedOneHotCategorical distributions. The temperature should be positive. logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of RelaxedOneHotCategorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set of RelaxedOneHotCategorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. validate_args: Unused in this distribution. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). """ parameters = dict(locals()) dist = ExpRelaxedOneHotCategorical(temperature, logits=logits, probs=probs, validate_args=validate_args, allow_nan_stats=allow_nan_stats) super(RelaxedOneHotCategorical, self).__init__(dist, exp_bijector.Exp(), validate_args=validate_args, parameters=parameters, name=name)
def _default_event_space_bijector(self): return exp_bijector.Exp(validate_args=self.validate_args)
def quadrature_scheme_lognormal_quantiles(loc, scale, quadrature_size, validate_args=False, name=None): """Use LogNormal quantiles to form quadrature on positive-reals. Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: (Batch of) length-`quadrature_size` vectors representing the `log_rate` parameters of a `Poisson`. probs: (Batch of) length-`quadrature_size` vectors representing the weight associate with each `grid` value. """ with tf.name_scope(name, "quadrature_scheme_lognormal_quantiles", [loc, scale]): # Create a LogNormal distribution. dist = transformed_distribution.TransformedDistribution( distribution=normal.Normal(loc=loc, scale=scale), bijector=exp_bijector.Exp(), validate_args=validate_args) batch_ndims = dist.batch_shape.ndims if batch_ndims is None: batch_ndims = tf.shape(dist.batch_shape_tensor())[0] def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(quantiles, perm) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. grid.set_shape(dist.batch_shape.concatenate([quadrature_size])) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = tf.fill(dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype)) return grid, probs