def testSampleConsistentLogProb(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( sess, gm, radius=1., center=[-1., 1], rtol=0.02) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( sess, gm, radius=1., center=[1., -1], rtol=0.02)
def testSampleAndLogProbBatchMultivariateShapes(self): with self.test_session(): gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[[-1., 1], [1, -1]], [[0., 1], [1, 0]]], scale_identity_multiplier=[1., 0.5])) x = gm.sample([4, 5], seed=42) log_prob_x = gm.log_prob(x) self.assertEqual([4, 5, 2, 2], x.shape) self.assertEqual([4, 5, 2], log_prob_x.shape)
def testSampleAndLogProbShapesBroadcastMix(self): mix_probs = np.float32([.3, .7]) bern_probs = np.float32([[.4, .6], [.25, .75]]) with self.test_session(): bm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=mix_probs), components_distribution=bernoulli_lib.Bernoulli(probs=bern_probs)) x = bm.sample([4, 5], seed=42) log_prob_x = bm.log_prob(x) x_ = x.eval() self.assertEqual([4, 5, 2], x.shape) self.assertEqual([4, 5, 2], log_prob_x.shape) self.assertAllEqual( np.ones_like(x_, dtype=np.bool), np.logical_or(x_ == 0., x_ == 1.))
def test_pad_mixture_dimensions_mixture_same_family(self): with self.test_session() as sess: gm = mixture_same_family.MixtureSameFamily( mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5])) x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) x_pad = distribution_util.pad_mixture_dimensions( x, gm, gm.mixture_distribution, gm.event_shape.ndims) x_out, x_pad_out = sess.run([x, x_pad]) self.assertAllEqual(x_pad_out.shape, [2, 2, 1]) self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
def testLogCdf(self): with self.test_session() as sess: gm = mixture_same_family_lib.MixtureSameFamily( mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=normal_lib.Normal( loc=[-1., 1], scale=[0.1, 0.5])) x = gm.sample(10, seed=42) actual_log_cdf = gm.log_cdf(x) expected_log_cdf = math_ops.reduce_logsumexp( (gm.mixture_distribution.logits + gm.components_distribution.log_cdf(x[..., array_ops.newaxis])), axis=1) actual_log_cdf_, expected_log_cdf_ = sess.run([ actual_log_cdf, expected_log_cdf]) self.assertAllClose(actual_log_cdf_, expected_log_cdf_, rtol=1e-6, atol=0.0)