def testSampleWithSameSeed(self): if tf.executing_eagerly(): return scale = make_pd(1., 2) df = 4 seed = test_util.test_seed() chol_w = tfd.WishartTriL( df, scale_tril=chol(scale), input_output_cholesky=False) x = self.evaluate(chol_w.sample(1, seed=seed)) chol_x = [chol(x[0])] full_w = tfd.Wishart(df, scale, input_output_cholesky=False) self.assertAllClose(x, self.evaluate(full_w.sample(1, seed=seed))) chol_w_chol = tfd.WishartTriL( df, scale_tril=chol(scale), input_output_cholesky=True) self.assertAllClose(chol_x, self.evaluate(chol_w_chol.sample(1, seed=seed))) eigen_values = tf.linalg.diag_part(chol_w_chol.sample(1000, seed=seed)) np.testing.assert_array_less(0., self.evaluate(eigen_values)) full_w_chol = tfd.Wishart(df, scale=scale, input_output_cholesky=True) self.assertAllClose(chol_x, self.evaluate(full_w_chol.sample(1, seed=seed))) eigen_values = tf.linalg.diag_part(full_w_chol.sample(1000, seed=seed)) np.testing.assert_array_less(0., self.evaluate(eigen_values))
def testEntropy(self): scale = make_pd(1., 2) df = 4 w = tfd.Wishart(df, scale_tril=chol(scale)) # sp.stats.wishart(df=4, scale=make_pd(1., 2)).entropy() self.assertAllClose(6.301387092430769, self.evaluate(w.entropy())) w = tfd.Wishart(df=1, scale_tril=[[1.]]) # sp.stats.wishart(df=1,scale=1).entropy() self.assertAllClose(0.78375711047393404, self.evaluate(w.entropy()))
def testValidateArgs(self): x = make_pd(1., 3) chol_scale = chol(x) df_deferred = tf1.placeholder_with_default(input=2., shape=None) chol_scale_deferred = tf1.placeholder_with_default( input=np.float32(chol_scale), shape=chol_scale.shape) # In eager mode, these checks are done statically and hence # ValueError is returned on object construction. error_type = tf.errors.InvalidArgumentError if tf.executing_eagerly(): error_type = ValueError # Check expensive, deferred assertions. with self.assertRaisesRegexp(error_type, "cannot be less than"): chol_w = tfd.Wishart(df=df_deferred, scale_tril=chol_scale_deferred, validate_args=True) self.evaluate(chol_w.log_prob(np.asarray(x, dtype=np.float32))) with self.assertRaisesOpError( "Cholesky decomposition was not successful"): df_deferred = tf1.placeholder_with_default(input=2., shape=None) chol_scale_deferred = tf1.placeholder_with_default(input=np.ones( [3, 3], dtype=np.float32), shape=[3, 3]) chol_w = tfd.Wishart(df=df_deferred, scale=chol_scale_deferred) # np.ones((3, 3)) is not positive, definite. self.evaluate(chol_w.log_prob(np.asarray(x, dtype=np.float32))) with self.assertRaisesOpError("scale_tril must be square"): chol_w = tfd.Wishart(df=4., scale_tril=np.array( [[2., 3., 4.], [1., 2., 3.]], dtype=np.float32), validate_args=True) self.evaluate(chol_w.scale()) # Ensure no assertions. df_deferred = tf1.placeholder_with_default(input=4., shape=None) chol_scale_deferred = tf1.placeholder_with_default( input=np.float32(chol_scale), shape=chol_scale.shape) chol_w = tfd.Wishart(df=df_deferred, scale_tril=chol_scale_deferred, validate_args=False) self.evaluate(chol_w.log_prob(np.asarray(x, dtype=np.float32))) chol_scale_deferred = tf1.placeholder_with_default(input=np.ones( [3, 3], dtype=np.float32), shape=[3, 3]) chol_w = tfd.Wishart(df=df_deferred, scale_tril=chol_scale_deferred, validate_args=False) # Bogus log_prob, but since we have no checks running... c"est la vie. self.evaluate(chol_w.log_prob(np.asarray(x, dtype=np.float32)))
def testMeanLogDetAndLogNormalizingConstant(self): def entropy_alt(w): return self.evaluate(w.log_normalization() - 0.5 * (w.df - w.dimension - 1.) * w.mean_log_det() + 0.5 * w.df * w.dimension) w = tfd.Wishart(df=4, scale_tril=chol(make_pd(1., 2))) self.assertAllClose(self.evaluate(w.entropy()), entropy_alt(w)) w = tfd.Wishart(df=5, scale_tril=[[1.]]) self.assertAllClose(self.evaluate(w.entropy()), entropy_alt(w))
def testStaticAsserts(self): x = make_pd(1., 3) chol_scale = chol(x) # Still has these assertions because they're resolveable at graph # construction: # df < rank with self.assertRaisesRegexp(ValueError, "cannot be less than"): tfd.Wishart(df=2, scale_tril=chol_scale, validate_args=False) # non-float dtype with self.assertRaisesRegexp(TypeError, "."): tfd.Wishart(df=4, scale_tril=np.asarray(chol_scale, dtype=np.int32), validate_args=False)
def testSample(self): # Check first and second moments. df = 4. chol_w = tfd.Wishart(df=df, scale_tril=chol(make_pd(1., 3)), input_output_cholesky=False) x = chol_w.sample(10000, seed=tfp_test_util.test_seed(hardcoded_seed=42)) self.assertAllEqual((10000, 3, 3), x.shape) moment1_estimate = self.evaluate( tf.reduce_mean(input_tensor=x, axis=[0])) self.assertAllClose(self.evaluate(chol_w.mean()), moment1_estimate, rtol=0.05) # The Variance estimate uses the squares rather than outer-products # because Wishart.Variance is the diagonal of the Wishart covariance # matrix. variance_estimate = self.evaluate( tf.reduce_mean(input_tensor=tf.square(x), axis=[0]) - tf.square(moment1_estimate)) self.assertAllClose(self.evaluate(chol_w.variance()), variance_estimate, rtol=0.05)
def testCopyExtraArgs(self): # Note: we cannot easily test all distributions since each requires # different initialization arguments. We therefore spot test a few. normal = tfd.Normal(loc=1., scale=2., validate_args=True) self.assertEqual(normal.parameters, normal.copy().parameters) wishart = tfd.Wishart(df=2, scale=[[1., 2], [2, 5]], validate_args=True) self.assertEqual(wishart.parameters, wishart.copy().parameters)
def _correlated_mvn_nuts(self, dim, step_size, num_steps): # The correlated MVN example is taken from the NUTS paper # https://arxiv.org/pdf/1111.4246.pdf. # This implementation in terms of MVNCholPrecisionTril follows # tfp/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb class MVNCholPrecisionTriL(tfd.TransformedDistribution): """MVN from loc and (Cholesky) precision matrix.""" def __init__(self, loc, chol_precision_tril, name=None): super(MVNCholPrecisionTriL, self).__init__( distribution=tfd.Independent(tfd.Normal(tf.zeros_like(loc), scale=tf.ones_like(loc)), reinterpreted_batch_ndims=1), bijector=tfb.Chain([ tfb.Affine(shift=loc), tfb.Invert(tfb.Affine(scale_tril=chol_precision_tril, adjoint=True)), ]), name=name) strm = tfp_test_util.test_seed_stream() wishart = tfd.Wishart(dim, scale=tf.eye(dim), input_output_cholesky=True) chol_precision = wishart.sample(seed=strm()) mvn = MVNCholPrecisionTriL( loc=tf.zeros(dim), chol_precision_tril=chol_precision) kernel = tfp.experimental.mcmc.NoUTurnSampler( mvn.log_prob, step_size=[step_size], num_trajectories_per_step=num_steps, use_auto_batching=True, stackless=False, max_tree_depth=7, seed=strm()) return kernel
def testMeanBroadcast(self): scale = [make_pd(1., 2), make_pd(1., 2)] chol_scale = np.float32([chol(s) for s in scale]) scale = np.float32(scale) df = np.array([4., 3.], dtype=np.float32) w = tfd.Wishart(df, scale_tril=chol_scale) self.assertAllEqual(df[..., np.newaxis, np.newaxis] * scale, self.evaluate(w.mean()))
def testVarianceBroadcast(self): scale = [make_pd(1., 2), make_pd(1., 2)] chol_scale = np.float32([chol(s) for s in scale]) scale = np.float32(scale) df = np.array([4., 3.], dtype=np.float32) w = tfd.Wishart(df, scale_tril=chol_scale) self.assertAllEqual(wishart_var(df, scale), self.evaluate(w.variance()))
def testSampleMultipleTimes(self): df = 4. n_val = 100 seed = tfp_test_util.test_seed() tf1.set_random_seed(seed) chol_w1 = tfd.Wishart(df=df, scale_tril=chol(make_pd(1., 3)), input_output_cholesky=False, name="wishart1") samples1 = self.evaluate(chol_w1.sample(n_val, seed=seed)) tf1.set_random_seed(seed) chol_w2 = tfd.Wishart(df=df, scale_tril=chol(make_pd(1., 3)), input_output_cholesky=False, name="wishart2") samples2 = self.evaluate(chol_w2.sample(n_val, seed=seed)) self.assertAllClose(samples1, samples2)
def testEventShape(self): scale = make_pd(1., 2) chol_scale = chol(scale) w = tfd.Wishart(df=4, scale_tril=chol_scale) self.assertAllEqual([2, 2], w.event_shape) self.assertAllEqual([2, 2], self.evaluate(w.event_shape_tensor())) w = tfd.Wishart(df=[4., 4], scale_tril=np.array([chol_scale, chol_scale])) self.assertAllEqual([2, 2], w.event_shape) self.assertAllEqual([2, 2], self.evaluate(w.event_shape_tensor())) scale_deferred = tf1.placeholder_with_default(input=chol_scale, shape=chol_scale.shape) w = tfd.Wishart(df=4, scale_tril=scale_deferred) self.assertAllEqual([2, 2], self.evaluate(w.event_shape_tensor())) scale_deferred = tf1.placeholder_with_default(input=np.array( [chol_scale, chol_scale]), shape=None) w = tfd.Wishart(df=4, scale_tril=scale_deferred) self.assertAllEqual([2, 2], self.evaluate(w.event_shape_tensor()))
def testAssertsVariableScale(self): df = 4 scale = tf.Variable([[2., 1.], [3., 3.]]) self.evaluate(scale.initializer) with self.assertRaisesOpError(''): d = tfd.Wishart(df=df, scale=scale, validate_args=True) self.evaluate(d.entropy()) scale_tril = tf.Variable(chol(make_pd(3., 4.)).astype(np.float32), shape=tf.TensorShape(None)) df = 3 self.evaluate(scale_tril.initializer) with self.assertRaisesOpError('cannot be less than'): d = tfd.WishartTriL(df=df, scale_tril=scale_tril, validate_args=True) self.evaluate(d.sample())
def testSampleBroadcasts(self): dims = 2 batch_shape = [2, 3] sample_shape = [2, 1] scale = np.float32([ [[1., 0.5], [0.5, 1.]], [[0.5, 0.25], [0.25, 0.75]], ]) scale = np.reshape(np.concatenate([scale, scale, scale], axis=0), batch_shape + [dims, dims]) wishart = tfd.Wishart(df=5, scale=scale) x = wishart.sample(sample_shape, seed=tfp_test_util.test_seed()) x_ = self.evaluate(x) expected_shape = sample_shape + batch_shape + [dims, dims] self.assertAllEqual(expected_shape, x.shape) self.assertAllEqual(expected_shape, x_.shape)
def testLogProbBroadcastOverDfInsideMixture(self): dims = 2 scale = np.float32([ [0.5, 0.25], # [0.25, 0.75] ]) df = np.arange(3., 8., dtype=np.float32) dist = tfd.MixtureSameFamily( components_distribution=tfd.Wishart(df=df, scale=scale), mixture_distribution=tfd.Categorical(logits=tf.zeros(df.shape))) x = np.random.randn(dims, dims) x = np.matmul(x, x.T) lp = dist.log_prob(x) lp_ = self.evaluate(lp) self.assertAllEqual([], dist.batch_shape) self.assertAllEqual([], lp.shape) self.assertAllEqual([], lp_.shape)
def make_wishart(self, dims, new_batch_shape, old_batch_shape): new_batch_shape_ph = (tf.constant(np.int32(new_batch_shape)) if self.is_static_shape else tf1.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = self.dtype([ [[1., 0.5], [0.5, 1.]], [[0.5, 0.25], [0.25, 0.75]], ]) scale = np.reshape(np.concatenate([scale, scale], axis=0), old_batch_shape + [dims, dims]) scale_ph = tf1.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) wishart = tfd.Wishart(df=5, scale=scale_ph) reshape_wishart = tfd.BatchReshape(distribution=wishart, batch_shape=new_batch_shape_ph, validate_args=True) return wishart, reshape_wishart
def testLogProbBroadcastsX(self): dims = 2 batch_shape = [2, 3] scale = np.float32([ [[1., 0.5], [0.5, 1.]], [[0.5, 0.25], [0.25, 0.75]], ]) scale = np.reshape(np.concatenate([scale, scale, scale], axis=0), batch_shape + [dims, dims]) wishart = tfd.Wishart(df=5, scale=scale) x = np.random.randn(dims, dims) x = np.matmul(x, x.T) lp = wishart.log_prob(x) lp_bc = wishart.log_prob(x * np.ones([2, 3, 1, 1])) lp_, lp_bc_ = self.evaluate([lp, lp_bc]) self.assertAllEqual(batch_shape, lp.shape) self.assertAllEqual(batch_shape, lp_.shape) self.assertAllEqual(batch_shape, lp_bc.shape) self.assertAllEqual(batch_shape, lp_bc_.shape) self.assertAllClose(lp_bc_, lp_)
def testProb(self): # Generate some positive definite (pd) matrices and their Cholesky # factorizations. x = np.array( [make_pd(1., 2), make_pd(2., 2), make_pd(3., 2), make_pd(4., 2)]) chol_x = np.array([chol(x[0]), chol(x[1]), chol(x[2]), chol(x[3])]) # Since Wishart wasn't added to SciPy until 0.16, we'll spot check some # pdfs with hard-coded results from upstream SciPy. log_prob_df_seq = np.array([ # math.log(stats.wishart.pdf(x[0], df=2+0, scale=x[0])) -3.5310242469692907, # math.log(stats.wishart.pdf(x[1], df=2+1, scale=x[1])) -7.689907330328961, # math.log(stats.wishart.pdf(x[2], df=2+2, scale=x[2])) -10.815845159537895, # math.log(stats.wishart.pdf(x[3], df=2+3, scale=x[3])) -13.640549882916691, ]) # This test checks that batches don't interfere with correctness. w = tfd.WishartTriL( df=[2, 3, 4, 5], scale_tril=chol_x, input_output_cholesky=True) self.assertAllClose(log_prob_df_seq, self.evaluate(w.log_prob(chol_x))) # Now we test various constructions of Wishart with different sample # shape. log_prob = np.array([ # math.log(stats.wishart.pdf(x[0], df=4, scale=x[0])) -4.224171427529236, # math.log(stats.wishart.pdf(x[1], df=4, scale=x[0])) -6.3378770664093453, # math.log(stats.wishart.pdf(x[2], df=4, scale=x[0])) -12.026946850193017, # math.log(stats.wishart.pdf(x[3], df=4, scale=x[0])) -20.951582705289454, ]) for w in (tfd.WishartTriL( df=4, scale_tril=chol_x[0], input_output_cholesky=False), tfd.Wishart(df=4, scale=x[0], input_output_cholesky=False), ): dimension = (w.scale.domain_dimension_tensor() if 'WishartTriL' in w.name else w.dimension) self.assertAllEqual((2, 2), self.evaluate(w.event_shape_tensor())) self.assertEqual(2, self.evaluate(dimension)) self.assertAllClose(log_prob[0], self.evaluate(w.log_prob(x[0]))) self.assertAllClose(log_prob[0:2], self.evaluate(w.log_prob(x[0:2]))) self.assertAllClose( np.reshape(log_prob, (2, 2)), self.evaluate(w.log_prob(np.reshape(x, (2, 2, 2, 2))))) self.assertAllClose( np.reshape(np.exp(log_prob), (2, 2)), self.evaluate(w.prob(np.reshape(x, (2, 2, 2, 2))))) self.assertAllEqual((2, 2), w.log_prob(np.reshape(x, (2, 2, 2, 2))).shape) for w in (tfd.WishartTriL( df=4, scale_tril=chol_x[0], input_output_cholesky=True), tfd.Wishart(df=4, scale=x[0], input_output_cholesky=True), ): dimension = (w.scale.domain_dimension_tensor() if 'WishartTriL' in w.name else w.dimension) self.assertAllEqual((2, 2), self.evaluate(w.event_shape_tensor())) self.assertEqual(2, self.evaluate(dimension)) self.assertAllClose(log_prob[0], self.evaluate(w.log_prob(chol_x[0]))) self.assertAllClose(log_prob[0:2], self.evaluate(w.log_prob(chol_x[0:2]))) self.assertAllClose( np.reshape(log_prob, (2, 2)), self.evaluate(w.log_prob(np.reshape(chol_x, (2, 2, 2, 2))))) self.assertAllClose( np.reshape(np.exp(log_prob), (2, 2)), self.evaluate(w.prob(np.reshape(chol_x, (2, 2, 2, 2))))) self.assertAllEqual((2, 2), w.log_prob(np.reshape(x, (2, 2, 2, 2))).shape)
def testStd(self): scale = make_pd(1., 2) df = 4 w = tfd.Wishart(df, scale_tril=chol(scale)) self.assertAllEqual(np.sqrt(wishart_var(df, scale)), self.evaluate(w.stddev()))
def testVariance(self): scale = make_pd(1., 2) df = 4 w = tfd.Wishart(df, scale_tril=chol(scale)) self.assertAllEqual(wishart_var(df, scale), self.evaluate(w.variance()))
def testSamplingEmptyDist(self): w = tfd.Wishart(df=[1], scale_tril=[[1.]], validate_args=True) self.evaluate(w[:0].sample())
def testMean(self): scale = make_pd(1., 2) df = 4 w = tfd.Wishart(df, scale_tril=chol(scale)) self.assertAllEqual(df * scale, self.evaluate(w.mean()))
def testLogProbEmptyDist(self): w = tfd.Wishart(df=[1], scale_tril=[[1.]], validate_args=True) self.evaluate(w[:0].log_prob([[1.]]))
def testMode(self): scale = make_pd(1., 2) df = 4 w = tfd.Wishart(df, scale_tril=chol(scale)) self.assertAllEqual((df - 2. - 1.) * scale, self.evaluate(w.mode()))
def testStaticAssertNonFlatDfDoesntRaise(self): # Check we don't get ValueError: The truth value of an array with more than # one element is ambiguous. Use a.any() or a.all() tfd.Wishart(df=[[2., 2]], scale=make_pd(1., 2), validate_args=True)