def testKLBatchBroadcast(self): batch_shape = [2] event_shape = [3] mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) # No batch shape. mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) mvn_a = tfd.MultivariateNormalTriL( loc=mu_a, scale_tril=np.linalg.cholesky(sigma_a), validate_args=True) mvn_b = tfd.MultivariateNormalTriL( loc=mu_b, scale_tril=np.linalg.cholesky(sigma_b), validate_args=True) kl = tfd.kl_divergence(mvn_a, mvn_b) self.assertEqual(batch_shape, kl.shape) kl_v = self.evaluate(kl) expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], mu_b, sigma_b) expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], mu_b, sigma_b) self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1])
def test_log_prob_matches_linear_gaussian_ssm(self): dim = 2 batch_shape = [3, 1] seed, *model_seeds = samplers.split_seed(test_util.test_seed(), n=6) # Sample a random linear Gaussian process. prior_loc = self.evaluate( tfd.Normal(0., 1.).sample(batch_shape + [dim], seed=model_seeds[0])) prior_scale = self.evaluate( tfd.InverseGamma(1., 1.).sample(batch_shape + [dim], seed=model_seeds[1])) transition_matrix = self.evaluate( tfd.Normal(0., 1.).sample([dim, dim], seed=model_seeds[2])) transition_bias = self.evaluate( tfd.Normal(0., 1.).sample(batch_shape + [dim], seed=model_seeds[3])) transition_scale_tril = self.evaluate( tf.linalg.cholesky( tfd.WishartTriL( df=dim, scale_tril=tf.eye(dim)).sample(seed=model_seeds[4]))) initial_state_prior = tfd.MultivariateNormalDiag( loc=prior_loc, scale_diag=prior_scale, name='initial_state_prior') lgssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=7, transition_matrix=transition_matrix, transition_noise=tfd.MultivariateNormalTriL( loc=transition_bias, scale_tril=transition_scale_tril), # Trivial observation model to pass through the latent state. observation_matrix=tf.eye(dim), observation_noise=tfd.MultivariateNormalDiag( loc=tf.zeros(dim), scale_diag=tf.zeros(dim)), initial_state_prior=initial_state_prior) markov_chain = tfd.MarkovChain( initial_state_prior=initial_state_prior, transition_fn=lambda _, x: tfd.MultivariateNormalTriL( # pylint: disable=g-long-lambda loc=tf.linalg.matvec(transition_matrix, x) + transition_bias, scale_tril=transition_scale_tril), num_steps=7) x = markov_chain.sample(5, seed=seed) self.assertAllClose(lgssm.log_prob(x), markov_chain.log_prob(x), rtol=1e-5)
def solve_weight_space() -> tfd.Distribution: d_by_d = tf.matmul(x_train, x_train, transpose_a=True) precis = tf.linalg.diag(1 / weight_var) + noise_precis * d_by_d sqprec = jitter_cholesky(precis) solv_y = parallel_solve(tf.linalg.triangular_solve, sqprec, proj, lower=True) solv_x = parallel_solve(tf.linalg.triangular_solve, sqprec, tf.linalg.matrix_transpose(x_eval), lower=True) loc = noise_precis * tf.matmul(solv_x, solv_y, transpose_a=1) if self.mean_function is not None: loc += self.mean_function(x) if full_cov: scale_tril = jitter_cholesky( tf.matmul(solv_x, solv_x, transpose_a=True)) return tfd.MultivariateNormalTriL(loc=tf.squeeze(loc, axis=-1), scale_tril=scale_tril) else: scale_diag = tf.sqrt(tf.reduce_sum(tf.square(solv_x), axis=-2)) return tfd.MultivariateNormalDiag(loc=tf.squeeze(loc, axis=-1), scale_diag=scale_diag)
def solve_function_space() -> tfd.Distribution: n_by_n = tf.matmul(x_train, weight_var[..., None, :] * x_train, transpose_b=True) precis = (1 / noise_precis) * eyeN + n_by_n sqprec = jitter_cholesky(precis) solv_x = parallel_solve(tf.linalg.triangular_solve, sqprec, x_train, lower=True) w_covar = solv_x * weight_var[..., None, :] # scratch variable w_covar = tf.linalg.diag(weight_var) \ - tf.matmul(w_covar, w_covar, transpose_a=True) loc = noise_precis * tf.matmul(x_eval, tf.matmul(w_covar, proj)) covar = tf.matmul(x_eval, tf.matmul(w_covar, x_eval, transpose_b=True)) if full_cov: scale_tril = jitter_cholesky(covar) return tfd.MultivariateNormalTriL(loc=tf.squeeze(loc, axis=-1), scale_tril=scale_tril) else: scale_diag = tf.sqrt(tf.linalg.diag_part(covar)) return tfd.MultivariateNormalDiag(loc=tf.squeeze(loc, axis=-1), scale_diag=scale_diag)
def test_bijector(self): """Employ bijector when sampling.""" dtype = np.float32 true_mean = dtype([1, 1]) true_cov = dtype([[1, 0.5], [0.5, 1]]) target = tfd.MultivariateNormalTriL( loc=true_mean, scale_tril=tf.linalg.cholesky(true_cov) ) def logp(x1, x2): return target.log_prob([x1, x2]) def kernel_make_fn(target_log_prob_fn, state): inner_kernel = tfp.mcmc.RandomWalkMetropolis(target_log_prob_fn=target_log_prob_fn) return tfp.mcmc.TransformedTransitionKernel( inner_kernel=inner_kernel, bijector=tfp.bijectors.Exp() ) kernel_list = [(0, kernel_make_fn), (1, kernel_make_fn)] kernel = GibbsKernel( target_log_prob_fn=logp, kernel_list=kernel_list ) samples = tfp.mcmc.sample_chain( num_results=20, current_state=[dtype(1), dtype(1)], kernel=kernel, trace_fn=None)
def seasonal_transition_noise(t): noise_scale_tril = dist_util.pick_scalar_condition( is_last_day_of_season(t), drift_scale_tril, tf.zeros_like(drift_scale_tril)) return tfd.MultivariateNormalTriL(loc=tf.zeros( num_seasons - 1, dtype=drift_scale.dtype), scale_tril=noise_scale_tril)
def run_chain_and_get_estimation_error(): chain_state = tfp.mcmc.sample_chain( num_results=num_steps, num_burnin_steps=0, current_state=initial_state, kernel=tfp.mcmc.NoUTurnSampler( tfd.MultivariateNormalTriL(loc=mu, scale_tril=scale_tril).log_prob, step_size=np.asarray([sigma1, sigma2]), parallel_iterations=1, seed=strm()), parallel_iterations=1, trace_fn=None) variance_est = tf.square(chain_state - mu) correlation_est = tf.reduce_prod( chain_state - mu, axis=-1, keepdims=True) / (sigma1 * sigma2) mcmc_samples = tf.concat([chain_state, variance_est, correlation_est], axis=-1) expected = tf.reduce_mean(mcmc_samples, axis=[0, 1]) ess = tf.reduce_sum(tfp.mcmc.effective_sample_size(mcmc_samples), axis=0) avg_monte_carlo_standard_error = tf.reduce_mean( tf.math.reduce_std(mcmc_samples, axis=0), axis=0) / tf.sqrt(ess) scaled_error = ( tf.abs(expected - true_param) / avg_monte_carlo_standard_error) return tfd.Normal(loc=0., scale=1.).survival_function(scaled_error)
def test_multipart_bijector(self): seed_stream = test_util.test_seed_stream() prior = tfd.JointDistributionSequential([ tfd.Gamma(1., 1.), lambda scale: tfd.Uniform(0., scale), lambda concentration: tfd.CholeskyLKJ(4, concentration), ], validate_args=True) likelihood = lambda corr: tfd.MultivariateNormalTriL(scale_tril=corr) obs = self.evaluate( likelihood( prior.sample(seed=seed_stream())[-1]).sample(seed=seed_stream())) bij = prior.experimental_default_event_space_bijector() def target_log_prob(scale, conc, corr): return prior.log_prob(scale, conc, corr) + likelihood(corr).log_prob(obs) kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob, num_leapfrog_steps=3, step_size=.5) kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bij) init = self.evaluate( tuple(tf.random.uniform(s, -2., 2., seed=seed_stream()) for s in bij.inverse_event_shape(prior.event_shape))) state = bij.forward(init) kr = kernel.bootstrap_results(state) next_state, next_kr = kernel.one_step(state, kr, seed=seed_stream()) self.evaluate((state, kr, next_state, next_kr)) expected = (target_log_prob(*state) - bij.inverse_log_det_jacobian(state, [0, 0, 2])) actual = kernel._inner_kernel.target_log_prob_fn(*init) # pylint: disable=protected-access self.assertAllClose(expected, actual)
def testLangevinCorrectVolatilityGradient(self): """Check that the gradient of the volatility is computed correctly.""" # Consider the example target distribution as in `testLangevin3DNormal` dtype = np.float32 true_mean = dtype([1, 2, 7]) true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]]) num_chains = 100 chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) def target_log_prob(x, y): # Stack the input tensors together z = tf.concat([x, y], axis=-1) return target.log_prob(z) def volatility_fn(x, y): # Stack the input tensors together return [1. / (0.5 + 0.1 * tf.abs(x + y)), 1. / (0.5 + 0.1 * tf.abs(y))] # Initial state of the chain init_state = [np.ones([num_chains, 2], dtype=dtype), np.ones([num_chains, 1], dtype=dtype)] strm = test_util.test_seed_stream() # Define MALA with constant volatility langevin_unit = tfp.mcmc.MetropolisAdjustedLangevinAlgorithm( target_log_prob_fn=target_log_prob, step_size=0.1, seed=strm()) # Define MALA with volatility being `volatility_fn` langevin_general = tfp.mcmc.MetropolisAdjustedLangevinAlgorithm( target_log_prob_fn=target_log_prob, step_size=0.1, volatility_fn=volatility_fn, seed=strm()) # Initialize the samplers kernel_unit_volatility = langevin_unit.bootstrap_results(init_state) kernel_general = langevin_general.bootstrap_results(init_state) # For `langevin_general` volatility gradient should be zero. grad_1, grad_2 = kernel_unit_volatility.accepted_results.grads_volatility self.assertAllEqual(self.evaluate(grad_1), np.zeros(shape=init_state[0].shape, dtype=dtype)) self.assertAllEqual(self.evaluate(grad_2), np.zeros(shape=init_state[1].shape, dtype=dtype)) # For `langevin_unit` volatility gradient should be around -0.926 for # each direction. grad_1, grad_2 = kernel_general.accepted_results.grads_volatility self.assertAllClose(self.evaluate(grad_1), -0.583 * np.ones(shape=init_state[0].shape, dtype=dtype), atol=0.01, rtol=0.01) self.assertAllClose(self.evaluate(grad_2), -0.926 * np.ones(shape=init_state[1].shape, dtype=dtype), atol=0.01, rtol=0.01)
def __init__( self, ndims=100, gamma_shape_parameter=0.5, max_eigvalue=None, seed=10, name='ill_conditioned_gaussian', pretty_name='Ill-Conditioned Gaussian', ): """Construct the ill-conditioned Gaussian. Args: ndims: Python `int`. Dimensionality of the Gaussian. gamma_shape_parameter: Python `float`. The shape parameter of the inverse Gamma distribution. Anything below 2 is likely to yield poorly conditioned covariance matrices. max_eigvalue: Python `float`. If set, will normalize the eigenvalues such that the maximum is this value. seed: Seed to use when generating the eigenvalues and the random orthogonal matrix. name: Python `str` name prefixed to Ops created by this class. pretty_name: A Python `str`. The pretty name of this model. """ rng = onp.random.RandomState(seed=seed & (2**32 - 1)) eigenvalues = 1. / onp.sort( rng.gamma(shape=gamma_shape_parameter, scale=1., size=ndims)) if max_eigvalue is not None: eigenvalues *= max_eigvalue / eigenvalues.max() q, r = onp.linalg.qr(rng.randn(ndims, ndims)) q *= onp.sign(onp.diag(r)) covariance = (q * eigenvalues).dot(q.T) gaussian = tfd.MultivariateNormalTriL( loc=tf.zeros(ndims), scale_tril=tf.linalg.cholesky( tf.convert_to_tensor(covariance, dtype=tf.float32))) self._eigenvalues = eigenvalues sample_transformations = { 'identity': model.Model.SampleTransformation( fn=lambda params: params, pretty_name='Identity', ground_truth_mean=onp.zeros(ndims), ground_truth_standard_deviation=onp.sqrt(onp.diag(covariance)), ) } self._gaussian = gaussian super(IllConditionedGaussian, self).__init__( default_event_space_bijector=tfb.Identity(), event_shape=gaussian.event_shape, dtype=gaussian.dtype, name=name, pretty_name=pretty_name, sample_transformations=sample_transformations, )
def test_float64(self): """Sample with dtype float64.""" dtype = np.float64 true_mean = dtype([1, 1]) true_cov = dtype([[1, 0.5], [0.5, 1]]) target = tfd.MultivariateNormalTriL( loc=true_mean, scale_tril=tf.linalg.cholesky(true_cov) ) def logp(x1, x2): return target.log_prob([x1, x2]) def kernel_make_fn(target_log_prob_fn, state): return tfp.mcmc.RandomWalkMetropolis(target_log_prob_fn=target_log_prob_fn) kernel_list = [(0, kernel_make_fn), (1, kernel_make_fn)] kernel = GibbsKernel( target_log_prob_fn=logp, kernel_list=kernel_list ) samples = tfp.mcmc.sample_chain( num_results=20, current_state=[dtype(1), dtype(1)], kernel=kernel, trace_fn=None)
def testSampleWithSampleShape(self): mu = self._rng.rand(3, 5, 2) chol, sigma = self._random_chol(3, 5, 2, 2) chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] mvn = tfd.MultivariateNormalTriL(mu, chol, validate_args=True) samples_val = self.evaluate( mvn.sample((10, 11, 12), seed=tfp_test_util.test_seed())) # Check sample shape self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape) # Check sample means x = samples_val[:, :, :, 1, 1, :] self.assertAllClose(x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=0.05) # Check that log_prob(samples) works log_prob_val = self.evaluate(mvn.log_prob(samples_val)) x_log_pdf = log_prob_val[:, :, :, 1, 1] expected_log_pdf = stats.multivariate_normal( mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x) self.assertAllClose(expected_log_pdf, x_log_pdf)
def test_3d_mvn(self): """Sample from 3-variate Gaussian Distribution.""" dtype = np.float32 true_mean = dtype([1., 2., 3.]) true_cov = dtype([[0.36, 0.12, 0.06], [0.12, 0.29, -0.13], [0.06, -0.13, 0.26]]) target = tfd.MultivariateNormalTriL( loc=true_mean, scale_tril=tf.linalg.cholesky(true_cov)) kernel = AdaptiveRandomWalkMetropolis( target_log_prob_fn=target.log_prob, initial_covariance=dtype(0.001) * np.eye(3, dtype=dtype)) samples = tfp.mcmc.sample_chain(num_results=2000, current_state=dtype([0.1, 0.1, 0.1]), kernel=kernel, num_burnin_steps=500, trace_fn=None) sample_mean = tf.math.reduce_mean(samples, axis=0) [sample_mean_] = self.evaluate([sample_mean]) self.assertAllClose(sample_mean_, true_mean, atol=0.1, rtol=0.1) sample_cov = tfp.stats.covariance(samples) sample_cov_ = self.evaluate(sample_cov) self.assertAllClose(sample_cov_, true_cov, atol=0.1, rtol=0.1)
def testFourDimNormal(self): """Sampling from a 4-D Multivariate Normal distribution.""" dtype = np.float32 true_mean = dtype([0, 4, -8, 2]) true_cov = np.eye(4, dtype=dtype) num_results, tolerance = self._get_mode_dependent_settings() num_chains = 10 target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=true_cov) # Initial state of the chain init_state = np.ones([num_chains, 4], dtype=dtype) # Run Slice Samper for `num_results` iterations for `num_chains` # independent chains: states, _ = tfp.mcmc.sample_chain( num_results=num_results, current_state=init_state, kernel=tfp.mcmc.SliceSampler( target_log_prob_fn=tf.function(target.log_prob, autograph=False), step_size=1.0, max_doublings=5, seed=test_util.test_seed_stream()), num_burnin_steps=100, parallel_iterations=1) result = tf.reshape(states, [-1, 4]) sample_mean = tf.reduce_mean(result, axis=0) sample_cov = tfp.stats.covariance(result) self.assertAllClose(true_mean, sample_mean, atol=tolerance, rtol=tolerance) self.assertAllClose(true_cov, sample_cov, atol=tolerance, rtol=tolerance)
def testKLNonBatch(self): batch_shape = [] event_shape = [2] mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) mvn_a = tfd.MultivariateNormalTriL( loc=mu_a, scale_tril=np.linalg.cholesky(sigma_a), validate_args=True) mvn_b = tfd.MultivariateNormalTriL( loc=mu_b, scale_tril=np.linalg.cholesky(sigma_b), validate_args=True) kl = tfd.kl_divergence(mvn_a, mvn_b) self.assertEqual(batch_shape, kl.shape) kl_v = self.evaluate(kl) expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b) self.assertAllClose(expected_kl, kl_v)
def _fn(kernel_size, bias_size, dtype=None): smallconst = np.log(np.expm1(1.)) n_weights_block = kernel_size//C n_bias_block = bias_size//C n_weight_mean_params = n_weights_block n_weight_cov_params = tfp.layers.MultivariateNormalTriL.params_size(n_weights_block) - n_weights_block n_params_total = C*(n_weight_mean_params + n_weight_cov_params) + bias_size #print("{} params in total".format(n_params_total)) block_param_indices = tf.split(np.arange(n_params_total - bias_size), C) split_array = [n_weight_mean_params, n_weight_cov_params] split_param_idxs = [tf.split(x, split_array, axis=0) for x in block_param_indices] model = tf.keras.Sequential([ tfpl.VariableLayer(n_params_total, dtype=dtype), tfpl.DistributionLambda(lambda t: tfd.Blockwise( [ tfd.MultivariateNormalTriL( loc=tf.gather(t,split_param_idxs[c][0], axis=-1), scale_tril=tfp.math.fill_triangular( 1e-5 + tf.nn.softplus(smallconst + tf.gather(t,split_param_idxs[c][1], axis=-1))) ) for c in range(C) ] + [ tfd.VectorDeterministic(loc=t[...,-bias_size:]) ] ) ) ]) return model
def testTwoDimNormalDynamicShape(self): """Checks that dynamic batch shapes for the initial state are supported.""" if tf.executing_eagerly(): return dtype = np.float32 true_mean = dtype([0, 0]) true_cov = dtype([[1, 0.5], [0.5, 1]]) num_results = 200 num_chains = 75 # Target distribution is defined through the Cholesky decomposition. chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) # Assume that the state is passed as a list of 1-d tensors `x` and `y`. # Then the target log-density is defined as follows: def target_log_prob(x, y): # Stack the input tensors together z = tf.stack([x, y], axis=-1) - true_mean return target.log_prob(z) # Initial state of the chain init_state = [ np.ones([num_chains, 1], dtype=dtype), np.ones([num_chains, 1], dtype=dtype) ] placeholder_init_state = [ tf1.placeholder_with_default(init_state[0], shape=[None, 1]), tf1.placeholder_with_default(init_state[1], shape=[None, 1]) ] # Run Slice Samper for `num_results` iterations for `num_chains` # independent chains: [x, y], _ = tfp.mcmc.sample_chain( num_results=num_results, current_state=placeholder_init_state, kernel=tfp.mcmc.SliceSampler(target_log_prob_fn=target_log_prob, step_size=1.0, max_doublings=5, seed=47), num_burnin_steps=200, num_steps_between_results=1, parallel_iterations=1) states = tf.stack([x, y], axis=-1) sample_mean = tf.reduce_mean(input_tensor=states, axis=[0, 1]) z = states - sample_mean sample_cov = tf.reduce_mean(input_tensor=tf.matmul(z, z, transpose_a=True), axis=[0, 1]) [sample_mean, sample_cov] = self.evaluate([sample_mean, sample_cov]) self.assertAllClose(true_mean, b=np.squeeze(sample_mean), atol=0.1, rtol=0.1) self.assertAllClose(true_cov, b=np.squeeze(sample_cov), atol=0.1, rtol=0.1)
def testVariableLocation(self): loc = tf.Variable([1., 1.]) scale = tf.eye(2) d = tfd.MultivariateNormalTriL(loc, scale, validate_args=True) self.evaluate(loc.initializer) with tf.GradientTape() as tape: lp = d.log_prob([0., 0.]) self.assertIsNotNone(tape.gradient(lp, loc))
def testVariableScale(self): loc = tf.constant([1., 1.]) scale = tf.Variable([[1., 0.], [0., 1.]]) d = tfd.MultivariateNormalTriL(loc, scale, validate_args=True) self.evaluate(scale.initializer) with tf.GradientTape() as tape: lp = d.log_prob([0., 0.]) self.assertIsNotNone(tape.gradient(lp, scale))
def new(params, event_size, validate_args=False, name=None): """Create the distribution instance from a `params` vector.""" with tf.name_scope(name, 'MultivariateNormalTriL', [params, event_size]): return tfd.MultivariateNormalTriL( loc=params[..., :event_size], scale_tril=tfb.ScaleTriL(validate_args=validate_args)( params[..., event_size:]), validate_args=validate_args)
def _mvn_pair(self, loc, cov_operator): """Construct a pair of MVNs.""" tril = tfd.MultivariateNormalTriL( loc=loc, scale_tril=cov_operator.cholesky().to_dense()) low_rank_update = ( MultivariateNormalLowRankUpdateLinearOperatorCovariance( loc=loc, cov_operator=cov_operator, validate_args=True)) return MVNPair(low_rank_update=low_rank_update, tril=tril)
def solve_weight_space() -> tfd.Distribution: d_by_d = tf.matmul(x, x, transpose_a=True) precis = tf.linalg.diag(1 / weight_var) + noise_precis * d_by_d sqprec = jitter_cholesky(precis) means = noise_precis * parallel_solve(tf.linalg.cholesky_solve, sqprec, proj) scale_tril = CholeskyToInvCholesky()(sqprec) # [!] improve me return tfd.MultivariateNormalTriL(loc=tf.squeeze(means, -1), scale_tril=scale_tril)
def ensemble_kalman_filter_log_marginal_likelihood(state, observation, observation_fn, seed=None, name=None): """Ensemble Kalman Filter Log Marginal Likelihood. The [Ensemble Kalman Filter]( https://en.wikipedia.org/wiki/Ensemble_Kalman_filter) is a Monte Carlo version of the traditional Kalman Filter. This method estimates (logarithm of) the marginal likelihood of the observation at step `k`, `Y_k`, given previous observations from steps `1` to `k-1`, `Y_{1:k}`. In other words, `Log[p(Y_k | Y_{1:k})]`. This function's approximation to `p(Y_k | Y_{1:k})` is correct under a Linear Gaussian state space model assumption, as ensemble size --> infinity. Args: state: Instance of `EnsembleKalmanFilterState` at step `k`, conditioned on previous observations `Y_{1:k}`. Typically this is the output of `ensemble_kalman_filter_predict`. observation: `Tensor` representing the observation at step `k`. observation_fn: callable returning an instance of `tfd.MultivariateNormalLinearOperator` along with an extra information to be returned in the `EnsembleKalmanFilterState`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'ensemble_kalman_filter_log_marginal_likelihood'`). Returns: log_marginal_likelihood: `Tensor` with same dtype as `state`. """ with tf.name_scope(name or 'ensemble_kalman_filter_log_marginal_likelihood'): observation_particles_dist, unused_extra = observation_fn( state.step, state.particles, state.extra) common_dtype = dtype_util.common_dtype( [observation_particles_dist, observation], dtype_hint=tf.float32) observation = tf.convert_to_tensor(observation, dtype=common_dtype) if not isinstance(observation_particles_dist, distributions.MultivariateNormalLinearOperator): raise ValueError( 'Expected `observation_fn` to return an instance of ' '`MultivariateNormalLinearOperator`') observation_particles = observation_particles_dist.sample(seed=seed) observation_dist = distributions.MultivariateNormalTriL( loc=tf.reduce_mean(observation_particles, axis=0), scale_tril=tf.linalg.cholesky(_covariance(observation_particles))) return observation_dist.log_prob(observation)
def testGradientWorksForMultivariateNormalTriL(self): # TODO(b/72831017): Remove this once bijector cacheing is fixed for # graph mode. if not tf.executing_eagerly(): self.skipTest('Gradients get None values in graph mode.') d = tfd.MultivariateNormalTriL(scale_tril=tf.eye(2)) x = d.sample(seed=test_util.test_seed()) fn_result, grads = util.maybe_call_fn_and_grads(d.log_prob, x) self.assertAllEqual(False, fn_result is None) self.assertAllEqual([False], [g is None for g in grads])
def testEntropy(self): mu = self._rng.rand(2) chol, sigma = self._random_chol(2, 2) mvn = tfd.MultivariateNormalTriL(mu, chol, validate_args=True) entropy = mvn.entropy() scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) expected_entropy = scipy_mvn.entropy() self.assertEqual(entropy.shape, ()) self.assertAllClose(expected_entropy, self.evaluate(entropy))
def testSample(self): mu = self._rng.rand(2) chol, sigma = self._random_chol(2, 2) n = tf.constant(100000) mvn = tfd.MultivariateNormalTriL(mu, chol, validate_args=True) samples = mvn.sample(n, seed=test_util.test_seed()) sample_values = self.evaluate(samples) self.assertEqual(samples.shape, [int(100e3), 2]) self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2) self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06)
def testLangevin3DNormal(self): """Sampling from a 3-D Multivariate Normal distribution.""" dtype = np.float32 true_mean = dtype([1, 2, 7]) true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]]) num_results = 500 num_chains = 500 # Target distribution is defined through the Cholesky decomposition chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) # Assume that the state is passed as a list of tensors `x` and `y`. # Then the target log-density is defined as follows: def target_log_prob(x, y): # Stack the input tensors together z = tf.concat([x, y], axis=-1) return target.log_prob(z) # Initial state of the chain init_state = [ np.ones([num_chains, 2], dtype=dtype), np.ones([num_chains, 1], dtype=dtype) ] # Run MALA with normal proposal for `num_results` iterations for # `num_chains` independent chains: states, _ = tfp.mcmc.sample_chain( num_results=num_results, current_state=init_state, kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm( target_log_prob_fn=target_log_prob, step_size=.1, seed=test_util.test_seed()), num_burnin_steps=200, num_steps_between_results=1, parallel_iterations=1) states = tf.concat(states, axis=-1) sample_mean = tf.reduce_mean(states, axis=[0, 1]) x = (states - sample_mean)[..., tf.newaxis] sample_cov = tf.reduce_mean(tf.matmul( x, tf.transpose(a=x, perm=[0, 1, 3, 2])), axis=[0, 1]) sample_mean_, sample_cov_ = self.evaluate([sample_mean, sample_cov]) self.assertAllClose(np.squeeze(sample_mean_), true_mean, atol=0.1, rtol=0.1) self.assertAllClose(np.squeeze(sample_cov_), true_cov, atol=0.1, rtol=0.1)
def testDocstrSliceExample(self): x = tf.random.normal([5, 3, 2, 2]) cov = tf.matmul(x, x, transpose_b=True) chol = tf.linalg.cholesky(cov) loc = tf.random.normal([4, 1, 3, 1]) mvn = tfd.MultivariateNormalTriL(loc, chol) self.assertAllEqual((4, 5, 3), mvn.batch_shape) self.assertAllEqual((2, ), mvn.event_shape) mvn2 = mvn[:, 3:, ..., ::-1, tf.newaxis] self.assertAllEqual((4, 2, 3, 1), mvn2.batch_shape) self.assertAllEqual((2, ), mvn2.event_shape)
def testEntropyMultidimensional(self): mu = self._rng.rand(3, 5, 2) chol, sigma = self._random_chol(3, 5, 2, 2) mvn = tfd.MultivariateNormalTriL(mu, chol, validate_args=True) entropy = mvn.entropy() # Scipy doesn't do batches, so test one of them. expected_entropy = stats.multivariate_normal( mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy() self.assertEqual(entropy.shape, (3, 5)) self.assertAllClose(expected_entropy, self.evaluate(entropy)[1, 1])
def testJacobianDiagonal3DListInput(self): """Tests that the diagonal of the Jacobian matrix computes correctly.""" dtype = np.float32 true_mean = dtype([0, 0, 0]) true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]]) chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) # Assume that the state is passed as a list of tensors `x` and `y`. # Then the target function is defined as follows: def target_fn(x, y): # Stack the input tensors together z = tf.concat([x, y], axis=-1) - true_mean return target.log_prob(z) sample_shape = [3, 5] state = [ tf.ones(sample_shape + [2], dtype=dtype), tf.ones(sample_shape + [1], dtype=dtype) ] fn_val, grads = tfp.math.value_and_gradient(target_fn, state) grad_fn = lambda *args: tfp.math.value_and_gradient(target_fn, args)[1] _, diag_jacobian_shape_passed = tfp.math.diag_jacobian( xs=state, ys=grads, fn=grad_fn, sample_shape=ps.shape(fn_val)) _, diag_jacobian_shape_none = tfp.math.diag_jacobian(xs=state, ys=grads, fn=grad_fn) true_diag_jacobian_1 = np.zeros(sample_shape + [2]) true_diag_jacobian_1[..., 0] = -1.05 true_diag_jacobian_1[..., 1] = -0.52 true_diag_jacobian_2 = -0.34 * np.ones(sample_shape + [1]) self.assertAllClose(self.evaluate(diag_jacobian_shape_passed[0]), true_diag_jacobian_1, atol=0.01, rtol=0.01) self.assertAllClose(self.evaluate(diag_jacobian_shape_none[0]), true_diag_jacobian_1, atol=0.01, rtol=0.01) self.assertAllClose(self.evaluate(diag_jacobian_shape_passed[1]), true_diag_jacobian_2, atol=0.01, rtol=0.01) self.assertAllClose(self.evaluate(diag_jacobian_shape_none[1]), true_diag_jacobian_2, atol=0.01, rtol=0.01)