def testStateParts(self): with self.test_session(graph=ops.Graph()) as sess: dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) dist_y = independent_lib.Independent( gamma_lib.Gamma(concentration=self.dtype([1, 2]), rate=self.dtype([0.5, 0.75])), reinterpreted_batch_ndims=1) def target_log_prob(x, y): return dist_x.log_prob(x) + dist_y.log_prob(y) x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] samples, _ = hmc.sample_chain( num_results=int(2e3), target_log_prob_fn=target_log_prob, current_state=x0, step_size=0.85, num_leapfrog_steps=3, num_burnin_steps=int(250), seed=49) actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] actual_vars = [_reduce_variance(s, axis=0) for s in samples] expected_means = [dist_x.mean(), dist_y.mean()] expected_vars = [dist_x.variance(), dist_y.variance()] [ actual_means_, actual_vars_, expected_means_, expected_vars_, ] = sess.run([ actual_means, actual_vars, expected_means, expected_vars, ]) self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25)
def testKLScalarToMultivariate(self): normal1 = normal_lib.Normal(loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])) ind1 = independent_lib.Independent(distribution=normal1, reinterpreted_batch_ndims=1) normal2 = normal_lib.Normal(loc=np.float32([-3., 3]), scale=np.float32([0.3, 0.3])) ind2 = independent_lib.Independent(distribution=normal2, reinterpreted_batch_ndims=1) normal_kl = kullback_leibler.kl_divergence(normal1, normal2) ind_kl = kullback_leibler.kl_divergence(ind1, ind2) self.assertAllClose( self.evaluate(math_ops.reduce_sum(normal_kl, axis=-1)), self.evaluate(ind_kl))
def testKLIdentity(self): normal1 = normal_lib.Normal(loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])) # This is functionally just a wrapper around normal1, # and doesn't change any outputs. ind1 = independent_lib.Independent(distribution=normal1, reinterpreted_batch_ndims=0) normal2 = normal_lib.Normal(loc=np.float32([-3., 3]), scale=np.float32([0.3, 0.3])) # This is functionally just a wrapper around normal2, # and doesn't change any outputs. ind2 = independent_lib.Independent(distribution=normal2, reinterpreted_batch_ndims=0) normal_kl = kullback_leibler.kl_divergence(normal1, normal2) ind_kl = kullback_leibler.kl_divergence(ind1, ind2) self.assertAllClose(self.evaluate(normal_kl), self.evaluate(ind_kl))
def _fn(dtype, shape, name, trainable, add_variable_fn): """Creates multivariate `Deterministic` or `Normal` distribution.""" loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) if scale is None: dist = deterministic_lib.Deterministic(loc=loc) else: dist = normal_lib.Normal(loc=loc, scale=scale) reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] return independent_lib.Independent( dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
def testKLMultivariateToMultivariate(self): # (1, 1, 2) batch of MVNDiag mvn1 = mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]), scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]])) ind1 = independent_lib.Independent(distribution=mvn1, reinterpreted_batch_ndims=2) # (1, 1, 2) batch of MVNDiag mvn2 = mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]), scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]])) ind2 = independent_lib.Independent(distribution=mvn2, reinterpreted_batch_ndims=2) mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2) ind_kl = kullback_leibler.kl_divergence(ind1, ind2) self.assertAllClose( self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])), self.evaluate(ind_kl))
def testSampleConsistentStats(self): loc = np.float32([[-1., 1], [1, -1]]) scale = np.float32([1., 0.5]) n_samp = 1e4 with self.cached_session() as sess: ind = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), reinterpreted_batch_ndims=1) x = ind.sample(int(n_samp), seed=42) sample_mean = math_ops.reduce_mean(x, axis=0) sample_var = math_ops.reduce_mean(math_ops.squared_difference( x, sample_mean), axis=0) sample_std = math_ops.sqrt(sample_var) sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0) [ sample_mean_, sample_var_, sample_std_, sample_entropy_, actual_mean_, actual_var_, actual_std_, actual_entropy_, actual_mode_, ] = sess.run([ sample_mean, sample_var, sample_std, sample_entropy, ind.mean(), ind.variance(), ind.stddev(), ind.entropy(), ind.mode(), ]) self.assertAllCloseAccordingToType(sample_mean_, actual_mean_, rtol=0.02) self.assertAllCloseAccordingToType(sample_var_, actual_var_, rtol=0.04) self.assertAllCloseAccordingToType(sample_std_, actual_std_, rtol=0.02) self.assertAllCloseAccordingToType(sample_entropy_, actual_entropy_, rtol=0.01) self.assertAllCloseAccordingToType(loc, actual_mode_, rtol=1e-6)
def testKLRaises(self): ind1 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1) ind2 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32(-1), scale=np.float32(0.5)), reinterpreted_batch_ndims=0) with self.assertRaisesRegexp(ValueError, "Event shapes do not match"): kullback_leibler.kl_divergence(ind1, ind2) ind1 = independent_lib.Independent(distribution=normal_lib.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1) ind2 = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=np.float32([-1., 1]), scale_diag=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=0) with self.assertRaisesRegexp(NotImplementedError, "different event shapes"): kullback_leibler.kl_divergence(ind1, ind2)
def _testMnistLike(self, static_shape): sample_shape = [4, 5] batch_shape = [10] image_shape = [28, 28, 1] logits = 3 * self._rng.random_sample(batch_shape + image_shape).astype( np.float32) - 1 def expected_log_prob(x, logits): return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1) with self.test_session() as sess: logits_ph = array_ops.placeholder( dtypes.float32, shape=logits.shape if static_shape else None) ind = independent_lib.Independent( distribution=bernoulli_lib.Bernoulli(logits=logits_ph)) x = ind.sample(sample_shape) log_prob_x = ind.log_prob(x) [ x_, actual_log_prob_x, ind_batch_shape, ind_event_shape, x_shape, log_prob_x_shape, ] = sess.run([ x, log_prob_x, ind.batch_shape_tensor(), ind.event_shape_tensor(), array_ops.shape(x), array_ops.shape(log_prob_x), ], feed_dict={logits_ph: logits}) if static_shape: ind_batch_shape = ind.batch_shape ind_event_shape = ind.event_shape x_shape = x.shape log_prob_x_shape = log_prob_x.shape self.assertAllEqual(batch_shape, ind_batch_shape) self.assertAllEqual(image_shape, ind_event_shape) self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape) self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape) self.assertAllClose(expected_log_prob(x_, logits), actual_log_prob_x, rtol=1e-6, atol=0.)
def testSampleAndLogProbUnivariate(self): loc = np.float32([-1., 1]) scale = np.float32([0.1, 0.5]) with self.cached_session() as sess: ind = independent_lib.Independent(distribution=normal_lib.Normal( loc=loc, scale=scale), reinterpreted_batch_ndims=1) x = ind.sample([4, 5], seed=42) log_prob_x = ind.log_prob(x) x_, actual_log_prob_x = sess.run([x, log_prob_x]) self.assertEqual([], ind.batch_shape) self.assertEqual([2], ind.event_shape) self.assertEqual([4, 5, 2], x.shape) self.assertEqual([4, 5], log_prob_x.shape) expected_log_prob_x = stats.norm(loc, scale).logpdf(x_).sum(-1) self.assertAllCloseAccordingToType(expected_log_prob_x, actual_log_prob_x)
def testSampleAndLogProbMultivariate(self): loc = np.float32([[-1., 1], [1, -1]]) scale = np.float32([1., 0.5]) with self.cached_session() as sess: ind = independent_lib.Independent( distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), reinterpreted_batch_ndims=1) x = ind.sample([4, 5], seed=42) log_prob_x = ind.log_prob(x) x_, actual_log_prob_x = sess.run([x, log_prob_x]) self.assertEqual([], ind.batch_shape) self.assertEqual([2, 2], ind.event_shape) self.assertEqual([4, 5, 2, 2], x.shape) self.assertEqual([4, 5], log_prob_x.shape) expected_log_prob_x = stats.norm( loc, scale[:, None]).logpdf(x_).sum(-1).sum(-1) self.assertAllCloseAccordingToType(expected_log_prob_x, actual_log_prob_x)
def _fn(samples): scale = math_ops.exp(affine_bijector.forward(samples)) return independent_lib.Independent(normal_lib.Normal( loc=0., scale=scale, validate_args=True), reinterpreted_batch_ndims=1)