def testInvalidP(self):
    temperature = 1.0
    invalid_ps = [1.01, 2.]
    for p in invalid_ps:
      with self.cached_session():
        with self.assertRaisesOpError("probs has components greater than 1"):
          dist = relaxed_bernoulli.RelaxedBernoulli(temperature,
                                                    probs=p,
                                                    validate_args=True)
          dist.probs.eval()

    invalid_ps = [-0.01, -3.]
    for p in invalid_ps:
      with self.cached_session():
        with self.assertRaisesOpError("Condition x >= 0"):
          dist = relaxed_bernoulli.RelaxedBernoulli(temperature,
                                                    probs=p,
                                                    validate_args=True)
          dist.probs.eval()

    valid_ps = [0.0, 0.5, 1.0]
    for p in valid_ps:
      with self.cached_session():
        dist = relaxed_bernoulli.RelaxedBernoulli(temperature,
                                                  probs=p)
        self.assertEqual(p, dist.probs.eval())
  def testLogits(self):
    temperature = 2.0
    logits = [-42., 42.]
    dist = relaxed_bernoulli.RelaxedBernoulli(temperature, logits=logits)
    with self.cached_session():
      self.assertAllClose(logits, dist.logits.eval())

    with self.cached_session():
      self.assertAllClose(scipy.special.expit(logits), dist.probs.eval())

    p = [0.01, 0.99, 0.42]
    dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
    with self.cached_session():
      self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
  def testDtype(self):
    temperature = constant_op.constant(1.0, dtype=dtypes.float32)
    p = constant_op.constant([0.1, 0.4], dtype=dtypes.float32)
    dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
    self.assertEqual(dist.dtype, dtypes.float32)
    self.assertEqual(dist.dtype, dist.sample(5).dtype)
    self.assertEqual(dist.probs.dtype, dist.prob([0.0]).dtype)
    self.assertEqual(dist.probs.dtype, dist.log_prob([0.0]).dtype)

    temperature = constant_op.constant(1.0, dtype=dtypes.float64)
    p = constant_op.constant([0.1, 0.4], dtype=dtypes.float64)
    dist64 = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
    self.assertEqual(dist64.dtype, dtypes.float64)
    self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
 def testP(self):
   """Tests that parameter P is set correctly. Note that dist.p != dist.pdf."""
   temperature = 1.0
   p = [0.1, 0.4]
   dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
   with self.cached_session():
     self.assertAllClose(p, dist.probs.eval())
Beispiel #5
0
 def testZeroTemperature(self):
     """If validate_args, raises InvalidArgumentError when temperature is 0."""
     temperature = constant_op.constant(0.0)
     p = constant_op.constant([0.1, 0.4])
     with self.assertRaisesWithPredicateMatch(
             errors_impl.InvalidArgumentError, "x > 0 did not hold"):
         _ = relaxed_bernoulli.RelaxedBernoulli(temperature,
                                                probs=p,
                                                validate_args=True)
 def testZeroTemperature(self):
   """If validate_args, raises InvalidArgumentError when temperature is 0."""
   temperature = constant_op.constant(0.0)
   p = constant_op.constant([0.1, 0.4])
   dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p,
                                             validate_args=True)
   with self.cached_session():
     sample = dist.sample()
     with self.assertRaises(errors_impl.InvalidArgumentError):
       sample.eval()
 def testShapes(self):
   with self.cached_session():
     for batch_shape in ([], [1], [2, 3, 4]):
       temperature = 1.0
       p = np.random.random(batch_shape).astype(np.float32)
       dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
       self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
       self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval())
       self.assertAllEqual([], dist.event_shape.as_list())
       self.assertAllEqual([], dist.event_shape_tensor().eval())
 def testLogProb(self):
   with self.cached_session():
     t = np.array(1.0, dtype=np.float64)
     p = np.array(0.1, dtype=np.float64)  # P(x=1)
     dist = relaxed_bernoulli.RelaxedBernoulli(t, probs=p)
     xs = np.array([0.1, 0.3, 0.5, 0.9], dtype=np.float64)
     # analytical density from Maddison et al. 2016
     alpha = np.array(p/(1-p), dtype=np.float64)
     expected_log_pdf = (np.log(t) + np.log(alpha) +
                         (-t-1)*(np.log(xs)+np.log(1-xs)) -
                         2*np.log(alpha*np.power(xs, -t) + np.power(1-xs, -t)))
     log_pdf = dist.log_prob(xs).eval()
     self.assertAllClose(expected_log_pdf, log_pdf)
  def testSampleN(self):
    """mean of quantized samples still approximates the Bernoulli mean."""
    with self.cached_session():
      temperature = 1e-2
      p = [0.2, 0.6, 0.5]
      dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
      n = 10000
      samples = dist.sample(n)
      self.assertEqual(samples.dtype, dtypes.float32)
      sample_values = samples.eval()
      self.assertTrue(np.all(sample_values >= 0))
      self.assertTrue(np.all(sample_values <= 1))

      frac_ones_like = np.sum(sample_values >= 0.5, axis=0)/n
      self.assertAllClose(p, frac_ones_like, atol=1e-2)
 def testBoundaryConditions(self):
   with self.cached_session():
     temperature = 1e-2
     dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=1.0)
     self.assertAllClose(np.nan, dist.log_prob(0.0).eval())
     self.assertAllClose([np.nan], [dist.log_prob(1.0).eval()])
Beispiel #11
0
 def testContinuous(self):
     temperature = 1.0
     p = [0.1, 0.4]
     dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
     self.assertTrue(dist.is_continuous)