def testLogisticVariance(self): with self.test_session(): loc = [2.0, 1.5, 1.0] scale = 1.5 expected_variance = stats.logistic.var(loc, scale) dist = logistic.Logistic(loc, scale) self.assertAllClose(dist.variance().eval(), expected_variance)
def testLogisticMean(self): with self.test_session(): loc = [2.0, 1.5, 1.0] scale = 1.5 expected_mean = stats.logistic.mean(loc, scale) dist = logistic.Logistic(loc, scale) self.assertAllClose(dist.mean().eval(), expected_mean)
def testReparameterizable(self): batch_size = 6 np_loc = np.array([2.0] * batch_size, dtype=np.float32) loc = constant_op.constant(np_loc) scale = 1.5 dist = logistic.Logistic(loc, scale) self.assertTrue( dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = locals() with ops.name_scope(name, values=[logits, probs, temperature]) as ns: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args) dist = logistic.Logistic(self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) self._parameters = parameters def inverse_log_det_jacobian_fn(y): return -math_ops.log(y) - math_ops.log1p(-y) sigmoid_bijector = bijectors.Inline( forward_fn=math_ops.sigmoid, inverse_fn=(lambda y: math_ops.log(y) - math_ops.log1p(-y)), inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn, name="sigmoid") super(RelaxedBernoulli, self).__init__(dist, sigmoid_bijector, name=name)
def testLogisticSample(self): with self.test_session(): loc = [3.0, 4.0, 2.0] scale = 1.0 dist = logistic.Logistic(loc, scale) sample = dist.sample(seed=100) self.assertEqual(sample.get_shape(), (3, )) self.assertAllClose(sample.eval(), [6.22460556, 3.79602098, 2.05084133])
def testLogisticEntropy(self): with self.test_session(): batch_size = 3 np_loc = np.array([2.0] * batch_size, dtype=np.float32) loc = constant_op.constant(np_loc) scale = 1.5 expected_entropy = stats.logistic.entropy(np_loc, scale) dist = logistic.Logistic(loc, scale) self.assertAllClose(dist.entropy().eval(), expected_entropy)
def testDtype(self): loc = constant_op.constant([0.1, 0.4], dtype=dtypes.float32) scale = constant_op.constant(1.0, dtype=dtypes.float32) dist = logistic.Logistic(loc, scale) self.assertEqual(dist.dtype, dtypes.float32) self.assertEqual(dist.loc.dtype, dist.scale.dtype) self.assertEqual(dist.dtype, dist.sample(5).dtype) self.assertEqual(dist.dtype, dist.mode().dtype) self.assertEqual(dist.loc.dtype, dist.mean().dtype) self.assertEqual(dist.loc.dtype, dist.variance().dtype) self.assertEqual(dist.loc.dtype, dist.stddev().dtype) self.assertEqual(dist.loc.dtype, dist.entropy().dtype) self.assertEqual(dist.loc.dtype, dist.prob(0.2).dtype) self.assertEqual(dist.loc.dtype, dist.log_prob(0.2).dtype) loc = constant_op.constant([0.1, 0.4], dtype=dtypes.float64) scale = constant_op.constant(1.0, dtype=dtypes.float64) dist64 = logistic.Logistic(loc, scale) self.assertEqual(dist64.dtype, dtypes.float64) self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
def testLogisticLogCDF(self): with self.test_session(): batch_size = 6 np_loc = np.array([2.0] * batch_size, dtype=np.float32) loc = constant_op.constant(np_loc) scale = 1.5 dist = logistic.Logistic(loc, scale) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) logcdf = dist.log_cdf(x) expected_logcdf = stats.logistic.logcdf(x, np_loc, scale) self.assertEqual(logcdf.get_shape(), (6, )) self.assertAllClose(logcdf.eval(), expected_logcdf)
def testLogisticSurvivalFunction(self): with self.test_session(): batch_size = 6 np_loc = np.array([2.0] * batch_size, dtype=np.float32) loc = constant_op.constant(np_loc) scale = 1.5 dist = logistic.Logistic(loc, scale) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) survival_function = dist.survival_function(x) expected_survival_function = stats.logistic.sf(x, np_loc, scale) self.assertEqual(survival_function.get_shape(), (6, )) self.assertAllClose(survival_function.eval(), expected_survival_function)
def testLogisticLogProb(self): with self.test_session(): batch_size = 6 np_loc = np.array([2.0] * batch_size, dtype=np.float32) loc = constant_op.constant(np_loc) scale = 1.5 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) dist = logistic.Logistic(loc, scale) expected_log_prob = stats.logistic.logpdf(x, np_loc, scale) log_prob = dist.log_prob(x) self.assertEqual(log_prob.get_shape(), (6, )) self.assertAllClose(log_prob.eval(), expected_log_prob) prob = dist.prob(x) self.assertEqual(prob.get_shape(), (6, )) self.assertAllClose(prob.eval(), np.exp(expected_log_prob))