def __init__(self, scale, validate_args=False, allow_nan_stats=True, name="HalfNormal"): """Construct HalfNormals with scale `scale`. Args: scale: Floating point tensor; the scales of the distribution(s). Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): self._scale = array_ops.identity(scale, name="scale") super(HalfNormal, self).__init__( dtype=self._scale.dtype, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._scale], name=name)
def __init__( self, logits=None, probs=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="OneHotCategorical"): """Initialize OneHotCategorical distributions using class log-probabilities. Args: logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of Categorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set of Categorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. dtype: The type of the event samples (default: int32). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, multidimensional=True) logits_shape_static = self._logits.get_shape().with_rank_at_least(1) if logits_shape_static.ndims is not None: self._batch_rank = ops.convert_to_tensor( logits_shape_static.ndims - 1, dtype=dtypes.int32, name="batch_rank") else: with ops.name_scope(name="batch_rank"): self._batch_rank = array_ops.rank(self._logits) - 1 with ops.name_scope(name="event_size"): self._event_size = array_ops.shape(self._logits)[-1] super(OneHotCategorical, self).__init__( dtype=dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._probs], name=name)
def __init__(self, scale, validate_args=False, allow_nan_stats=True, name="HalfNormal"): """Construct HalfNormals with scale `scale`. Args: scale: Floating point tensor; the scales of the distribution(s). Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[scale]) as name: with ops.control_dependencies( [check_ops.assert_positive(scale)] if validate_args else []): self._scale = array_ops.identity(scale, name="scale") super(HalfNormal, self).__init__( dtype=self._scale.dtype, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._scale], name=name)
def __init__(self, distribution_fn, sample0=None, num_steps=None, validate_args=False, allow_nan_stats=True, name="Autoregressive"): """Construct an `Autoregressive` distribution. Args: distribution_fn: Python `callable` which constructs a `tf.distributions.Distribution`-like instance from a `Tensor` (e.g., `sample0`). The function must respect the "autoregressive property", i.e., there exists a permutation of event such that each coordinate is a diffeomorphic function of on preceding coordinates. sample0: Initial input to `distribution_fn`; used to build the distribution in `__init__` which in turn specifies this distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`. If unspecified, then `distribution_fn` should be default constructable. num_steps: Number of times `distribution_fn` is composed from samples, e.g., `num_steps=2` implies `distribution_fn(distribution_fn(sample0).sample(n)).sample()`. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Default value: "Autoregressive". Raises: ValueError: if `num_steps` and `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 self._distribution0 = (distribution_fn() if sample0 is None else distribution_fn(sample0)) if num_steps is None: num_steps = self._distribution0.event_shape.num_elements() if num_steps is None: raise ValueError("distribution_fn must generate a distribution " "with fully known `event_shape`.") if num_steps < 1: raise ValueError("num_steps ({}) must be at least 1.".format(num_steps)) self._num_steps = num_steps super(Autoregressive, self).__init__( dtype=self._distribution0.dtype, reparameterization_type=self._distribution0.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=self._distribution0._graph_parents, # pylint: disable=protected-access name=name)
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="Multinomial"): """Initialize a batch of Multinomial distributions. Args: total_count: Non-negative floating point tensor with shape broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Multinomial distributions. Its components should be equal to integer values. logits: Floating point tensor representing unnormalized log-probabilities of a positive event with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines this as a batch of `N1 x ... x Nm` different `K` class Multinomial distributions. Only one of `logits` or `probs` should be passed in. probs: Positive floating point tensor with shape broadcastable to `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines this as a batch of `N1 x ... x Nm` different `K` class Multinomial distributions. `probs`'s components in the last portion of its shape should sum to `1`. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = ops.convert_to_tensor(total_count, name="total_count") if validate_args: self._total_count = ( distribution_util.embed_check_nonnegative_integer_form( self._total_count)) self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, multidimensional=True, validate_args=validate_args, name=name) self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs super(Multinomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._logits, self._probs], name=name)
def __init__(self, total_count, concentration, validate_args=False, allow_nan_stats=True, name="DirichletMultinomial"): """Initialize a batch of DirichletMultinomial distributions. Args: total_count: Non-negative floating point tensor, whose dtype is the same as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Dirichlet multinomial distributions. Its components should be equal to integer values. concentration: Positive floating point tensor, whose dtype is the same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet multinomial distributions. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[total_count, concentration]) as name: # Broadcasting works because: # * The broadcasting convention is to prepend dimensions of size [1], and # we use the last dimension for the distribution, whereas # the batch dimensions are the leading dimensions, which forces the # distribution dimension to be defined explicitly (i.e. it cannot be # created automatically by prepending). This forces enough explicitness. # * All calls involving `counts` eventually require a broadcast between # `counts` and concentration. self._total_count = ops.convert_to_tensor(total_count, name="total_count") if validate_args: self._total_count = ( distribution_util.embed_check_nonnegative_integer_form( self._total_count)) self._concentration = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration, name="concentration"), validate_args) self._total_concentration = math_ops.reduce_sum( self._concentration, -1) super(DirichletMultinomial, self).__init__( dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, graph_parents=[self._total_count, self._concentration], name=name)
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="NegativeBinomial"): """Construct NegativeBinomial distributions. Args: total_count: Non-negative floating-point `Tensor` with shape broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Negative Binomial distributions. In practice, this represents the number of negative Bernoulli trials to stop at (the `total_count` of failures), but this is still a valid distribution when `total_count` is a non-integer. logits: Floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Negative Binomial distributions and must be in the open interval `(-inf, inf)`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Negative Binomial distributions and must be in the open interval `(0, 1)`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( [check_ops.assert_positive(total_count )] if validate_args else []): self._total_count = array_ops.identity(total_count) super(NegativeBinomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._probs, self._logits], name=name)
def __init__(self, rate=None, log_rate=None, validate_args=False, allow_nan_stats=True, name="Poisson"): """Initialize a batch of Poisson distributions. Args: rate: Floating point tensor, the rate parameter. `rate` must be positive. Must specify exactly one of `rate` and `log_rate`. log_rate: Floating point tensor, the log of the rate parameter. Must specify exactly one of `rate` and `log_rate`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if none or both of `rate`, `log_rate` are specified. TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[rate]) as name: if (rate is None) == (log_rate is None): raise ValueError("Must specify exactly one of `rate` and `log_rate`.") elif log_rate is None: rate = ops.convert_to_tensor(rate, name="rate") if not rate.dtype.is_floating: raise TypeError("rate.dtype ({}) is a not a float-type.".format( rate.dtype.name)) with ops.control_dependencies([check_ops.assert_positive(rate)] if validate_args else []): self._rate = array_ops.identity(rate, name="rate") self._log_rate = math_ops.log(rate, name="log_rate") else: log_rate = ops.convert_to_tensor(log_rate, name="log_rate") if not log_rate.dtype.is_floating: raise TypeError("log_rate.dtype ({}) is a not a float-type.".format( log_rate.dtype.name)) self._rate = math_ops.exp(log_rate, name="rate") self._log_rate = ops.convert_to_tensor(log_rate, name="log_rate") super(Poisson, self).__init__( dtype=self._rate.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._rate], name=name)
def __init__(self, distribution, batch_shape, validate_args=False, allow_nan_stats=True, name=None): """Construct BatchReshape distribution. Args: distribution: The base distribution instance to reshape. Typically an instance of `Distribution`. batch_shape: Positive `int`-like vector-shaped `Tensor` representing the new shape of the batch dimensions. Up to one dimension may contain `-1`, meaning the remainder of the batch size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: The name to give Ops created by the initializer. Default value: `"BatchReshape" + distribution.name`. Raises: ValueError: if `batch_shape` is not a vector. ValueError: if `batch_shape` has non-positive elements. ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ parameters = distribution_util.parent_frame_arguments() name = name or "BatchReshape" + distribution.name with ops.name_scope(name, values=[batch_shape]) as name: # The unexpanded batch shape may contain up to one dimension of -1. self._batch_shape_unexpanded = ops.convert_to_tensor( batch_shape, dtype=dtypes.int32, name="batch_shape") validate_init_args_statically(distribution, self._batch_shape_unexpanded) batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( distribution.batch_shape_tensor(), self._batch_shape_unexpanded, validate_args) self._distribution = distribution self._batch_shape_ = batch_shape self._batch_shape_static = batch_shape_static self._runtime_assertions = runtime_assertions super(BatchReshape, self).__init__( dtype=distribution.dtype, reparameterization_type=distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=([self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access name=name)
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="NegativeBinomial"): """Construct NegativeBinomial distributions. Args: total_count: Non-negative floating-point `Tensor` with shape broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Negative Binomial distributions. In practice, this represents the number of negative Bernoulli trials to stop at (the `total_count` of failures), but this is still a valid distribution when `total_count` is a non-integer. logits: Floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Negative Binomial distributions and must be in the open interval `(-inf, inf)`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Negative Binomial distributions and must be in the open interval `(0, 1)`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( [check_ops.assert_positive(total_count)] if validate_args else []): self._total_count = array_ops.identity(total_count) super(NegativeBinomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._probs, self._logits], name=name)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies( [check_ops.assert_positive(temperature )] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args) super(RelaxedBernoulli, self).__init__(distribution=logistic.Logistic( self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name + "/Logistic"), bijector=Sigmoid(validate_args=validate_args), validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, distribution, batch_shape, validate_args=False, allow_nan_stats=True, name=None): """Construct BatchReshape distribution. Args: distribution: The base distribution instance to reshape. Typically an instance of `Distribution`. batch_shape: Positive `int`-like vector-shaped `Tensor` representing the new shape of the batch dimensions. Up to one dimension may contain `-1`, meaning the remainder of the batch size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: The name to give Ops created by the initializer. Default value: `"BatchReshape" + distribution.name`. Raises: ValueError: if `batch_shape` is not a vector. ValueError: if `batch_shape` has non-positive elements. ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ parameters = distribution_util.parent_frame_arguments() name = name or "BatchReshape" + distribution.name with ops.name_scope(name, values=[batch_shape]) as name: # The unexpanded batch shape may contain up to one dimension of -1. self._batch_shape_unexpanded = ops.convert_to_tensor( batch_shape, dtype=dtypes.int32, name="batch_shape") validate_init_args_statically(distribution, self._batch_shape_unexpanded) batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( distribution.batch_shape_tensor(), self._batch_shape_unexpanded, validate_args) self._distribution = distribution self._batch_shape_ = batch_shape self._batch_shape_static = batch_shape_static self._runtime_assertions = runtime_assertions super(BatchReshape, self).__init__( dtype=distribution.dtype, reparameterization_type=distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access name=name)
def __init__(self, df, loc, scale, validate_args=False, allow_nan_stats=True, name="StudentT"): """Construct Student's t distributions. The distributions have degree of freedom `df`, mean `loc`, and scale `scale`. The parameters `df`, `loc`, and `scale` must be shaped in a way that supports broadcasting (e.g. `df + loc + scale` is a valid operation). Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. loc: Floating-point `Tensor`. The mean(s) of the distribution(s). scale: Floating-point `Tensor`. The scaling factor(s) for the distribution(s). Note that `scale` is not technically the standard deviation of this distribution but has semantics more similar to standard deviation than variance. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if loc and scale are different dtypes. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[df, loc, scale]) as name: with ops.control_dependencies( [check_ops.assert_positive(df)] if validate_args else []): self._df = array_ops.identity(df, name="df") self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") check_ops.assert_same_float_dtype( (self._df, self._loc, self._scale)) super(StudentT, self).__init__( dtype=self._scale.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._df, self._loc, self._scale], name=name)
def __init__(self, df, loc, scale, validate_args=False, allow_nan_stats=True, name="StudentT"): """Construct Student's t distributions. The distributions have degree of freedom `df`, mean `loc`, and scale `scale`. The parameters `df`, `loc`, and `scale` must be shaped in a way that supports broadcasting (e.g. `df + loc + scale` is a valid operation). Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. loc: Floating-point `Tensor`. The mean(s) of the distribution(s). scale: Floating-point `Tensor`. The scaling factor(s) for the distribution(s). Note that `scale` is not technically the standard deviation of this distribution but has semantics more similar to standard deviation than variance. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if loc and scale are different dtypes. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[df, loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(df)] if validate_args else []): self._df = array_ops.identity(df, name="df") self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") check_ops.assert_same_float_dtype( (self._df, self._loc, self._scale)) super(StudentT, self).__init__( dtype=self._scale.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._df, self._loc, self._scale], name=name)
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="Binomial"): """Initialize a batch of Binomial distributions. Args: total_count: Non-negative floating point tensor with shape broadcastable to `[N1,..., Nm]` with `m >= 0` and the same dtype as `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Binomial distributions. Its components should be equal to integer values. logits: Floating point tensor representing the log-odds of a positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and the same dtype as `total_count`. Each entry represents logits for the probability of success for independent Binomial distributions. Only one of `logits` or `probs` should be passed in. probs: Positive floating point tensor with shape broadcastable to `[N1,..., Nm]` `m >= 0`, `probs in [0, 1]`. Each entry represents the probability of success for independent Binomial distributions. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), validate_args) self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, name=name) super(Binomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._logits, self._probs], name=name)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args) super(RelaxedBernoulli, self).__init__( distribution=logistic.Logistic( self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name + "/Logistic"), bijector=Sigmoid(validate_args=validate_args), validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, distribution, reinterpreted_batch_ndims=None, validate_args=False, name=None): """Construct a `Independent` distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims which will be regarded as event dims. When `None` all but the first batch axis (batch axis 0) will be transferred to event dimensions (analogous to `tf.layers.flatten`). validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. name: The name for ops managed by the distribution. Default value: `Independent + distribution.name`. Raises: ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ parameters = distribution_util.parent_frame_arguments() name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name) as name: if reinterpreted_batch_ndims is None: reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims( distribution) reinterpreted_batch_ndims = ops.convert_to_tensor( reinterpreted_batch_ndims, dtype=dtypes.int32, name="reinterpreted_batch_ndims") self._reinterpreted_batch_ndims = reinterpreted_batch_ndims self._static_reinterpreted_batch_ndims = tensor_util.constant_value( reinterpreted_batch_ndims) if self._static_reinterpreted_batch_ndims is not None: self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims super(Independent, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution. reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, graph_parents=([reinterpreted_batch_ndims] + distribution._graph_parents), # pylint: disable=protected-access name=name) self._runtime_assertions = self._make_runtime_assertions( distribution, reinterpreted_batch_ndims, validate_args)
def __init__(self, rate, validate_args=False, allow_nan_stats=True, name="ExponentialWithSoftplusRate"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[rate]) as name: super(ExponentialWithSoftplusRate, self).__init__( rate=nn.softplus(rate, name="softplus_rate"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, rate, validate_args=False, allow_nan_stats=True, name="ExponentialWithSoftplusRate"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[rate]) as name: super(ExponentialWithSoftplusRate, self).__init__(rate=nn.softplus(rate, name="softplus_rate"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__( self, distribution, reinterpreted_batch_ndims=None, validate_args=False, name=None): """Construct a `Independent` distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims which will be regarded as event dims. When `None` all but the first batch axis (batch axis 0) will be transferred to event dimensions (analogous to `tf.layers.flatten`). validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. name: The name for ops managed by the distribution. Default value: `Independent + distribution.name`. Raises: ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ parameters = distribution_util.parent_frame_arguments() name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name) as name: if reinterpreted_batch_ndims is None: reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims( distribution) reinterpreted_batch_ndims = ops.convert_to_tensor( reinterpreted_batch_ndims, dtype=dtypes.int32, name="reinterpreted_batch_ndims") self._reinterpreted_batch_ndims = reinterpreted_batch_ndims self._static_reinterpreted_batch_ndims = tensor_util.constant_value( reinterpreted_batch_ndims) if self._static_reinterpreted_batch_ndims is not None: self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims super(Independent, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, graph_parents=( [reinterpreted_batch_ndims] + distribution._graph_parents), # pylint: disable=protected-access name=name) self._runtime_assertions = self._make_runtime_assertions( distribution, reinterpreted_batch_ndims, validate_args)
def __init__(self, concentration, rate, validate_args=False, allow_nan_stats=True, name="InverseGamma"): """Construct InverseGamma with `concentration` and `rate` parameters. The parameters `concentration` and `rate` must be shaped in a way that supports broadcasting (e.g. `concentration + rate` is a valid operation). Args: concentration: Floating point tensor, the concentration params of the distribution(s). Must contain only positive values. rate: Floating point tensor, the inverse scale params of the distribution(s). Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `concentration` and `rate` are different dtypes. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), check_ops.assert_positive(rate), ] if validate_args else []): self._concentration = array_ops.identity( concentration, name="concentration") self._rate = array_ops.identity(rate, name="rate") check_ops.assert_same_float_dtype( [self._concentration, self._rate]) super(InverseGamma, self).__init__( dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._rate], name=name)
def __init__(self, concentration1=None, concentration0=None, validate_args=False, allow_nan_stats=True, name="Beta"): """Initialize a batch of Beta distributions. Args: concentration1: Positive floating-point `Tensor` indicating mean number of successes; aka "alpha". Implies `self.dtype` and `self.batch_shape`, i.e., `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. concentration0: Positive floating-point `Tensor` indicating mean number of failures; aka "beta". Otherwise has same semantics as `concentration1`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration1, concentration0]) as name: self._concentration1 = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration1, name="concentration1"), validate_args) self._concentration0 = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration0, name="concentration0"), validate_args) check_ops.assert_same_float_dtype( [self._concentration1, self._concentration0]) self._total_concentration = self._concentration1 + self._concentration0 super(Beta, self).__init__( dtype=self._total_concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, graph_parents=[ self._concentration1, self._concentration0, self._total_concentration ], name=name)
def __init__(self, concentration, rate, validate_args=False, allow_nan_stats=True, name="InverseGamma"): """Construct InverseGamma with `concentration` and `rate` parameters. The parameters `concentration` and `rate` must be shaped in a way that supports broadcasting (e.g. `concentration + rate` is a valid operation). Args: concentration: Floating point tensor, the concentration params of the distribution(s). Must contain only positive values. rate: Floating point tensor, the inverse scale params of the distribution(s). Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `concentration` and `rate` are different dtypes. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), check_ops.assert_positive(rate), ] if validate_args else []): self._concentration = array_ops.identity(concentration, name="concentration") self._rate = array_ops.identity(rate, name="rate") check_ops.assert_same_float_dtype( [self._concentration, self._rate]) super(InverseGamma, self).__init__( dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._rate], name=name)
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name="NormalWithSoftplusScale"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[scale]) as name: super(NormalWithSoftplusScale, self).__init__( loc=loc, scale=nn.softplus(scale, name="softplus_scale"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, df, validate_args=False, allow_nan_stats=True, name="Chi2WithAbsDf"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[df]) as name: super(Chi2WithAbsDf, self).__init__( df=math_ops.floor( math_ops.abs(df, name="abs_df"), name="floor_abs_df"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, df, validate_args=False, allow_nan_stats=True, name="Chi2WithAbsDf"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[df]) as name: super(Chi2WithAbsDf, self).__init__(df=math_ops.floor(math_ops.abs(df, name="abs_df"), name="floor_abs_df"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="Geometric"): """Construct Geometric distributions. Args: logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Geometric distributions and must be in the range `(-inf, inf]`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Geometric distributions and must be in the range `(0, 1]`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( [check_ops.assert_positive(self._probs )] if validate_args else []): self._probs = array_ops.identity(self._probs, name="probs") super(Geometric, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._probs, self._logits], name=name)
def __init__(self, concentration, rate, validate_args=False, allow_nan_stats=True, name="InverseGammaWithSoftplusConcentrationRate"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration, rate]) as name: super(InverseGammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, name="softplus_concentration"), rate=nn.softplus(rate, name="softplus_rate"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, concentration, rate, validate_args=False, allow_nan_stats=True, name="GammaWithSoftplusConcentrationRate"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration, rate]) as name: super(GammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, name="softplus_concentration"), rate=nn.softplus(rate, name="softplus_rate"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name="NormalWithSoftplusScale"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[scale]) as name: super(NormalWithSoftplusScale, self).__init__(loc=loc, scale=nn.softplus(scale, name="softplus_scale"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="Geometric"): """Construct Geometric distributions. Args: logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Geometric distributions and must be in the range `(-inf, inf]`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Geometric distributions and must be in the range `(0, 1]`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( [check_ops.assert_positive(self._probs)] if validate_args else []): self._probs = array_ops.identity(self._probs, name="probs") super(Geometric, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._probs, self._logits], name=name)
def __init__(self, logits=None, probs=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="Bernoulli"): """Construct Bernoulli distributions. Args: logits: An N-D `Tensor` representing the log-odds of a `1` event. Each entry in the `Tensor` parametrizes an independent Bernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a `1` event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. dtype: The type of the event samples. Default: `int32`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If p and logits are passed, or if neither are passed. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, name=name) super(Bernoulli, self).__init__( dtype=dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._probs], name=name)
def __init__(self, df, loc, scale, validate_args=False, allow_nan_stats=True, name="StudentTWithAbsDfSoftplusScale"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[df, scale]) as name: super(StudentTWithAbsDfSoftplusScale, self).__init__( df=math_ops.floor(math_ops.abs(df)), loc=loc, scale=nn.softplus(scale, name="softplus_scale"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, df, loc, scale, validate_args=False, allow_nan_stats=True, name="StudentTWithAbsDfSoftplusScale"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[df, scale]) as name: super(StudentTWithAbsDfSoftplusScale, self).__init__(df=math_ops.floor(math_ops.abs(df)), loc=loc, scale=nn.softplus(scale, name="softplus_scale"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, concentration1, concentration0, validate_args=False, allow_nan_stats=True, name="BetaWithSoftplusConcentration"): parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration1, concentration0]) as name: super(BetaWithSoftplusConcentration, self).__init__( concentration1=nn.softplus(concentration1, name="softplus_concentration1"), concentration0=nn.softplus(concentration0, name="softplus_concentration0"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name="Laplace"): """Construct Laplace distribution with parameters `loc` and `scale`. The parameters `loc` and `scale` must be shaped in a way that supports broadcasting (e.g., `loc / scale` is a valid operation). Args: loc: Floating point tensor which characterizes the location (center) of the distribution. scale: Positive floating point tensor which characterizes the spread of the distribution. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `loc` and `scale` are of different dtype. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies( [check_ops.assert_positive(scale)] if validate_args else []): self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") check_ops.assert_same_float_dtype([self._loc, self._scale]) super(Laplace, self).__init__( dtype=self._loc.dtype, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._loc, self._scale], name=name)
def __init__(self, low=0., high=1., validate_args=False, allow_nan_stats=True, name="Uniform"): """Initialize a batch of Uniform distributions. Args: low: Floating point tensor, lower boundary of the output interval. Must have `low < high`. high: Floating point tensor, upper boundary of the output interval. Must have `low < high`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: InvalidArgumentError: if `low >= high` and `validate_args=False`. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[low, high]) as name: with ops.control_dependencies([ check_ops.assert_less( low, high, message="uniform not defined when low >= high.") ] if validate_args else []): self._low = array_ops.identity(low, name="low") self._high = array_ops.identity(high, name="high") check_ops.assert_same_float_dtype([self._low, self._high]) super(Uniform, self).__init__( dtype=self._low.dtype, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._low, self._high], name=name)
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name="Laplace"): """Construct Laplace distribution with parameters `loc` and `scale`. The parameters `loc` and `scale` must be shaped in a way that supports broadcasting (e.g., `loc / scale` is a valid operation). Args: loc: Floating point tensor which characterizes the location (center) of the distribution. scale: Positive floating point tensor which characterizes the spread of the distribution. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `loc` and `scale` are of different dtype. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") check_ops.assert_same_float_dtype([self._loc, self._scale]) super(Laplace, self).__init__( dtype=self._loc.dtype, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._loc, self._scale], name=name)
def __init__(self, concentration, validate_args=False, allow_nan_stats=True, name="Dirichlet"): """Initialize a batch of Dirichlet distributions. Args: concentration: Positive floating-point `Tensor` indicating mean number of class occurrences; aka "alpha". Implies `self.dtype`, and `self.batch_shape`, `self.event_shape`, i.e., if `concentration.shape = [N1, N2, ..., Nm, k]` then `batch_shape = [N1, N2, ..., Nm]` and `event_shape = [k]`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration]) as name: self._concentration = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration, name="concentration"), validate_args) self._total_concentration = math_ops.reduce_sum( self._concentration, -1) super(Dirichlet, self).__init__( dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._total_concentration], name=name)
def __init__(self, df, validate_args=False, allow_nan_stats=True, name="Chi2"): """Construct Chi2 distributions with parameter `df`. Args: df: Floating point tensor, the degrees of freedom of the distribution(s). `df` must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing # allow_nan_stats=True # through to the parent class results in unnecessary asserts. with ops.name_scope(name, values=[df]) as name: with ops.control_dependencies([ check_ops.assert_positive(df), ] if validate_args else []): self._df = array_ops.identity(df, name="df") super(Chi2, self).__init__(concentration=0.5 * self._df, rate=constant_op.constant( 0.5, dtype=self._df.dtype), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, rate, validate_args=False, allow_nan_stats=True, name="Exponential"): """Construct Exponential distribution with parameter `rate`. Args: rate: Floating point tensor, equivalent to `1 / mean`. Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() # Even though all statistics of are defined for valid inputs, this is not # true in the parent class "Gamma." Therefore, passing # allow_nan_stats=True # through to the parent class results in unnecessary asserts. with ops.name_scope(name, values=[rate]) as name: self._rate = ops.convert_to_tensor(rate, name="rate") super(Exponential, self).__init__(concentration=array_ops.ones( [], dtype=self._rate.dtype), rate=self._rate, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) # While the Gamma distribution is not reparameterizable, the exponential # distribution is. self._reparameterization_type = True self._parameters = parameters self._graph_parents += [self._rate]
def __init__(self, rate, validate_args=False, allow_nan_stats=True, name="Exponential"): """Construct Exponential distribution with parameter `rate`. Args: rate: Floating point tensor, equivalent to `1 / mean`. Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() # Even though all statistics of are defined for valid inputs, this is not # true in the parent class "Gamma." Therefore, passing # allow_nan_stats=True # through to the parent class results in unnecessary asserts. with ops.name_scope(name, values=[rate]) as name: self._rate = ops.convert_to_tensor(rate, name="rate") super(Exponential, self).__init__( concentration=array_ops.ones([], dtype=self._rate.dtype), rate=self._rate, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) # While the Gamma distribution is not reparameterizable, the exponential # distribution is. self._reparameterization_type = True self._parameters = parameters self._graph_parents += [self._rate]
def __init__(self, df, validate_args=False, allow_nan_stats=True, name="Chi2"): """Construct Chi2 distributions with parameter `df`. Args: df: Floating point tensor, the degrees of freedom of the distribution(s). `df` must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing # allow_nan_stats=True # through to the parent class results in unnecessary asserts. with ops.name_scope(name, values=[df]) as name: with ops.control_dependencies([ check_ops.assert_positive(df), ] if validate_args else []): self._df = array_ops.identity(df, name="df") super(Chi2, self).__init__( concentration=0.5 * self._df, rate=constant_op.constant(0.5, dtype=self._df.dtype), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, distribution_fn, sample0=None, num_steps=None, validate_args=False, allow_nan_stats=True, name="Autoregressive"): """Construct an `Autoregressive` distribution. Args: distribution_fn: Python `callable` which constructs a `tf.distributions.Distribution`-like instance from a `Tensor` (e.g., `sample0`). The function must respect the "autoregressive property", i.e., there exists a permutation of event such that each coordinate is a diffeomorphic function of on preceding coordinates. sample0: Initial input to `distribution_fn`; used to build the distribution in `__init__` which in turn specifies this distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`. If unspecified, then `distribution_fn` should be default constructable. num_steps: Number of times `distribution_fn` is composed from samples, e.g., `num_steps=2` implies `distribution_fn(distribution_fn(sample0).sample(n)).sample()`. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Default value: "Autoregressive". Raises: ValueError: if `num_steps` and `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 self._distribution0 = (distribution_fn() if sample0 is None else distribution_fn(sample0)) if num_steps is None: num_steps = self._distribution0.event_shape.num_elements() if num_steps is None: raise ValueError( "distribution_fn must generate a distribution " "with fully known `event_shape`.") if num_steps < 1: raise ValueError( "num_steps ({}) must be at least 1.".format(num_steps)) self._num_steps = num_steps super(Autoregressive, self).__init__( dtype=self._distribution0.dtype, reparameterization_type=self._distribution0. reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=self._distribution0._graph_parents, # pylint: disable=protected-access name=name)
def __init__(self, mixture_distribution, components_distribution, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tf.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tf.distributions.Distribution`-like instance. Right-most batch dimension indexes components. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not mixture_distribution.dtype.is_integer`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = (s.shape[0].value if s.shape.with_rank_at_least(1)[0].value is not None else array_ops.shape(s)[0]) if not mixture_distribution.dtype.is_integer: raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers".format( mixture_distribution.dtype.name)) if (mixture_distribution.event_shape.ndims is not None and mixture_distribution.event_shape.ndims != 0): raise ValueError("`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ control_flow_ops.assert_has_rank( mixture_distribution.event_shape_tensor(), 0, message="`mixture_distribution` must have scalar `event_dim`s"), ] mdbs = mixture_distribution.batch_shape cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1] if mdbs.is_fully_defined() and cdbs.is_fully_defined(): if mdbs.ndims != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ control_flow_ops.assert_equal( distribution_util.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message=( "`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`"))] km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value if km is not None and kc is not None and km != kc: raise ValueError("`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = array_ops.shape(mixture_distribution.logits)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ control_flow_ops.assert_equal( km, kc, message=("`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = array_ops.shape(mixture_distribution.logits)[-1] self._num_components = km super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)
def __init__(self, distribution, bijector=None, batch_shape=None, event_shape=None, validate_args=False, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. `None` means `Identity()`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch()`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_event()`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Default: `bijector.name + distribution.name`. """ parameters = distribution_util.parent_frame_arguments() name = name or (("" if bijector is None else bijector.name) + distribution.name) with ops.name_scope(name, values=[event_shape, batch_shape]) as name: # For convenience we define some handy constants. self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero") self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty") if bijector is None: bijector = identity_bijector.Identity(validate_args=validate_args) # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph # execution, including possibly raising Python exceptions. self._override_batch_shape = self._maybe_validate_shape_override( batch_shape, distribution.is_scalar_batch(), validate_args, "batch_shape") self._is_batch_override = _logical_not(_logical_equal( _ndims_from_shape(self._override_batch_shape), self._zero)) self._is_maybe_batch_override = bool( tensor_util.constant_value(self._override_batch_shape) is None or tensor_util.constant_value(self._override_batch_shape).size != 0) self._override_event_shape = self._maybe_validate_shape_override( event_shape, distribution.is_scalar_event(), validate_args, "event_shape") self._is_event_override = _logical_not(_logical_equal( _ndims_from_shape(self._override_event_shape), self._zero)) self._is_maybe_event_override = bool( tensor_util.constant_value(self._override_event_shape) is None or tensor_util.constant_value(self._override_event_shape).size != 0) # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that the base distribution has batch dims # and we're overriding event shape. When that case happens the event dims # will incorrectly be to the left of the batch dims. In this case we'll # cyclically permute left the new dims. self._needs_rotation = _logical_and( self._is_event_override, _logical_not(self._is_batch_override), _logical_not(distribution.is_scalar_batch())) override_event_ndims = _ndims_from_shape(self._override_event_shape) self._rotate_ndims = _pick_scalar_condition( self._needs_rotation, override_event_ndims, 0) # We'll be reducing the head dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = math_ops.range( self._rotate_ndims - override_event_ndims, self._rotate_ndims) self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, # We let TransformedDistribution access _graph_parents since this class # is more like a baseclass than derived. graph_parents=(distribution._graph_parents + # pylint: disable=protected-access bijector.graph_parents), name=name)
def __init__(self, cat, components, validate_args=False, allow_nan_stats=True, use_static_graph=False, name="Mixture"): """Initialize a Mixture distribution. A `Mixture` is defined by a `Categorical` (`cat`, representing the mixture probabilities) and a list of `Distribution` objects all having matching dtype, batch shape, event shape, and continuity properties (the components). The `num_classes` of `cat` must be possible to infer at graph construction time and match `len(components)`. Args: cat: A `Categorical` distribution instance, representing the probabilities of `distributions`. components: A list or tuple of `Distribution` instances. Each instance must have the same type, be defined on the same domain, and have matching `event_shape` and `batch_shape`. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between cat and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. use_static_graph: Calls to `sample` will not rely on dynamic tensor indexing, allowing for some static graph compilation optimizations, but at the expense of sampling all underlying distributions in the mixture. (Possibly useful when running on TPUs). Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: TypeError: If cat is not a `Categorical`, or `components` is not a list or tuple, or the elements of `components` are not instances of `Distribution`, or do not have matching `dtype`. ValueError: If `components` is an empty list or tuple, or its elements do not have a statically known event rank. If `cat.num_classes` cannot be inferred at graph creation time, or the constant value of `cat.num_classes` is not equal to `len(components)`, or all `components` and `cat` do not have matching static batch shapes, or all components do not have matching static event shapes. """ parameters = distribution_util.parent_frame_arguments() if not isinstance(cat, categorical.Categorical): raise TypeError("cat must be a Categorical distribution, but saw: %s" % cat) if not components: raise ValueError("components must be a non-empty list or tuple") if not isinstance(components, (list, tuple)): raise TypeError("components must be a list or tuple, but saw: %s" % components) if not all(isinstance(c, distribution.Distribution) for c in components): raise TypeError( "all entries in components must be Distribution instances" " but saw: %s" % components) dtype = components[0].dtype if not all(d.dtype == dtype for d in components): raise TypeError("All components must have the same dtype, but saw " "dtypes: %s" % [(d.name, d.dtype) for d in components]) static_event_shape = components[0].event_shape static_batch_shape = cat.batch_shape for d in components: static_event_shape = static_event_shape.merge_with(d.event_shape) static_batch_shape = static_batch_shape.merge_with(d.batch_shape) if static_event_shape.ndims is None: raise ValueError( "Expected to know rank(event_shape) from components, but " "none of the components provide a static number of ndims") # Ensure that all batch and event ndims are consistent. with ops.name_scope(name, values=[cat.logits]) as name: num_components = cat.event_size static_num_components = tensor_util.constant_value(num_components) if static_num_components is None: raise ValueError( "Could not infer number of classes from cat and unable " "to compare this value to the number of components passed in.") # Possibly convert from numpy 0-D array. static_num_components = int(static_num_components) if static_num_components != len(components): raise ValueError("cat.num_classes != len(components): %d vs. %d" % (static_num_components, len(components))) cat_batch_shape = cat.batch_shape_tensor() cat_batch_rank = array_ops.size(cat_batch_shape) if validate_args: batch_shapes = [d.batch_shape_tensor() for d in components] batch_ranks = [array_ops.size(bs) for bs in batch_shapes] check_message = ("components[%d] batch shape must match cat " "batch shape") self._assertions = [ check_ops.assert_equal( cat_batch_rank, batch_ranks[di], message=check_message % di) for di in range(len(components)) ] self._assertions += [ check_ops.assert_equal( cat_batch_shape, batch_shapes[di], message=check_message % di) for di in range(len(components)) ] else: self._assertions = [] self._cat = cat self._components = list(components) self._num_components = static_num_components self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape self._use_static_graph = use_static_graph if use_static_graph and static_num_components is None: raise ValueError("Number of categories must be known statically when " "`static_sample=True`.") # We let the Mixture distribution access _graph_parents since its arguably # more like a baseclass. graph_parents = self._cat._graph_parents # pylint: disable=protected-access for c in self._components: graph_parents += c._graph_parents # pylint: disable=protected-access super(Mixture, self).__init__( dtype=dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=graph_parents, name=name)
def __init__(self, loc=None, covariance_matrix=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalFullCovariance"): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `covariance_matrix` arguments. The `event_shape` is given by last dimension of the matrix implied by `covariance_matrix`. The last dimension of `loc` (if provided) must broadcast with this. A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive definite matrix. In other words it is (real) symmetric with all eigenvalues strictly positive. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. covariance_matrix: Floating-point, symmetric positive definite `Tensor` of same `dtype` as `loc`. The strict upper triangle of `covariance_matrix` is ignored, so if `covariance_matrix` is not symmetric no error will be raised (unless `validate_args is True`). `covariance_matrix` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ parameters = distribution_util.parent_frame_arguments() # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with ops.name_scope(name) as name: with ops.name_scope("init", values=[loc, covariance_matrix]): if covariance_matrix is None: scale_tril = None else: covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: covariance_matrix = control_flow_ops.with_dependencies([ check_ops.assert_near( covariance_matrix, array_ops.matrix_transpose(covariance_matrix), message="Matrix was not symmetric")], covariance_matrix) # No need to validate that covariance_matrix is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular method that # is called by the Bijector. # However, cholesky() ignores the upper triangular part, so we do need # to separately assert symmetric. scale_tril = linalg_ops.cholesky(covariance_matrix) super(MultivariateNormalFullCovariance, self).__init__( loc=loc, scale_tril=scale_tril, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__( self, temperature, logits=None, probs=None, dtype=None, validate_args=False, allow_nan_stats=True, name="ExpRelaxedOneHotCategorical"): """Initialize ExpRelaxedOneHotCategorical using class log-probabilities. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of ExpRelaxedCategorical distributions. The temperature should be positive. logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of ExpRelaxedCategorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set of ExpRelaxedCategorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. dtype: The type of the event samples (default: inferred from logits/probs). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs, temperature]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, multidimensional=True) if dtype is None: dtype = self._logits.dtype if not validate_args: temperature = math_ops.cast(temperature, dtype) with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._temperature_2d = array_ops.reshape(temperature, [-1, 1], name="temperature_2d") logits_shape_static = self._logits.get_shape().with_rank_at_least(1) if logits_shape_static.ndims is not None: self._batch_rank = ops.convert_to_tensor( logits_shape_static.ndims - 1, dtype=dtypes.int32, name="batch_rank") else: with ops.name_scope(name="batch_rank"): self._batch_rank = array_ops.rank(self._logits) - 1 with ops.name_scope(name="event_size"): self._event_size = array_ops.shape(self._logits)[-1] super(ExpRelaxedOneHotCategorical, self).__init__( dtype=dtype, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._probs, self._temperature], name=name)
def __init__(self, loc=None, scale_tril=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalTriL"): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: ```none scale = scale_tril ``` where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal, i.e., `tf.diag_part(scale_tril) != 0`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_tril: Floating-point, lower-triangular `Tensor` with non-zero diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ parameters = distribution_util.parent_frame_arguments() def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) if loc is None and scale_tril is None: raise ValueError("Must specify one or both of `loc`, `scale_tril`.") with ops.name_scope(name) as name: with ops.name_scope("init", values=[loc, scale_tril]): loc = _convert_to_tensor(loc, name="loc") scale_tril = _convert_to_tensor(scale_tril, name="scale_tril") if scale_tril is None: scale = linalg.LinearOperatorIdentity( num_rows=distribution_util.dimension_size(loc, -1), dtype=loc.dtype, is_self_adjoint=True, is_positive_definite=True, assert_proper_shapes=validate_args) else: # No need to validate that scale_tril is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular # method that is called by the Bijector. scale = linalg.LinearOperatorLowerTriangular( scale_tril, is_non_singular=True, is_self_adjoint=False, is_positive_definite=False) super(MultivariateNormalTriL, self).__init__( loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, distribution, low=None, high=None, validate_args=False, name="QuantizedDistribution"): """Construct a Quantized Distribution representing `Y = ceiling(X)`. Some properties are inherited from the distribution defining `X`. Example: `allow_nan_stats` is determined for this `QuantizedDistribution` by reading the `distribution`. Args: distribution: The base distribution class to transform. Typically an instance of `Distribution`. low: `Tensor` with same `dtype` as this distribution and shape able to be added to samples. Should be a whole number. Default `None`. If provided, base distribution's `prob` should be defined at `low`. high: `Tensor` with same `dtype` as this distribution and shape able to be added to samples. Should be a whole number. Default `None`. If provided, base distribution's `prob` should be defined at `high - 1`. `high` must be strictly greater than `low`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: If `dist_cls` is not a subclass of `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ parameters = distribution_util.parent_frame_arguments() values = ( list(distribution.parameters.values()) + [low, high]) with ops.name_scope(name, values=values) as name: self._dist = distribution if low is not None: low = ops.convert_to_tensor(low, name="low") if high is not None: high = ops.convert_to_tensor(high, name="high") check_ops.assert_same_float_dtype( tensors=[self.distribution, low, high]) # We let QuantizedDistribution access _graph_parents since this class is # more like a baseclass. graph_parents = self._dist._graph_parents # pylint: disable=protected-access checks = [] if validate_args and low is not None and high is not None: message = "low must be strictly less than high." checks.append( check_ops.assert_less( low, high, message=message)) self._validate_args = validate_args # self._check_integer uses this. with ops.control_dependencies(checks if validate_args else []): if low is not None: self._low = self._check_integer(low) graph_parents += [self._low] else: self._low = None if high is not None: self._high = self._check_integer(high) graph_parents += [self._high] else: self._high = None super(QuantizedDistribution, self).__init__( dtype=self._dist.dtype, reparameterization_type=distributions.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=self._dist.allow_nan_stats, parameters=parameters, graph_parents=graph_parents, name=name)
def __init__(self, distribution, low=None, high=None, validate_args=False, name="QuantizedDistribution"): """Construct a Quantized Distribution representing `Y = ceiling(X)`. Some properties are inherited from the distribution defining `X`. Example: `allow_nan_stats` is determined for this `QuantizedDistribution` by reading the `distribution`. Args: distribution: The base distribution class to transform. Typically an instance of `Distribution`. low: `Tensor` with same `dtype` as this distribution and shape able to be added to samples. Should be a whole number. Default `None`. If provided, base distribution's `prob` should be defined at `low`. high: `Tensor` with same `dtype` as this distribution and shape able to be added to samples. Should be a whole number. Default `None`. If provided, base distribution's `prob` should be defined at `high - 1`. `high` must be strictly greater than `low`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: If `dist_cls` is not a subclass of `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ parameters = distribution_util.parent_frame_arguments() values = (list(distribution.parameters.values()) + [low, high]) with ops.name_scope(name, values=values) as name: self._dist = distribution if low is not None: low = ops.convert_to_tensor(low, name="low") if high is not None: high = ops.convert_to_tensor(high, name="high") check_ops.assert_same_float_dtype( tensors=[self.distribution, low, high]) # We let QuantizedDistribution access _graph_parents since this class is # more like a baseclass. graph_parents = self._dist._graph_parents # pylint: disable=protected-access checks = [] if validate_args and low is not None and high is not None: message = "low must be strictly less than high." checks.append(check_ops.assert_less(low, high, message=message)) self._validate_args = validate_args # self._check_integer uses this. with ops.control_dependencies(checks if validate_args else []): if low is not None: self._low = self._check_integer(low) graph_parents += [self._low] else: self._low = None if high is not None: self._high = self._check_integer(high) graph_parents += [self._high] else: self._high = None super(QuantizedDistribution, self).__init__( dtype=self._dist.dtype, reparameterization_type=distributions.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=self._dist.allow_nan_stats, parameters=parameters, graph_parents=graph_parents, name=name)