def testGetLogitsAndProbProbabilityValidateArgs(self): p = [0.01, 0.2, 0.5, 0.7, .99] # Component less than 0. p2 = [-1, 0.2, 0.5, 0.3, .2] # Component greater than 1. p3 = [2, 0.2, 0.5, 0.3, .2] with self.test_session(): _, prob = distribution_util.get_logits_and_prob(p=p, validate_args=True) prob.eval() with self.assertRaisesOpError("Condition x >= 0"): _, prob = distribution_util.get_logits_and_prob( p=p2, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_prob(p=p2, validate_args=False) prob.eval() with self.assertRaisesOpError("p has components greater than 1"): _, prob = distribution_util.get_logits_and_prob( p=p3, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_prob(p=p3, validate_args=False) prob.eval()
def testGetLogitsAndProbProbabilityValidateArgs(self): p = [0.01, 0.2, 0.5, 0.7, .99] # Component less than 0. p2 = [-1, 0.2, 0.5, 0.3, .2] # Component greater than 1. p3 = [2, 0.2, 0.5, 0.3, .2] with self.test_session(): _, prob = distribution_util.get_logits_and_prob(p=p, validate_args=True) prob.eval() with self.assertRaisesOpError("Condition x >= 0"): _, prob = distribution_util.get_logits_and_prob( p=p2, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_prob( p=p2, validate_args=False) prob.eval() with self.assertRaisesOpError("p has components greater than 1"): _, prob = distribution_util.get_logits_and_prob( p=p3, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_prob( p=p3, validate_args=False) prob.eval()
def testGetLogitsAndProbImproperArguments(self): with self.test_session(): with self.assertRaises(ValueError): distribution_util.get_logits_and_prob(logits=None, p=None) with self.assertRaises(ValueError): distribution_util.get_logits_and_prob(logits=[0.1], p=[0.1])
def __init__(self, n, logits=None, p=None, validate_args=False, allow_nan_stats=True, name="Binomial"): """Initialize a batch of Binomial distributions. Args: n: Non-negative floating point tensor with shape broadcastable to `[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Binomial distributions. Its components should be equal to integer values. logits: Floating point tensor representing the log-odds of a positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and the same dtype as `n`. Each entry represents logits for the probability of success for independent Binomial distributions. p: Positive floating point tensor with shape broadcastable to `[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the probability of success for independent Binomial distributions. validate_args: `Boolean`, default `False`. Whether to assert valid values for parameters `n`, `p`, and `x` in `prob` and `log_prob`. If `False` and inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. Examples: ```python # Define 1-batch of a binomial distribution. dist = Binomial(n=2., p=.9) # Define a 2-batch. dist = Binomial(n=[4., 5], p=[.1, .3]) ``` """ self._logits, self._p = distribution_util.get_logits_and_prob( name=name, logits=logits, p=p, validate_args=validate_args) with ops.name_scope(name, values=[n]) as ns: with ops.control_dependencies([ check_ops.assert_non_negative( n, message="n has negative components."), distribution_util.assert_integer_form( n, message="n has non-integer components."), ] if validate_args else []): self._n = array_ops.identity(n, name="n") super(Binomial, self).__init__( dtype=self._p.dtype, parameters={"n": self._n, "p": self._p, "logits": self._logits}, is_continuous=False, is_reparameterized=False, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns)
def testGetLogitsAndProbProbabilityMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) with self.test_session(): new_logits, new_p = distribution_util.get_logits_and_prob( p=p, multidimensional=True, validate_args=True) self.assertAllClose(special.logit(p), new_logits.eval()) self.assertAllClose(p, new_p.eval())
def testGetLogitsAndProbProbability(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) with self.test_session(): new_logits, new_p = distribution_util.get_logits_and_prob( p=p, validate_args=True) self.assertAllClose(special.logit(p), new_logits.eval()) self.assertAllClose(p, new_p.eval())
def testGetLogitsAndProbProbabilityMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) with self.test_session(): new_logits, new_p = distribution_util.get_logits_and_prob( p=p, multidimensional=True, validate_args=True) self.assertAllClose(np.log(p), new_logits.eval()) self.assertAllClose(p, new_p.eval())
def testGetLogitsAndProbLogitsMultidimensional(self): p = np.array([0.2, 0.3, 0.5], dtype=np.float32) logits = np.log(p) with self.test_session(): new_logits, new_p = distribution_util.get_logits_and_prob( logits=logits, multidimensional=True, validate_args=True) self.assertAllClose(new_p.eval(), p) self.assertAllClose(new_logits.eval(), logits)
def __init__(self, logits=None, p=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="Bernoulli"): """Construct Bernoulli distributions. Args: logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent Bernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `p` should be passed in. p: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `p` should be passed in. dtype: dtype for samples. validate_args: `Boolean`, default `False`. Whether to validate that `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are invalid, methods like `log_pmf` may return `NaN` values. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution. Raises: ValueError: If p and logits are passed, or if neither are passed. """ parameters = locals() parameters.pop("self") with ops.name_scope(name) as ns: self._logits, self._p = distribution_util.get_logits_and_prob( logits=logits, p=p, validate_args=validate_args) with ops.name_scope("q"): self._q = 1. - self._p super(Bernoulli, self).__init__(dtype=dtype, is_continuous=False, is_reparameterized=False, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._p, self._q, self._logits], name=ns)
def __init__(self, logits=None, p=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="Bernoulli"): """Construct Bernoulli distributions. Args: logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent Bernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `p` should be passed in. p: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `p` should be passed in. dtype: dtype for samples. validate_args: `Boolean`, default `False`. Whether to validate that `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are invalid, methods like `log_pmf` may return `NaN` values. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution. Raises: ValueError: If p and logits are passed, or if neither are passed. """ parameters = locals() parameters.pop("self") with ops.name_scope(name) as ns: self._logits, self._p = distribution_util.get_logits_and_prob( logits=logits, p=p, validate_args=validate_args) with ops.name_scope("q"): self._q = 1. - self._p super(Bernoulli, self).__init__( dtype=dtype, is_continuous=False, is_reparameterized=False, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._p, self._q, self._logits], name=ns)
def __init__(self, logits=None, p=None, dtype=dtypes.int32, validate_args=True, allow_nan_stats=False, name="Bernoulli"): """Construct Bernoulli distributions. Args: logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent Bernoulli distribution where the probability of an event is sigmoid(logits). p: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. dtype: dtype for samples. validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args, `log_pmf` may return nans. allow_nan_stats: Boolean, default `False`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution. Raises: ValueError: If p and logits are passed, or if neither are passed. """ self._logits, self._p = distribution_util.get_logits_and_prob( name=name, logits=logits, p=p, validate_args=validate_args) with ops.name_scope(name): with ops.name_scope("q"): self._q = 1. - self._p super(Bernoulli, self).__init__(dtype=dtype, parameters={ "p": self._p, "q": self._q, "logits": self._logits }, is_continuous=False, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)
def __init__(self, logits=None, p=None, dtype=dtypes.int32, validate_args=True, allow_nan_stats=False, name="Bernoulli"): """Construct Bernoulli distributions. Args: logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent Bernoulli distribution where the probability of an event is sigmoid(logits). p: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. dtype: dtype for samples. validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args, `log_pmf` may return nans. allow_nan_stats: Boolean, default `False`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution. Raises: ValueError: If p and logits are passed, or if neither are passed. """ self._logits, self._p = distribution_util.get_logits_and_prob( name=name, logits=logits, p=p, validate_args=validate_args) with ops.name_scope(name): with ops.name_scope("q"): self._q = 1. - self._p super(Bernoulli, self).__init__( dtype=dtype, parameters={"p": self._p, "q": self._q, "logits": self._logits}, is_continuous=False, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)
def testGetLogitsAndProbProbabilityValidateArgsMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) # Component less than 0. Still sums to 1. p2 = np.array([[-.3, 0.4, 0.9], [0.1, 0.5, 0.4]], dtype=np.float32) # Component greater than 1. Does not sum to 1. p3 = np.array([[1.3, 0.0, 0.0], [0.1, 0.5, 0.4]], dtype=np.float32) # Does not sum to 1. p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32) with self.test_session(): _, prob = distribution_util.get_logits_and_prob( p=p, multidimensional=True) prob.eval() with self.assertRaisesOpError("Condition x >= 0"): _, prob = distribution_util.get_logits_and_prob( p=p2, multidimensional=True, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_prob( p=p2, multidimensional=True, validate_args=False) prob.eval() with self.assertRaisesOpError( "(p has components greater than 1|p does not sum to 1)"): _, prob = distribution_util.get_logits_and_prob( p=p3, multidimensional=True, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_prob( p=p3, multidimensional=True, validate_args=False) prob.eval() with self.assertRaisesOpError("p does not sum to 1"): _, prob = distribution_util.get_logits_and_prob( p=p4, multidimensional=True, validate_args=True) prob.eval() _, prob = distribution_util.get_logits_and_prob( p=p4, multidimensional=True, validate_args=False) prob.eval()
def __init__(self, temperature, logits=None, p=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `p` should be passed in. p: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `p` should be passed in. validate_args: `Boolean`, default `False`. Whether to validate that `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are invalid, methods like `log_pmf` may return `NaN` values. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution. Raises: ValueError: If p and logits are passed, or if neither are passed. """ parameters = locals() parameters.pop("self") with ops.name_scope(name, values=[logits, p, temperature]) as ns: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._logits, self._p = distribution_util.get_logits_and_prob( logits=logits, p=p, validate_args=validate_args) with ops.name_scope("q"): self._q = 1. - self._p dist = logistic._Logistic(self._logits / self._temperature, 1./self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) def inverse_log_det_jacobian_fn(y): return -math_ops.reduce_sum(math_ops.log(y) + math_ops.log(1-y), reduction_indices=-1) sigmoidbijector = bijector.Inline( forward_fn=math_ops.sigmoid, inverse_fn=(lambda y: math_ops.log(y) - math_ops.log(1-y)), inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn, name="sigmoid") super(_RelaxedBernoulli, self).__init__(dist, sigmoidbijector, name=name)
def __init__(self, n, logits=None, p=None, validate_args=False, allow_nan_stats=True, name="Multinomial"): """Initialize a batch of Multinomial distributions. Args: n: Non-negative floating point tensor with shape broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Multinomial distributions. Its components should be equal to integer values. logits: Floating point tensor representing the log-odds of a positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`, and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm` different `k` class Multinomial distributions. p: Positive floating point tensor with shape broadcastable to `[N1,..., Nm, k]` `m >= 0` and same dtype as `n`. Defines this as a batch of `N1 x ... x Nm` different `k` class Multinomial distributions. `p`'s components in the last portion of its shape should sum up to 1. validate_args: `Boolean`, default `False`. Whether to assert valid values for parameters `n` and `p`, and `x` in `prob` and `log_prob`. If `False`, correct behavior is not guaranteed. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. Examples: ```python # Define 1-batch of 2-class multinomial distribution, # also known as a Binomial distribution. dist = Multinomial(n=2., p=[.1, .9]) # Define a 2-batch of 3-class distributions. dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]]) ``` """ self._logits, self._p = distribution_util.get_logits_and_prob( name=name, logits=logits, p=p, validate_args=validate_args, multidimensional=True) with ops.name_scope(name, values=[n, self._p]) as ns: with ops.control_dependencies([ check_ops.assert_non_negative( n, message="n has negative components."), distribution_util.assert_integer_form( n, message="n has non-integer components.") ] if validate_args else []): self._n = array_ops.identity(n, name="convert_n") self._mean_val = array_ops.expand_dims(n, -1) * self._p self._broadcast_shape = math_ops.reduce_sum( self._mean_val, reduction_indices=[-1], keep_dims=False) super(Multinomial, self).__init__( dtype=self._p.dtype, parameters={"p": self._p, "n": self._n, "mean": self._mean, "logits": self._logits, "broadcast_shape": self._broadcast_shape}, is_continuous=False, is_reparameterized=False, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns)
def __init__( self, logits=None, p=None, dtype=dtypes.int32, validate_args=False, allow_nan_stats=True, name="OneHotCategorical"): """Initialize OneHotCategorical distributions using class log-probabilities. Args: logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of Categorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of `logits` or `p` should be passed in. p: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set of Categorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `p` should be passed in. dtype: The type of the event samples (default: int32). validate_args: Unused in this distribution. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). """ parameters = locals() parameters.pop("self") with ops.name_scope(name, values=[logits]) as ns: self._logits, self._p = distribution_util.get_logits_and_prob( name=name, logits=logits, p=p, validate_args=validate_args, multidimensional=True) logits_shape_static = self._logits.get_shape().with_rank_at_least(1) if logits_shape_static.ndims is not None: self._batch_rank = ops.convert_to_tensor( logits_shape_static.ndims - 1, dtype=dtypes.int32, name="batch_rank") else: with ops.name_scope(name="batch_rank"): self._batch_rank = array_ops.rank(self._logits) - 1 logits_shape = array_ops.shape(self._logits, name="logits_shape") if logits_shape_static[-1].value is not None: self._num_classes = ops.convert_to_tensor( logits_shape_static[-1].value, dtype=dtypes.int32, name="num_classes") else: self._num_classes = array_ops.gather(logits_shape, self._batch_rank, name="num_classes") if logits_shape_static[:-1].is_fully_defined(): self._batch_shape_val = constant_op.constant( logits_shape_static[:-1].as_list(), dtype=dtypes.int32, name="batch_shape") else: with ops.name_scope(name="batch_shape"): self._batch_shape_val = logits_shape[:-1] super(_OneHotCategorical, self).__init__( dtype=dtype, is_continuous=False, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._num_classes], name=ns)
def __init__(self, temperature, logits=None, p=None, dtype=dtypes.float32, validate_args=False, allow_nan_stats=True, name="ExpRelaxedOneHotCategorical"): """Initialize ExpRelaxedOneHotCategorical using class log-probabilities. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of ExpRelaxedCategorical distributions. The temperature should be positive. logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of ExpRelaxedCategorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of `logits` or `p` should be passed in. p: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set of ExpRelaxedCategorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `p` should be passed in. dtype: The type of the event samples (default: int32). validate_args: Unused in this distribution. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). """ parameters = locals() parameters.pop("self") with ops.name_scope(name, values=[logits, p, temperature]) as ns: with ops.control_dependencies( [check_ops.assert_positive(temperature )] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._logits, self._p = distribution_util.get_logits_and_prob( name=name, logits=logits, p=p, validate_args=validate_args, multidimensional=True) logits_shape_static = self._logits.get_shape().with_rank_at_least( 1) if logits_shape_static.ndims is not None: self._batch_rank = ops.convert_to_tensor( logits_shape_static.ndims - 1, dtype=dtypes.int32, name="batch_rank") else: with ops.name_scope(name="batch_rank"): self._batch_rank = array_ops.rank(self._logits) - 1 logits_shape = array_ops.shape(self._logits, name="logits_shape") if logits_shape_static[-1].value is not None: self._num_classes = ops.convert_to_tensor( logits_shape_static[-1].value, dtype=dtypes.int32, name="num_classes") else: self._num_classes = array_ops.gather(logits_shape, self._batch_rank, name="num_classes") if logits_shape_static[:-1].is_fully_defined(): self._batch_shape_val = constant_op.constant( logits_shape_static[:-1].as_list(), dtype=dtypes.int32, name="batch_shape") else: with ops.name_scope(name="batch_shape"): self._batch_shape_val = logits_shape[:-1] super(_ExpRelaxedOneHotCategorical, self).__init__( dtype=dtype, is_continuous=True, is_reparameterized=True, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._temperature, self._num_classes], name=ns)
def __init__(self, temperature, logits=None, p=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `p` should be passed in. p: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `p` should be passed in. validate_args: `Boolean`, default `False`. Whether to validate that `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are invalid, methods like `log_pmf` may return `NaN` values. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution. Raises: ValueError: If p and logits are passed, or if neither are passed. """ parameters = locals() parameters.pop("self") with ops.name_scope(name, values=[logits, p, temperature]) as ns: with ops.control_dependencies( [check_ops.assert_positive(temperature )] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._logits, self._p = distribution_util.get_logits_and_prob( logits=logits, p=p, validate_args=validate_args) with ops.name_scope("q"): self._q = 1. - self._p dist = logistic._Logistic(self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) def inverse_log_det_jacobian_fn(y): return -math_ops.reduce_sum(math_ops.log(y) + math_ops.log(1 - y), reduction_indices=-1) sigmoidbijector = bijector.Inline( forward_fn=math_ops.sigmoid, inverse_fn=(lambda y: math_ops.log(y) - math_ops.log(1 - y)), inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn, name="sigmoid") super(_RelaxedBernoulli, self).__init__(dist, sigmoidbijector, name=name)