def test_raises_if_both_z_and_n_are_not_none(self): with self.cached_session(): dist = normal_lib.Normal(loc=0., scale=1.) z = dist.sample(seed=42) n = 1 seed = None with self.assertRaisesRegexp(ValueError, 'exactly one'): _get_samples(dist, z, n, seed)
def test_returns_z_if_z_provided(self): with self.cached_session(): dist = normal_lib.Normal(loc=0., scale=1.) z = dist.sample(10, seed=42) n = None seed = None z = _get_samples(dist, z, n, seed) self.assertEqual((10, ), z.get_shape())
def test_normal_entropy_analytic_form_uses_exact_entropy(self): with self.test_session(): dist = normal_lib.Normal(loc=1.11, scale=2.22) mc_entropy = entropy.entropy_shannon( dist, form=entropy.ELBOForms.analytic_entropy) exact_entropy = dist.entropy() self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) self.assertAllClose(exact_entropy.eval(), mc_entropy.eval())
def test_docstring_example_normal(self): with self.cached_session() as sess: num_draws = int(1e5) mu_p = constant_op.constant(0.) mu_q = constant_op.constant(1.) p = normal_lib.Normal(loc=mu_p, scale=1.) q = normal_lib.Normal(loc=mu_q, scale=2.) exact_kl_normal_normal = kullback_leibler.kl_divergence(p, q) approx_kl_normal_normal = monte_carlo_lib.expectation( f=lambda x: p.log_prob(x) - q.log_prob(x), samples=p.sample(num_draws, seed=42), log_prob=p.log_prob, use_reparametrization=(p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)) [exact_kl_normal_normal_, approx_kl_normal_normal_ ] = sess.run([exact_kl_normal_normal, approx_kl_normal_normal]) self.assertEqual( True, p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED) self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_, rtol=0.01, atol=0.) # Compare gradients. (Not present in `docstring`.) gradp = lambda fp: gradients_impl.gradients(fp, mu_p)[0] gradq = lambda fq: gradients_impl.gradients(fq, mu_q)[0] [ gradp_exact_kl_normal_normal_, gradq_exact_kl_normal_normal_, gradp_approx_kl_normal_normal_, gradq_approx_kl_normal_normal_, ] = sess.run([ gradp(exact_kl_normal_normal), gradq(exact_kl_normal_normal), gradp(approx_kl_normal_normal), gradq(approx_kl_normal_normal), ]) self.assertAllClose(gradp_exact_kl_normal_normal_, gradp_approx_kl_normal_normal_, rtol=0.01, atol=0.) self.assertAllClose(gradq_exact_kl_normal_normal_, gradq_approx_kl_normal_normal_, rtol=0.01, atol=0.)
def test_returns_n_samples_if_n_provided(self): with self.test_session(): dist = normal_lib.Normal(loc=0., scale=1.) z = None n = 10 seed = None z = _get_samples(dist, z, n, seed) self.assertEqual((10,), z.get_shape())
def __init__(self, mean, variance, targets=None, seed=None): assert len(mean.shape) == 2, "Expect 2D mean tensor." assert len(variance.shape) == 2, "Expect 2D variance tensor." self._mean = mean self._variance = variance self._scale = math_ops.sqrt(variance) dist = normal.Normal(loc=self._mean, scale=self._scale) super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(dist, targets=targets, seed=seed)
def testNormalNormalKL(self): batch_size = 6 mu_a = np.array([3.0] * batch_size) sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) mu_b = np.array([-3.0] * batch_size) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) kl = kullback_leibler.kl_divergence(n_a, n_b) kl_val = self.evaluate(kl) kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) self.assertEqual(kl.get_shape(), (batch_size, )) self.assertAllClose(kl_val, kl_expected)
def testKLIdentity(self): normal1 = normal_lib.Normal(loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])) # This is functionally just a wrapper around normal1, # and doesn't change any outputs. ind1 = independent_lib.Independent(distribution=normal1, reinterpreted_batch_ndims=0) normal2 = normal_lib.Normal(loc=np.float32([-3., 3]), scale=np.float32([0.3, 0.3])) # This is functionally just a wrapper around normal2, # and doesn't change any outputs. ind2 = independent_lib.Independent(distribution=normal2, reinterpreted_batch_ndims=0) normal_kl = kullback_leibler.kl_divergence(normal1, normal2) ind_kl = kullback_leibler.kl_divergence(ind1, ind2) self.assertAllClose(self.evaluate(normal_kl), self.evaluate(ind_kl))
def testNormalVariance(self): # sigma will be broadcast to [7, 7, 7] mu = [1., 2., 3.] sigma = [7.] normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3, ), normal.variance().get_shape()) self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
def testNormalStandardDeviation(self): # sigma will be broadcast to [7, 7, 7] mu = [1., 2., 3.] sigma = [7.] normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3, ), normal.stddev().get_shape()) self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
def testNormalShape(self): mu = constant_op.constant([-3.0] * 5) sigma = constant_op.constant(11.0) normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5]) self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
def testKLScalarToMultivariate(self): normal1 = normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])) ind1 = independent_lib.Independent( distribution=normal1, reinterpreted_batch_ndims=1) normal2 = normal_lib.Normal( loc=np.float32([-3., 3]), scale=np.float32([0.3, 0.3])) ind2 = independent_lib.Independent( distribution=normal2, reinterpreted_batch_ndims=1) normal_kl = kullback_leibler.kl_divergence(normal1, normal2) ind_kl = kullback_leibler.kl_divergence(ind1, ind2) self.assertAllClose( self.evaluate(math_ops.reduce_sum(normal_kl, axis=-1)), self.evaluate(ind_kl))
def normal_conjugates_known_scale_posterior(prior, scale, s, n): """Posterior Normal distribution with conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a Normal with unknown mean `loc` (described by the Normal `prior`) and known variance `scale**2`. The "known scale posterior" is the distribution of the unknown `loc`. Accepts a prior Normal distribution object, having parameters `loc0` and `scale0`, as well as known `scale` values of the predictive distribution(s) (also assumed Normal), and statistical estimates `s` (the sum(s) of the observations) and `n` (the number(s) of observations). Returns a posterior (also Normal) distribution object, with parameters `(loc', scale'**2)`, where: ``` mu ~ N(mu', sigma'**2) sigma'**2 = 1/(1/sigma0**2 + n/sigma**2), mu' = (mu0/sigma0**2 + s/sigma**2) * sigma'**2. ``` Distribution parameters from `prior`, as well as `scale`, `s`, and `n`. will broadcast in the case of multidimensional sets of parameters. Args: prior: `Normal` object of type `dtype`: the prior distribution having parameters `(loc0, scale0)`. scale: tensor of type `dtype`, taking values `scale > 0`. The known stddev parameter(s). s: Tensor of type `dtype`. The sum(s) of observations. n: Tensor of type `int`. The number(s) of observations. Returns: A new Normal posterior distribution object for the unknown observation mean `loc`. Raises: TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a Normal object. """ if not isinstance(prior, normal.Normal): raise TypeError("Expected prior to be an instance of type Normal") if s.dtype != prior.dtype: raise TypeError( "Observation sum s.dtype does not match prior dtype: %s vs. %s" % (s.dtype, prior.dtype)) n = math_ops.cast(n, prior.dtype) scale0_2 = math_ops.square(prior.scale) scale_2 = math_ops.square(scale) scalep_2 = 1.0 / (1 / scale0_2 + n / scale_2) return normal.Normal(loc=(prior.loc / scale0_2 + s / scale_2) * scalep_2, scale=math_ops.sqrt(scalep_2))
def testSampleAndLogProbUnivariateShapes(self): with self.test_session(): gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=normal_lib.Normal( loc=[-1., 1], scale=[0.1, 0.5])) x = gm.sample([4, 5], seed=42) log_prob_x = gm.log_prob(x) self.assertEqual([4, 5], x.shape) self.assertEqual([4, 5], log_prob_x.shape)
def _fn(dtype, shape, name, trainable, add_variable_fn): """Creates multivariate `Deterministic` or `Normal` distribution.""" loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) if scale is None: dist = deterministic_lib.Deterministic(loc=loc) else: dist = normal_lib.Normal(loc=loc, scale=scale) reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] return independent_lib.Independent( dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
def testExplicitVariationalAndPrior(self): with self.test_session() as sess: _, _, variational, _, log_likelihood = mini_vae() prior = normal.Normal(loc=3., scale=2.) elbo = vi.elbo(log_likelihood, variational_with_prior={variational: prior}) expected_elbo = log_likelihood - kullback_leibler.kl_divergence( variational.distribution, prior) sess.run(variables.global_variables_initializer()) self.assertAllEqual(*sess.run([expected_elbo, elbo]))
def testConstructionAndValue(self): with self.test_session() as sess: mu = [0.0, 0.1, 0.2] sigma = constant_op.constant([1.1, 1.2, 1.3]) sigma2 = constant_op.constant([0.1, 0.2, 0.3]) prior_default = st.StochasticTensor( normal.Normal(loc=mu, scale=sigma)) self.assertTrue( isinstance(prior_default.value_type, st.SampleValue)) prior_0 = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma), dist_value_type=st.SampleValue()) self.assertTrue(isinstance(prior_0.value_type, st.SampleValue)) with st.value_type(st.SampleValue()): prior = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior.value_type, st.SampleValue)) likelihood = st.StochasticTensor( normal.Normal(loc=prior, scale=sigma2)) self.assertTrue( isinstance(likelihood.value_type, st.SampleValue)) coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) self.assertEqual(coll, [prior_default, prior_0, prior, likelihood]) # Also works: tf.convert_to_tensor(prior) prior_default = array_ops.identity(prior_default) prior_0 = array_ops.identity(prior_0) prior = array_ops.identity(prior) likelihood = array_ops.identity(likelihood) # Mostly a smoke test for now... prior_0_val, prior_val, prior_default_val, _ = sess.run( [prior_0, prior, prior_default, likelihood]) self.assertEqual(prior_0_val.shape, prior_val.shape) self.assertEqual(prior_default_val.shape, prior_val.shape) # These are different random samples from the same distribution, # so the values should differ. self.assertGreater(np.abs(prior_0_val - prior_val).sum(), 1e-6) self.assertGreater( np.abs(prior_default_val - prior_val).sum(), 1e-6)
def testDenseLocalReparameterization(self): batch_size, in_size, out_size = 2, 3, 4 with self.test_session() as sess: (kernel_posterior, kernel_prior, kernel_divergence, bias_posterior, bias_prior, bias_divergence, layer, inputs, outputs, kl_penalty) = self._testDenseSetUp( prob_layers_lib.DenseLocalReparameterization, batch_size, in_size, out_size) expected_kernel_posterior_affine = normal_lib.Normal( loc=math_ops.matmul(inputs, kernel_posterior.result_loc), scale=math_ops.matmul( inputs**2., kernel_posterior.result_scale**2)**0.5) expected_kernel_posterior_affine_tensor = ( expected_kernel_posterior_affine.sample(seed=42)) expected_outputs = (expected_kernel_posterior_affine_tensor + bias_posterior.result_sample) [ expected_outputs_, actual_outputs_, expected_kernel_divergence_, actual_kernel_divergence_, expected_bias_, actual_bias_, expected_bias_divergence_, actual_bias_divergence_, ] = sess.run([ expected_outputs, outputs, kernel_divergence.result, kl_penalty[0], bias_posterior.result_sample, layer.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) self.assertAllClose( expected_bias_, actual_bias_, rtol=1e-6, atol=0.) self.assertAllClose( expected_outputs_, actual_outputs_, rtol=1e-6, atol=0.) self.assertAllClose( expected_kernel_divergence_, actual_kernel_divergence_, rtol=1e-6, atol=0.) self.assertAllClose( expected_bias_divergence_, actual_bias_divergence_, rtol=1e-6, atol=0.) self.assertAllEqual( [[kernel_posterior.distribution, kernel_prior.distribution, None]], kernel_divergence.args) self.assertAllEqual( [[bias_posterior.distribution, bias_prior.distribution, bias_posterior.result_sample]], bias_divergence.args)
def testNormalFullyReparameterized(self): mu = constant_op.constant(4.0) sigma = constant_op.constant(3.0) with backprop.GradientTape() as tape: tape.watch(mu) tape.watch(sigma) normal = normal_lib.Normal(loc=mu, scale=sigma) samples = normal.sample(100) grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma]) self.assertIsNotNone(grad_mu) self.assertIsNotNone(grad_sigma)
def _testParamShapes(self, sample_shape, expected): param_shapes = normal_lib.Normal.param_shapes(sample_shape) mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] self.assertAllEqual(expected, self.evaluate(mu_shape)) self.assertAllEqual(expected, self.evaluate(sigma_shape)) mu = array_ops.zeros(mu_shape) sigma = array_ops.ones(sigma_shape) self.assertAllEqual( expected, self.evaluate( array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
def testConstructionAndValue(self): with self.test_session() as sess: mu = [0.0, 0.1, 0.2] sigma = constant_op.constant([1.1, 1.2, 1.3]) obs = array_ops.zeros((2, 3)) z = st.ObservedStochasticTensor(normal.Normal(loc=mu, scale=sigma), value=obs) [obs_val, z_val] = sess.run([obs, z.value()]) self.assertAllEqual(obs_val, z_val) coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) self.assertEqual(coll, [z])
def testNormalMeanAndMode(self): # Mu will be broadcast to [7, 7, 7]. mu = [7.] sigma = [11., 12., 13.] normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3, ), normal.mean().get_shape()) self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean())) self.assertAllEqual((3, ), normal.mode().get_shape()) self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
def test_normal_integral_mean_and_var_correctly_estimated(self): n = int(1000) # This test is almost identical to the similarly named test in # monte_carlo_test.py. The only difference is that we use the Halton # samples instead of the random samples to evaluate the expectations. # MC with pseudo random numbers converges at the rate of 1/ Sqrt(N) # (N=number of samples). For QMC in low dimensions, the expected convergence # rate is ~ 1/N. Hence we should only need 1e3 samples as compared to the # 1e6 samples used in the pseudo-random monte carlo. with self.test_session(): mu_p = array_ops.constant([-1.0, 1.0], dtype=dtypes.float64) mu_q = array_ops.constant([0.0, 0.0], dtype=dtypes.float64) sigma_p = array_ops.constant([0.5, 0.5], dtype=dtypes.float64) sigma_q = array_ops.constant([1.0, 1.0], dtype=dtypes.float64) p = normal_lib.Normal(loc=mu_p, scale=sigma_p) q = normal_lib.Normal(loc=mu_q, scale=sigma_q) cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64) q_sample = q.quantile(cdf_sample) # Compute E_p[X]. e_x = mc.expectation_importance_sampler(f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, seed=42) # Compute E_p[X^2]. e_x2 = mc.expectation_importance_sampler(f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, seed=42) stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) # Keep the tolerance levels the same as in monte_carlo_test.py. self.assertEqual(p.batch_shape, e_x.get_shape()) self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01) self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02)
def testKLRaises(self): ind1 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1) ind2 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32(-1), scale=np.float32(0.5)), reinterpreted_batch_ndims=0) with self.assertRaisesRegexp(ValueError, "Event shapes do not match"): kullback_leibler.kl_divergence(ind1, ind2) ind1 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1) ind2 = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([-1., 1]), scale_diag=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=0) with self.assertRaisesRegexp(NotImplementedError, "different event shapes"): kullback_leibler.kl_divergence(ind1, ind2)
def test_works_correctly(self): with self.cached_session() as sess: x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6]) p = normal_lib.Normal(loc=x, scale=1.) # We use the prefex "efx" to mean "E_p[f(X)]". f = lambda u: u efx_true = x samples = p.sample(int(1e5), seed=1) efx_reparam = mc.expectation(f, samples, p.log_prob) efx_score = mc.expectation(f, samples, p.log_prob, use_reparametrization=False) [ efx_true_, efx_reparam_, efx_score_, efx_true_grad_, efx_reparam_grad_, efx_score_grad_, ] = sess.run([ efx_true, efx_reparam, efx_score, gradients_impl.gradients(efx_true, x)[0], gradients_impl.gradients(efx_reparam, x)[0], gradients_impl.gradients(efx_score, x)[0], ]) self.assertAllEqual(np.ones_like(efx_true_grad_), efx_true_grad_) self.assertAllClose(efx_true_, efx_reparam_, rtol=0.005, atol=0.) self.assertAllClose(efx_true_, efx_score_, rtol=0.005, atol=0.) self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool), np.isfinite(efx_reparam_grad_)) self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool), np.isfinite(efx_score_grad_)) self.assertAllClose(efx_true_grad_, efx_reparam_grad_, rtol=0.03, atol=0.) # Variance is too high to be meaningful, so we'll only check those which # converge. self.assertAllClose(efx_true_grad_[2:-2], efx_score_grad_[2:-2], rtol=0.05, atol=0.)
def testNormalShapeWithPlaceholders(self): mu = array_ops.placeholder(dtype=dtypes.float32) sigma = array_ops.placeholder(dtype=dtypes.float32) normal = normal_lib.Normal(loc=mu, scale=sigma) with self.test_session() as sess: # get_batch_shape should return an "<unknown>" tensor. self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None)) self.assertEqual(normal.event_shape, ()) self.assertAllEqual(normal.event_shape_tensor().eval(), []) self.assertAllEqual( sess.run(normal.batch_shape_tensor(), feed_dict={mu: 5.0, sigma: [1.0, 2.0]}), [2])
def testSurrogateLoss(self): with self.test_session(): mu = [[3.0, -4.0, 5.0], [6.0, -7.0, 8.0]] sigma = constant_op.constant(1.0) # With default with st.value_type(st.MeanValue(stop_gradient=True)): dt = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) loss = dt.loss([constant_op.constant(2.0)]) self.assertTrue(loss is not None) self.assertAllClose( dt.distribution.log_prob(mu).eval() * 2.0, loss.eval()) # With passed-in loss_fn. dt = st.StochasticTensor( normal.Normal(loc=mu, scale=sigma), dist_value_type=st.MeanValue(stop_gradient=True), loss_fn=sge.get_score_function_with_constant_baseline( baseline=constant_op.constant(8.0))) loss = dt.loss([constant_op.constant(2.0)]) self.assertTrue(loss is not None) self.assertAllClose( (dt.distribution.log_prob(mu) * (2.0 - 8.0)).eval(), loss.eval())
def testSampleValueScalar(self): with self.test_session() as sess: mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]] sigma = constant_op.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]]) with st.value_type(st.SampleValue()): prior_single = st.StochasticTensor( normal.Normal(loc=mu, scale=sigma)) prior_single_value = prior_single.value() self.assertEqual(prior_single_value.get_shape(), (2, 3)) prior_single_value_val = sess.run([prior_single_value])[0] self.assertEqual(prior_single_value_val.shape, (2, 3)) with st.value_type(st.SampleValue(1)): prior_single = st.StochasticTensor( normal.Normal(loc=mu, scale=sigma)) self.assertTrue( isinstance(prior_single.value_type, st.SampleValue)) prior_single_value = prior_single.value() self.assertEqual(prior_single_value.get_shape(), (1, 2, 3)) prior_single_value_val = sess.run([prior_single_value])[0] self.assertEqual(prior_single_value_val.shape, (1, 2, 3)) with st.value_type(st.SampleValue(2)): prior_double = st.StochasticTensor( normal.Normal(loc=mu, scale=sigma)) prior_double_value = prior_double.value() self.assertEqual(prior_double_value.get_shape(), (2, 2, 3)) prior_double_value_val = sess.run([prior_double_value])[0] self.assertEqual(prior_double_value_val.shape, (2, 2, 3))
def setUp(self): super(MutualInformationPenaltyTest, self).setUp() self._penalty_fn = tfgan_losses.mutual_information_penalty self._structured_generator_inputs = [1.0, 2.0] self._predicted_distributions = [categorical.Categorical(logits=[1.0, 2.0]), normal.Normal([0.0], [1.0])] self._expected_dtype = dtypes.float32 self._kwargs = { 'structured_generator_inputs': self._structured_generator_inputs, 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 self._expected_op_name = 'mutual_information_loss/mul' self._batch_size = 2
def testBroadcastWithBatchParamsAndBiggerEvent(self): ## The parameters have a single batch dimension, and the event has two. # param shape is [3 x 4], where 4 is the number of bins (non-batch dim). cat_params_py = [[0.2, 0.15, 0.35, 0.3], [0.1, 0.05, 0.68, 0.17], [0.1, 0.05, 0.68, 0.17]] # event shape = [5, 3], both are "batch" dimensions. disc_event_py = [[0, 1, 2], [1, 2, 3], [0, 0, 0], [1, 1, 1], [2, 1, 0]] # shape is [3] normal_params_py = [-10.0, 120.0, 50.0] # shape is [5, 3] real_event_py = [[-1.0, 0.0, 1.0], [100.0, 101, -50], [90, 90, 90], [-4, -400, 20.0], [0.0, 0.0, 0.0]] cat_params_tf = array_ops.constant(cat_params_py) disc_event_tf = array_ops.constant(disc_event_py) cat = categorical.Categorical(probs=cat_params_tf) normal_params_tf = array_ops.constant(normal_params_py) real_event_tf = array_ops.constant(real_event_py) norm = normal.Normal(loc=normal_params_tf, scale=1.0) # Check that normal and categorical have the same broadcasting behaviour. to_run = { "cat_prob": cat.prob(disc_event_tf), "cat_log_prob": cat.log_prob(disc_event_tf), "cat_cdf": cat.cdf(disc_event_tf), "cat_log_cdf": cat.log_cdf(disc_event_tf), "norm_prob": norm.prob(real_event_tf), "norm_log_prob": norm.log_prob(real_event_tf), "norm_cdf": norm.cdf(real_event_tf), "norm_log_cdf": norm.log_cdf(real_event_tf), } with self.cached_session() as sess: run_result = self.evaluate(to_run) self.assertAllEqual(run_result["cat_prob"].shape, run_result["norm_prob"].shape) self.assertAllEqual(run_result["cat_log_prob"].shape, run_result["norm_log_prob"].shape) self.assertAllEqual(run_result["cat_cdf"].shape, run_result["norm_cdf"].shape) self.assertAllEqual(run_result["cat_log_cdf"].shape, run_result["norm_log_cdf"].shape)