Exemple #1
0
 def testDTypes(self):
     # check that sampling and log_prob work for a range of dtypes
     with self.cached_session():
         for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
             logits = random_ops.random_uniform(shape=[3, 3], dtype=dtype)
             dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(
                 temperature=0.5, logits=logits)
             dist.log_prob(dist.sample())
Exemple #2
0
 def testLogits(self):
     temperature = 1.0
     logits = [2.0, 3.0, -4.0]
     dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(
         temperature, logits)
     with self.cached_session():
         # check p for ExpRelaxed base distribution
         self.assertAllClose(logits, dist._distribution.logits.eval())
         self.assertAllEqual([3], dist._distribution.logits.get_shape())
Exemple #3
0
def make_relaxed_categorical(batch_shape, num_classes, dtype=dtypes.float32):
    logits = random_ops.random_uniform(
        list(batch_shape) + [num_classes], -10, 10, dtype=dtype) - 50.
    temperatures = random_ops.random_uniform(list(batch_shape),
                                             0.1,
                                             10,
                                             dtype=dtypes.float32)
    return relaxed_onehot_categorical.RelaxedOneHotCategorical(temperatures,
                                                               logits,
                                                               dtype=dtype)
Exemple #4
0
 def testSample(self):
     temperature = 1.4
     with self.cached_session():
         # single logit
         logits = [.3, .1, .4]
         dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(
             temperature, logits)
         self.assertAllEqual([3], dist.sample().eval().shape)
         self.assertAllEqual([5, 3], dist.sample(5).eval().shape)
         # multiple distributions
         logits = [[2.0, 3.0, -4.0], [.3, .1, .4]]
         dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(
             temperature, logits)
         self.assertAllEqual([2, 3], dist.sample().eval().shape)
         self.assertAllEqual([5, 2, 3], dist.sample(5).eval().shape)
         # multiple distributions
         logits = np.random.uniform(size=(4, 1, 3)).astype(np.float32)
         dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(
             temperature, logits)
         self.assertAllEqual([4, 1, 3], dist.sample().eval().shape)
         self.assertAllEqual([5, 4, 1, 3], dist.sample(5).eval().shape)
Exemple #5
0
    def testPdf(self):
        def analytical_pdf(x, temperature, logits):
            # analytical density of RelaxedOneHotCategorical
            temperature = np.reshape(temperature, (-1, 1))
            if len(x.shape) == 1:
                x = np.expand_dims(x, 0)
            k = logits.shape[1]
            p = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
            term1 = gamma(k) * np.power(temperature, k - 1)
            term2 = np.sum(p / (np.power(x, temperature)),
                           axis=1,
                           keepdims=True)
            term3 = np.prod(p / (np.power(x, temperature + 1)),
                            axis=1,
                            keepdims=True)
            expected_pdf = term1 * np.power(term2, -k) * term3
            return expected_pdf

        with self.cached_session():
            temperature = .4
            logits = np.array([[.3, .1, .4]]).astype(np.float32)
            dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(
                temperature, logits)
            x = dist.sample().eval()
            pdf = dist.prob(x).eval()
            expected_pdf = analytical_pdf(x, temperature, logits)
            self.assertAllClose(expected_pdf.flatten(), pdf, rtol=1e-4)

            # variable batch size
            logits = np.array([[.3, .1, .4], [.6, -.1, 2.]]).astype(np.float32)
            temperatures = np.array([0.4, 2.3]).astype(np.float32)
            dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(
                temperatures, logits)
            x = dist.sample().eval()
            pdf = dist.prob(x).eval()
            expected_pdf = analytical_pdf(x, temperatures, logits)
            self.assertAllClose(expected_pdf.flatten(), pdf, rtol=1e-4)