Ejemplo n.º 1
0
 def testLogisticVariance(self):
     with self.test_session():
         loc = [2.0, 1.5, 1.0]
         scale = 1.5
         expected_variance = stats.logistic.var(loc, scale)
         dist = logistic.Logistic(loc, scale)
         self.assertAllClose(dist.variance().eval(), expected_variance)
Ejemplo n.º 2
0
 def testLogisticMean(self):
     with self.test_session():
         loc = [2.0, 1.5, 1.0]
         scale = 1.5
         expected_mean = stats.logistic.mean(loc, scale)
         dist = logistic.Logistic(loc, scale)
         self.assertAllClose(dist.mean().eval(), expected_mean)
Ejemplo n.º 3
0
 def testReparameterizable(self):
     batch_size = 6
     np_loc = np.array([2.0] * batch_size, dtype=np.float32)
     loc = constant_op.constant(np_loc)
     scale = 1.5
     dist = logistic.Logistic(loc, scale)
     self.assertTrue(
         dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED)
Ejemplo n.º 4
0
  def __init__(self,
               temperature,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="RelaxedBernoulli"):
    """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
    parameters = locals()
    with ops.name_scope(name, values=[logits, probs, temperature]) as ns:
      with ops.control_dependencies([check_ops.assert_positive(temperature)]
                                    if validate_args else []):
        self._temperature = array_ops.identity(temperature, name="temperature")

      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits, probs=probs, validate_args=validate_args)
      dist = logistic.Logistic(self._logits / self._temperature,
                               1. / self._temperature,
                               validate_args=validate_args,
                               allow_nan_stats=allow_nan_stats,
                               name=ns)
      self._parameters = parameters

    def inverse_log_det_jacobian_fn(y):
      return -math_ops.log(y) - math_ops.log1p(-y)

    sigmoid_bijector = bijectors.Inline(
        forward_fn=math_ops.sigmoid,
        inverse_fn=(lambda y: math_ops.log(y) - math_ops.log1p(-y)),
        inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn,
        name="sigmoid")
    super(RelaxedBernoulli, self).__init__(dist, sigmoid_bijector, name=name)
Ejemplo n.º 5
0
 def testLogisticSample(self):
     with self.test_session():
         loc = [3.0, 4.0, 2.0]
         scale = 1.0
         dist = logistic.Logistic(loc, scale)
         sample = dist.sample(seed=100)
         self.assertEqual(sample.get_shape(), (3, ))
         self.assertAllClose(sample.eval(),
                             [6.22460556, 3.79602098, 2.05084133])
Ejemplo n.º 6
0
 def testLogisticEntropy(self):
     with self.test_session():
         batch_size = 3
         np_loc = np.array([2.0] * batch_size, dtype=np.float32)
         loc = constant_op.constant(np_loc)
         scale = 1.5
         expected_entropy = stats.logistic.entropy(np_loc, scale)
         dist = logistic.Logistic(loc, scale)
         self.assertAllClose(dist.entropy().eval(), expected_entropy)
Ejemplo n.º 7
0
    def testDtype(self):
        loc = constant_op.constant([0.1, 0.4], dtype=dtypes.float32)
        scale = constant_op.constant(1.0, dtype=dtypes.float32)
        dist = logistic.Logistic(loc, scale)
        self.assertEqual(dist.dtype, dtypes.float32)
        self.assertEqual(dist.loc.dtype, dist.scale.dtype)
        self.assertEqual(dist.dtype, dist.sample(5).dtype)
        self.assertEqual(dist.dtype, dist.mode().dtype)
        self.assertEqual(dist.loc.dtype, dist.mean().dtype)
        self.assertEqual(dist.loc.dtype, dist.variance().dtype)
        self.assertEqual(dist.loc.dtype, dist.stddev().dtype)
        self.assertEqual(dist.loc.dtype, dist.entropy().dtype)
        self.assertEqual(dist.loc.dtype, dist.prob(0.2).dtype)
        self.assertEqual(dist.loc.dtype, dist.log_prob(0.2).dtype)

        loc = constant_op.constant([0.1, 0.4], dtype=dtypes.float64)
        scale = constant_op.constant(1.0, dtype=dtypes.float64)
        dist64 = logistic.Logistic(loc, scale)
        self.assertEqual(dist64.dtype, dtypes.float64)
        self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
Ejemplo n.º 8
0
    def testLogisticLogCDF(self):
        with self.test_session():
            batch_size = 6
            np_loc = np.array([2.0] * batch_size, dtype=np.float32)
            loc = constant_op.constant(np_loc)
            scale = 1.5

            dist = logistic.Logistic(loc, scale)
            x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
            logcdf = dist.log_cdf(x)
            expected_logcdf = stats.logistic.logcdf(x, np_loc, scale)

            self.assertEqual(logcdf.get_shape(), (6, ))
            self.assertAllClose(logcdf.eval(), expected_logcdf)
Ejemplo n.º 9
0
    def testLogisticSurvivalFunction(self):
        with self.test_session():
            batch_size = 6
            np_loc = np.array([2.0] * batch_size, dtype=np.float32)
            loc = constant_op.constant(np_loc)
            scale = 1.5

            dist = logistic.Logistic(loc, scale)
            x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
            survival_function = dist.survival_function(x)
            expected_survival_function = stats.logistic.sf(x, np_loc, scale)

            self.assertEqual(survival_function.get_shape(), (6, ))
            self.assertAllClose(survival_function.eval(),
                                expected_survival_function)
Ejemplo n.º 10
0
    def testLogisticLogProb(self):
        with self.test_session():
            batch_size = 6
            np_loc = np.array([2.0] * batch_size, dtype=np.float32)
            loc = constant_op.constant(np_loc)
            scale = 1.5
            x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
            dist = logistic.Logistic(loc, scale)
            expected_log_prob = stats.logistic.logpdf(x, np_loc, scale)

            log_prob = dist.log_prob(x)
            self.assertEqual(log_prob.get_shape(), (6, ))
            self.assertAllClose(log_prob.eval(), expected_log_prob)

            prob = dist.prob(x)
            self.assertEqual(prob.get_shape(), (6, ))
            self.assertAllClose(prob.eval(), np.exp(expected_log_prob))