def test_scalar_distributions(self): self.dist1 = tfd.Normal( loc=self.maybe_static( tf.zeros(self.batch_dim_1, dtype=self.dtype), self.is_static), scale=self.maybe_static( tf.ones(self.batch_dim_1, dtype=self.dtype), self.is_static) ) self.dist2 = tfd.Logistic( loc=self.maybe_static( tf.zeros(self.batch_dim_2, dtype=self.dtype), self.is_static), scale=self.maybe_static( tf.ones(self.batch_dim_2, dtype=self.dtype), self.is_static) ) self.dist3 = tfd.Exponential( rate=self.maybe_static( tf.ones(self.batch_dim_3, dtype=self.dtype), self.is_static) ) concat_dist = batch_concat.BatchConcat( distributions=[self.dist1, self.dist2, self.dist3], axis=1, validate_args=False) self.assertAllEqual( self.evaluate(concat_dist.batch_shape_tensor()), [2, 6, 4]) seed = test_util.test_seed() samples = concat_dist.sample(seed=seed) self.assertAllEqual(self.evaluate(tf.shape(samples)), [2, 6, 4])
def test_batch_concat_of_concat(self): concat_dist_1 = self.get_distributions() concat_dist_2 = self.get_distributions() concat_concat = batch_concat.BatchConcat( [concat_dist_1, concat_dist_2], axis=0) self.assertAllEqual(self.evaluate(concat_concat.batch_shape_tensor()), [4, 6, 4]) x_sample = tf.zeros([32, 4, 6, 4, 2]) self.assertAllEqual( self.evaluate(tf.shape(concat_concat.log_prob(x_sample))), [32, 4, 6, 4])
def get_distributions(self, validate_args=False): self.dist1 = tfd.MultivariateNormalDiag(loc=self.maybe_static( tf.zeros(self.batch_dim_1 + self.event_dim_1, dtype=self.dtype), self.is_static), scale_diag=self.maybe_static( tf.ones(self.batch_dim_1 + self.event_dim_1, dtype=self.dtype), self.is_static)) self.dist2 = tfd.OneHotCategorical(logits=self.maybe_static( tf.zeros(self.batch_dim_2 + self.event_dim_2), self.is_static), dtype=self.dtype) self.dist3 = tfd.Dirichlet( self.maybe_static( tf.zeros(self.batch_dim_3 + self.event_dim_3, dtype=self.dtype), self.is_static)) return batch_concat.BatchConcat( distributions=[self.dist1, self.dist2, self.dist3], axis=self.axis, validate_args=validate_args)