def test_scalar_distributions(self):
    self.dist1 = tfd.Normal(
        loc=self.maybe_static(
            tf.zeros(self.batch_dim_1, dtype=self.dtype),
            self.is_static),
        scale=self.maybe_static(
            tf.ones(self.batch_dim_1, dtype=self.dtype),
            self.is_static)
    )
    self.dist2 = tfd.Logistic(
        loc=self.maybe_static(
            tf.zeros(self.batch_dim_2, dtype=self.dtype),
            self.is_static),
        scale=self.maybe_static(
            tf.ones(self.batch_dim_2, dtype=self.dtype),
            self.is_static)
    )
    self.dist3 = tfd.Exponential(
        rate=self.maybe_static(
            tf.ones(self.batch_dim_3, dtype=self.dtype),
            self.is_static)
    )
    concat_dist = batch_concat.BatchConcat(
        distributions=[self.dist1, self.dist2, self.dist3], axis=1,
        validate_args=False)
    self.assertAllEqual(
        self.evaluate(concat_dist.batch_shape_tensor()),
        [2, 6, 4])

    seed = test_util.test_seed()
    samples = concat_dist.sample(seed=seed)
    self.assertAllEqual(self.evaluate(tf.shape(samples)), [2, 6, 4])
Exemple #2
0
 def test_batch_concat_of_concat(self):
     concat_dist_1 = self.get_distributions()
     concat_dist_2 = self.get_distributions()
     concat_concat = batch_concat.BatchConcat(
         [concat_dist_1, concat_dist_2], axis=0)
     self.assertAllEqual(self.evaluate(concat_concat.batch_shape_tensor()),
                         [4, 6, 4])
     x_sample = tf.zeros([32, 4, 6, 4, 2])
     self.assertAllEqual(
         self.evaluate(tf.shape(concat_concat.log_prob(x_sample))),
         [32, 4, 6, 4])
Exemple #3
0
    def get_distributions(self, validate_args=False):
        self.dist1 = tfd.MultivariateNormalDiag(loc=self.maybe_static(
            tf.zeros(self.batch_dim_1 + self.event_dim_1, dtype=self.dtype),
            self.is_static),
                                                scale_diag=self.maybe_static(
                                                    tf.ones(self.batch_dim_1 +
                                                            self.event_dim_1,
                                                            dtype=self.dtype),
                                                    self.is_static))

        self.dist2 = tfd.OneHotCategorical(logits=self.maybe_static(
            tf.zeros(self.batch_dim_2 + self.event_dim_2), self.is_static),
                                           dtype=self.dtype)

        self.dist3 = tfd.Dirichlet(
            self.maybe_static(
                tf.zeros(self.batch_dim_3 + self.event_dim_3,
                         dtype=self.dtype), self.is_static))
        return batch_concat.BatchConcat(
            distributions=[self.dist1, self.dist2, self.dist3],
            axis=self.axis,
            validate_args=validate_args)