def testProbabilityValidateArgs(self): p = [0.01, 0.2, 0.5, 0.7, .99] # Component less than 0. p2 = [-1, 0.2, 0.5, 0.3, .2] # Component greater than 1. p3 = [2, 0.2, 0.5, 0.3, .2] with self.test_session(): _, prob = distribution_util.get_logits_and_probs( probs=p, validate_args=True) prob.eval() with self.assertRaisesOpError("Condition x >= 0"): _, prob = distribution_util.get_logits_and_probs( probs=p2, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_probs( probs=p2, validate_args=False) prob.eval() with self.assertRaisesOpError( "probs has components greater than 1"): _, prob = distribution_util.get_logits_and_probs( probs=p3, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_probs( probs=p3, validate_args=False) prob.eval()
def testImproperArguments(self): with self.test_session(): with self.assertRaises(ValueError): distribution_util.get_logits_and_probs(logits=None, probs=None) with self.assertRaises(ValueError): distribution_util.get_logits_and_probs(logits=[0.1], probs=[0.1])
def testLogitsMultidimShape(self): with self.test_session(): with self.assertRaises(ValueError): l = array_ops.ones([int(2**11 + 1)], dtype=np.float16) distribution_util.get_logits_and_probs(logits=l, multidimensional=True, validate_args=True) with self.assertRaisesOpError( "Number of classes exceeds `dtype` precision"): l = array_ops.placeholder(dtype=dtypes.float16) logit, _ = distribution_util.get_logits_and_probs( logits=l, multidimensional=True, validate_args=True) logit.eval(feed_dict={l: np.ones([int(2**11 + 1)])})
def __init__( self, logits=None, probs=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="OneHotCategorical"): """Initialize OneHotCategorical distributions using class log-probabilities. Args: logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of Categorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set of Categorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. dtype: The type of the event samples (default: int32). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, multidimensional=True) logits_shape_static = self._logits.get_shape().with_rank_at_least(1) if logits_shape_static.ndims is not None: self._batch_rank = ops.convert_to_tensor( logits_shape_static.ndims - 1, dtype=dtypes.int32, name="batch_rank") else: with ops.name_scope(name="batch_rank"): self._batch_rank = array_ops.rank(self._logits) - 1 with ops.name_scope(name="event_size"): self._event_size = array_ops.shape(self._logits)[-1] super(OneHotCategorical, self).__init__( dtype=dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._probs], name=name)
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="Multinomial"): """Initialize a batch of Multinomial distributions. Args: total_count: Non-negative floating point tensor with shape broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Multinomial distributions. Its components should be equal to integer values. logits: Floating point tensor representing unnormalized log-probabilities of a positive event with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines this as a batch of `N1 x ... x Nm` different `K` class Multinomial distributions. Only one of `logits` or `probs` should be passed in. probs: Positive floating point tensor with shape broadcastable to `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines this as a batch of `N1 x ... x Nm` different `K` class Multinomial distributions. `probs`'s components in the last portion of its shape should sum to `1`. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = locals() with ops.name_scope(name, values=[total_count, logits, probs]): self._total_count = ops.convert_to_tensor(total_count, name="total_count") if validate_args: self._total_count = ( distribution_util.embed_check_nonnegative_integer_form( self._total_count)) self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, multidimensional=True, validate_args=validate_args, name=name) self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs super(Multinomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._logits, self._probs], name=name)
def __init__(self, total_count, logits=None, probs=None, concentrations=None, validate_args=False, allow_nan_stats=True, name="BetaBinomial"): """Initialize a batch of BetaBinomial distributions. Args: total_count: Non-negative floating point tensor with shape broadcastable to `[N1,..., Nm]` with `m >= 0` and the same dtype as `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Binomial distributions. Its components should be equal to integer values. logits: Floating point tensor representing the log-odds of a positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and the same dtype as `total_count`. Each entry represents logits for the probability of success for independent Binomial distributions. Only one of `logits` or `probs` should be passed in. probs: Positive floating point tensor with shape broadcastable to `[N1,..., Nm]` `m >= 0`, `probs in [0, 1]`. Each entry represents the probability of success for independent Binomial distributions. Only one of `logits` or `probs` should be passed in. concentrations: Positive floating point tensor with shape broadcastable to `[N1,..., Nm]` `m >= 0`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs, concentrations]) as name: self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), validate_args) self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, name=name) self._concentrations = concentrations super(distributions_BetaBinomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._logits, self._probs, self._concentrations], name=name)
def testProbability(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) with self.test_session(): new_logits, new_p = distribution_util.get_logits_and_probs( probs=p, validate_args=True) self.assertAllClose(_logit(p), new_logits.eval()) self.assertAllClose(p, new_p.eval())
def testProbabilityMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) with self.test_session(): new_logits, new_p = distribution_util.get_logits_and_probs( probs=p, multidimensional=True, validate_args=True) self.assertAllClose(np.log(p), new_logits.eval()) self.assertAllClose(p, new_p.eval())
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="NegativeBinomial"): """Construct NegativeBinomial distributions. Args: total_count: Non-negative floating-point `Tensor` with shape broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Negative Binomial distributions. In practice, this represents the number of negative Bernoulli trials to stop at (the `total_count` of failures), but this is still a valid distribution when `total_count` is a non-integer. logits: Floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Negative Binomial distributions and must be in the open interval `(-inf, inf)`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Negative Binomial distributions and must be in the open interval `(0, 1)`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = locals() with ops.name_scope(name, values=[total_count, logits, probs]): self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( [check_ops.assert_positive(total_count )] if validate_args else []): self._total_count = array_ops.identity(total_count) super(NegativeBinomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._probs, self._logits], name=name)
def testLogits(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) logits = _logit(p) with self.test_session(): new_logits, new_p = distribution_util.get_logits_and_probs( logits=logits, validate_args=True) self.assertAllClose(p, new_p.eval(), rtol=1e-5, atol=0.) self.assertAllClose(logits, new_logits.eval(), rtol=1e-5, atol=0.)
def testLogitsMultidimensional(self): p = np.array([0.2, 0.3, 0.5], dtype=np.float32) logits = np.log(p) with self.test_session(): new_logits, new_p = distribution_util.get_logits_and_probs( logits=logits, multidimensional=True, validate_args=True) self.assertAllClose(new_p.eval(), p) self.assertAllClose(new_logits.eval(), logits)
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="NegativeBinomial"): """Construct NegativeBinomial distributions. Args: total_count: Non-negative floating-point `Tensor` with shape broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Negative Binomial distributions. In practice, this represents the number of negative Bernoulli trials to stop at (the `total_count` of failures), but this is still a valid distribution when `total_count` is a non-integer. logits: Floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Negative Binomial distributions and must be in the open interval `(-inf, inf)`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Negative Binomial distributions and must be in the open interval `(0, 1)`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( [check_ops.assert_positive(total_count)] if validate_args else []): self._total_count = array_ops.identity(total_count) super(NegativeBinomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._probs, self._logits], name=name)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies( [check_ops.assert_positive(temperature )] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args) super(RelaxedBernoulli, self).__init__(distribution=logistic.Logistic( self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name + "/Logistic"), bijector=Sigmoid(validate_args=validate_args), validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="Binomial"): """Initialize a batch of Binomial distributions. Args: total_count: Non-negative floating point tensor with shape broadcastable to `[N1,..., Nm]` with `m >= 0` and the same dtype as `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Binomial distributions. Its components should be equal to integer values. logits: Floating point tensor representing the log-odds of a positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and the same dtype as `total_count`. Each entry represents logits for the probability of success for independent Binomial distributions. Only one of `logits` or `probs` should be passed in. probs: Positive floating point tensor with shape broadcastable to `[N1,..., Nm]` `m >= 0`, `probs in [0, 1]`. Each entry represents the probability of success for independent Binomial distributions. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = locals() with ops.name_scope(name, values=[total_count, logits, probs]): self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), validate_args) self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, name=name) super(Binomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._logits, self._probs], name=name)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = dict(locals()) with tf.name_scope(name, values=[logits, probs, temperature]) as name: with tf.control_dependencies([tf.assert_positive(temperature)] if validate_args else []): self._temperature = tf.identity(temperature, name="temperature") self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args) super(RelaxedBernoulli, self).__init__( distribution=Logistic( self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name + "/Logistic"), bijector=Sigmoid(validate_args=validate_args), validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="Geometric"): """Construct Geometric distributions. Args: logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Geometric distributions and must be in the range `(-inf, inf]`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Geometric distributions and must be in the range `(0, 1]`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = locals() with ops.name_scope(name, values=[logits, probs]): self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( [check_ops.assert_positive(self._probs )] if validate_args else []): self._probs = array_ops.identity(self._probs, name="probs") super(Geometric, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._probs, self._logits], name=name)
def __init__(self, logits=None, probs=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="Bernoulli"): """Construct Bernoulli distributions. Args: logits: An N-D `Tensor` representing the log-odds of a `1` event. Each entry in the `Tensor` parametrizes an independent Bernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a `1` event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. dtype: The type of the event samples. Default: `int32`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If p and logits are passed, or if neither are passed. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, name=name) super(Bernoulli, self).__init__( dtype=dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._probs], name=name)
def __init__(self, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="Geometric"): """Construct Geometric distributions. Args: logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Geometric distributions and must be in the range `(-inf, inf]`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Geometric distributions and must be in the range `(0, 1]`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( [check_ops.assert_positive(self._probs)] if validate_args else []): self._probs = array_ops.identity(self._probs, name="probs") super(Geometric, self).__init__( dtype=self._probs.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._probs, self._logits], name=name)
def __init__(self, logits=None, probs=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="Bernoulli"): """Construct Bernoulli distributions. Args: logits: An N-D `Tensor` representing the log-odds of a `1` event. Each entry in the `Tensor` parametrizes an independent Bernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a `1` event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. dtype: The type of the event samples. Default: `int32`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If p and logits are passed, or if neither are passed. """ parameters = dict(locals()) with ops.name_scope(name) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, name=name) super(Bernoulli, self).__init__( dtype=dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._probs], name=name)
def testProbabilityValidateArgsMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) # Component less than 0. Still sums to 1. p2 = np.array([[-.3, 0.4, 0.9], [0.1, 0.5, 0.4]], dtype=np.float32) # Component greater than 1. Does not sum to 1. p3 = np.array([[1.3, 0.0, 0.0], [0.1, 0.5, 0.4]], dtype=np.float32) # Does not sum to 1. p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32) with self.test_session(): _, prob = distribution_util.get_logits_and_probs( probs=p, multidimensional=True) prob.eval() with self.assertRaisesOpError("Condition x >= 0"): _, prob = distribution_util.get_logits_and_probs( probs=p2, multidimensional=True, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_probs( probs=p2, multidimensional=True, validate_args=False) prob.eval() with self.assertRaisesOpError( "(probs has components greater than 1|probs does not sum to 1)" ): _, prob = distribution_util.get_logits_and_probs( probs=p3, multidimensional=True, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_probs( probs=p3, multidimensional=True, validate_args=False) prob.eval() with self.assertRaisesOpError("probs does not sum to 1"): _, prob = distribution_util.get_logits_and_probs( probs=p4, multidimensional=True, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_probs( probs=p4, multidimensional=True, validate_args=False) prob.eval()