def test_multiple_chains(self, make_kernel, target_accept_rate): num_chains = 16 num_samples = 4000 sample_key, chain_key, init_key = random.split(self._seed, 3) unconstrained_log_prob = self._make_unconstrained_log_prob() initial_states = jax.vmap(self._initialize_state)(random.split( init_key, num_chains)) kernel = make_kernel(unconstrained_log_prob) sample_chain = jax.jit( jax.vmap( harvest.harvest(kernels.sample_chain(kernel, num_samples), tag=kernels.MCMC_METRICS))) true_samples = self.model.sample(sample_shape=4096, seed=sample_key) samples, metrics = sample_chain({}, random.split(chain_key, num_chains), initial_states) samples = tf.nest.map_structure( lambda s, shape: s.reshape([num_chains * num_samples] + list(shape) ), samples, self.model.event_shape) onp.testing.assert_allclose(true_samples.mean(axis=0), samples.mean(axis=0), rtol=0.1, atol=0.1) onp.testing.assert_allclose(np.cov(true_samples.T), np.cov(samples.T), rtol=0.1, atol=0.1) onp.testing.assert_allclose(target_accept_rate, metrics['kernel']['accept_prob'].mean(), atol=1e-2, rtol=1e-2)
def test_get_proposal_loc_and_scale(dense_mass): N = 10 dim = 3 samples = random.normal(random.PRNGKey(0), (N, dim)) loc = np.mean(samples[:-1], 0) if dense_mass: scale = np.linalg.cholesky( np.cov(samples[:-1], rowvar=False, bias=True)) else: scale = np.std(samples[:-1], 0) actual_loc, actual_scale = _get_proposal_loc_and_scale( samples[:-1], loc, scale, samples[-1]) expected_loc, expected_scale = [], [] for i in range(N - 1): samples_i = onp.delete(samples, i, axis=0) expected_loc.append(np.mean(samples_i, 0)) if dense_mass: expected_scale.append( np.linalg.cholesky(np.cov(samples_i, rowvar=False, bias=True))) else: expected_scale.append(np.std(samples_i, 0)) expected_loc = np.stack(expected_loc) expected_scale = np.stack(expected_scale) assert_allclose(actual_loc, expected_loc, rtol=1e-4) assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.05)
def test_single_chain(self, make_kernel, target_accept_rate): num_samples = 20000 sample_key, chain_key, init_key = random.split(self._seed, 3) unconstrained_log_prob = self._make_unconstrained_log_prob() initial_state = self._initialize_state(init_key) kernel = make_kernel(unconstrained_log_prob) sample_chain = jax.jit( harvest.harvest(kernels.sample_chain(kernel, num_samples), tag=kernels.MCMC_METRICS)) true_samples = self.model.sample(sample_shape=4096, seed=sample_key) samples, metrics = sample_chain({}, chain_key, initial_state) onp.testing.assert_allclose(true_samples.mean(axis=0), samples.mean(axis=0), rtol=0.5, atol=0.1) onp.testing.assert_allclose(np.cov(true_samples.T), np.cov(samples.T), rtol=0.5, atol=0.1) onp.testing.assert_allclose(target_accept_rate, metrics['kernel']['accept_prob'].mean(), atol=1e-2, rtol=1e-2)
def statistics(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray], rng: random.PRNGKey): # Split pseudo-random number key. rng, rng_sample, rng_xobs, rng_kl = random.split(rng, 4) # Compute comparison statistics. _, xsph, _ = ode_forward(rng_sample, net_params, 10000, 4) xobs = rejection_sampling(rng_xobs, len(xsph), 4, embedded_sphere_density) mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T))) approx = importance_density(rng_kl, net_params, deq_params, 10000, xsph) log_approx = jnp.log(approx) target = embedded_sphere_density(xsph) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, approx, log_target, target, xsph approx = importance_density(rng_kl, net_params, deq_params, 10000, xobs) log_approx = jnp.log(approx) target = embedded_sphere_density(xobs) w = approx / target Z = jnp.nanmean(w) log_target = jnp.log(target) klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z) del w, Z, log_approx, approx, log_target, target method = 'deqode ({})'.format('ELBO' if args.elbo_loss else 'KL') print( '{} - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%' .format(method, mean_mse, cov_mse, klqp, klpq, ress))
def main(): # Set pseudo-random number generator keys. rng = random.PRNGKey(args.seed) rng, rng_net = random.split(rng, 2) rng, rng_sample, rng_xobs, rng_basis = random.split(rng, 4) rng, rng_fwd, rng_rev = random.split(rng, 3) rng, rng_kl = random.split(rng, 2) # Initialize the parameters of the ambient vector field network. _, params = net_init(rng_net, (-1, 4)) opt_state = opt_init(params) for it in range(args.num_steps): opt_state, kl = step(opt_state, it, args.num_samples) print('iter.: {} - kl: {:.4f}'.format(it, kl)) params = get_params(opt_state) count = lambda x: jnp.prod(jnp.array(x.shape)) num_params = jnp.array( tree_util.tree_map(count, tree_util.tree_flatten(params)[0])).sum() print('number of parameters: {}'.format(num_params)) # Compute comparison statistics. xsph, log_approx = manifold_ode_log_prob(params, rng_sample, 10000) xobs = rejection_sampling(rng_xobs, len(xsph), 3, embedded_sphere_density) mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T))) approx = jnp.exp(log_approx) target = embedded_sphere_density(xsph) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, approx, log_target, target log_approx = manifold_reverse_ode_log_prob(params, rng_kl, xobs) approx = jnp.exp(log_approx) target = embedded_sphere_density(xobs) w = approx / target Z = jnp.nanmean(w) log_target = jnp.log(target) klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z) del w, Z, log_approx, approx, log_target, target print( 'manode - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%' .format(mean_mse, cov_mse, klqp, klpq, ress))
def _test_cov(self, sample: cdict): if sample.value.ndim == 3: val = jnp.concatenate(sample.value) else: val = sample.value samp_cov = jnp.cov(val.T) npt.assert_array_almost_equal(samp_cov, self.scenario_cov, decimal=0.5)
def cylinder_to_gaussian_sample(key, raydir, t0, t1, radius, padding=1, num_samples=1000000): # Sample uniformly from a cube that surrounds the entire conical frustom. z_max = max(t0, t1) samples = random.uniform(key, [num_samples, 3], minval=jnp.min(raydir) * z_max - padding, maxval=jnp.max(raydir) * z_max + padding) # Grab only the points within the cylinder. raydir_magsq = jnp.sum(raydir**2, -1, keepdims=True) proj = (raydir * (samples @ raydir)[:, None]) / raydir_magsq dist = samples @ raydir mask = (dist >= raydir_magsq * t0) & (dist <= raydir_magsq * t1) & (jnp.sum( (proj - samples)**2, -1) < radius**2) samples = samples[mask, :] # Compute their mean and covariance. mean = jnp.mean(samples, 0) cov = jnp.cov(samples.T, bias=False) return mean, cov
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None): if isinstance(m, JaxArray): m = m.value if isinstance(y, JaxArray): y = y.value if isinstance(fweights, JaxArray): fweights = fweights.value if isinstance(aweights, JaxArray): aweights = aweights.value return JaxArray(jnp.cov(m, y=y, rowvar=rowvar, bias=bias, ddof=ddof, fweights=fweights, aweights=aweights))
def estimate_cov(self, samples: jnp.ndarray): """Estimates the covariance matrix using shrinkage.""" n, d = samples.shape rho = self.rho_fn(samples) shrinkage_rho = 1 / (1 + (n - 1) * rho) sample_cov = jnp.cov(samples, rowvar=False) return shrinkage_rho * jnp.eye(d) + (1 - shrinkage_rho) * sample_cov
def sample_kernel(sa_state, model_args=(), model_kwargs=None): pe_fn = potential_fn if potential_fn_gen: pe_fn = potential_fn_gen(*model_args, **model_kwargs) zs, pes, loc, scale = sa_state.adapt_state # we recompute loc/scale after each iteration to avoid precision loss # XXX: consider to expose a setting to do this job periodically # to save some computations loc = jnp.mean(zs, 0) if scale.ndim == 2: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky) else: scale = jnp.std(zs, 0) rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4) _, unravel_fn = ravel_pytree(sa_state.z) z = loc + _sample_proposal(scale, rng_key_z) pe = pe_fn(unravel_fn(z)) pe = jnp.where(jnp.isnan(pe), jnp.inf, pe) diverging = (pe - sa_state.potential_energy) > max_delta_energy # NB: all terms having the pattern *s will have shape N x ... # and all terms having the pattern *s_ will have shape (N + 1) x ... locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z) zs_ = jnp.concatenate([zs, z[None, :]]) pes_ = jnp.concatenate([pes, pe[None]]) locs_ = jnp.concatenate([locs, loc[None, :]]) scales_ = jnp.concatenate([scales, scale[None, ...]]) if scale.ndim == 2: # dense_mass log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_ else: log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_ # mask invalid values (nan, +inf) by -inf log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf) # get rejecting index j = random.categorical(rng_key_reject, log_weights_) zs = _numpy_delete(zs_, j) pes = _numpy_delete(pes_, j) loc = locs_[j] scale = scales_[j] adapt_state = SAAdaptState(zs, pes, loc, scale) # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`. accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_)) itr = sa_state.i + 1 n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n # XXX: we make a modification of SA sampler in [1] # in [1], each MCMC state contains N points `zs` # here we do resampling to pick randomly a point from those N points k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def parametric(subposteriors, diagonal=False): """ Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm. **References:** 1. *Asymptotically Exact, Embarrassingly Parallel MCMC*, Willie Neiswanger, Chong Wang, Eric Xing :param list subposteriors: a list in which each element is a collection of samples. :param bool diagonal: whether to compute weights using variance or covariance, defaults to `False` (using covariance). :return: the estimated mean and variance/covariance parameters of the joined posterior """ joined_subposteriors = tree_multimap(lambda *args: np.stack(args), *subposteriors) joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors) submeans = np.mean(joined_subposteriors, axis=1) if diagonal: # NB: jax.numpy.var does not support ddof=1, so we do it manually weights = vmap(lambda x: 1 / np.var(x, ddof=1, axis=0))(joined_subposteriors) var = 1 / np.sum(weights, axis=0) normalized_weights = var * weights # comparing to consensus implementation, we compute weighted mean here mean = np.einsum('ij,ij->j', normalized_weights, submeans) return mean, var else: weights = vmap(lambda x: np.linalg.inv(np.cov(x.T)))(joined_subposteriors) cov = np.linalg.inv(np.sum(weights, axis=0)) normalized_weights = np.matmul(cov, weights) # comparing to consensus implementation, we compute weighted mean here mean = np.einsum('ijk,ik->j', normalized_weights, submeans) return mean, cov
def compute_cov( x0_key: PrngKey, dyn_sys: DynamicalSystem, num_samples: int, num_warmup_steps: int, ) -> Array: """ Computes covariance matrix for states sampled from a dynamical system after initial warmup integration. Yields unbiased covariance estimate, i.e., with `(N - 1)` normalization. Args: x0_key: random number key. dyn_sys: dynamical system. num_samples: number of independen states to sample from dynamical system. num_warmup_steps: number of warmup steps before sampling. Returns: Spatial covariance matrix. """ X = generate_data( x0_key, dyn_sys, num_samples, num_warmup_steps, ) grid_size = dyn_sys.grid_size num_vars = reduce(mul, dyn_sys.state_dim) C = jnp.cov(X.reshape(-1, num_vars), rowvar=False) return C
def test_mvn(shape=(1000, 5)): key = jr.PRNGKey(time.time_ns()) data = 5 * jr.normal(key, shape=shape) mvn = dists.MultivariateNormalFullCovariance.fit(data) assert np.allclose(data.mean(axis=0), mvn.loc, atol=1e-6) assert np.allclose(np.cov(data, rowvar=False, bias=True), mvn.covariance(), atol=1e-6)
def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None): """ Merges subposteriors following consensus Monte Carlo algorithm. **References:** 1. *Bayes and big data: The consensus Monte Carlo algorithm*, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch :param list subposteriors: a list in which each element is a collection of samples. :param int num_draws: number of draws from the merged posterior. :param bool diagonal: whether to compute weights using variance or covariance, defaults to `False` (using covariance). :param jax.random.PRNGKey rng_key: source of the randomness, defaults to `jax.random.PRNGKey(0)`. :return: if `num_draws` is None, merges subposteriors without resampling; otherwise, returns a collection of `num_draws` samples with the same data structure as each subposterior. """ # stack subposteriors joined_subposteriors = tree_multimap(lambda *args: jnp.stack(args), *subposteriors) # shape of joined_subposteriors: n_subs x n_samples x sample_shape joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))( joined_subposteriors ) if num_draws is not None: rng_key = random.PRNGKey(0) if rng_key is None else rng_key # randomly gets num_draws from subposteriors n_subs = len(subposteriors) n_samples = tree_flatten(subposteriors[0])[0][0].shape[0] # shape of draw_idxs: n_subs x num_draws x sample_shape draw_idxs = random.randint( rng_key, shape=(n_subs, num_draws), minval=0, maxval=n_samples ) joined_subposteriors = vmap(lambda x, idx: x[idx])( joined_subposteriors, draw_idxs ) if diagonal: # compute weights for each subposterior (ref: Section 3.1 of [1]) weights = vmap(lambda x: 1 / jnp.var(x, ddof=1, axis=0))(joined_subposteriors) normalized_weights = weights / jnp.sum(weights, axis=0) # get weighted samples samples_flat = jnp.einsum( "ij,ikj->kj", normalized_weights, joined_subposteriors ) else: weights = vmap(lambda x: jnp.linalg.inv(jnp.cov(x.T)))(joined_subposteriors) normalized_weights = jnp.matmul( jnp.linalg.inv(jnp.sum(weights, axis=0)), weights ) samples_flat = jnp.einsum( "ijk,ilk->lj", normalized_weights, joined_subposteriors ) # unravel_fn acts on 1 sample of a subposterior _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0])) return vmap(lambda x: unravel_fn(x))(samples_flat)
def normalization_factor(data, bw): data_covariance = jnp.cov(data[:, jnp.newaxis], rowvar=0, bias=False) covariance = data_covariance * bw**2 stdev = jnp.sqrt(covariance) return stdev
def init_kernel(init_params, num_warmup, adapt_state_size=None, inverse_mass_matrix=None, dense_mass=False, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0)): nonlocal wa_steps wa_steps = num_warmup pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( 'Only one of `potential_fn` or `potential_fn_gen` must be provided.' ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) rng_key_sa, rng_key_zs, rng_key_z = random.split(rng_key, 3) z = init_params z_flat, unravel_fn = ravel_pytree(z) if inverse_mass_matrix is None: inverse_mass_matrix = jnp.identity( z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1]) inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \ else jnp.sqrt(inverse_mass_matrix) if adapt_state_size is None: # XXX: heuristic choice adapt_state_size = 2 * z_flat.shape[-1] else: assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.' # NB: mean is init_params zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size, )) # compute potential energies pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs) if dense_mass: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) # if cholesky is NaN, we use the scale from `sample_proposal` here inv_mass_matrix_sqrt = jnp.where(jnp.any(jnp.isnan(cholesky)), inv_mass_matrix_sqrt, cholesky) else: inv_mass_matrix_sqrt = jnp.std(zs, 0) adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0), inv_mass_matrix_sqrt) k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] sa_state = SAState(jnp.array(0), z, pe, jnp.zeros(()), jnp.zeros(()), jnp.array(False), adapt_state, rng_key_sa) return device_put(sa_state)
def test_single_chain(self, make_kernel, target_accept_rate): num_samples = 20000 sample_key, chain_key, init_key = random.split(self._seed, 3) unconstrained_log_prob = self._make_unconstrained_log_prob() initial_state = self._initialize_state(init_key) kernel = make_kernel(unconstrained_log_prob) sample_chain = jax.jit(kernels.sample_chain(kernel, num_samples)) true_samples = self.model.sample(sample_shape=4096, seed=sample_key) samples = sample_chain(chain_key, initial_state) onp.testing.assert_allclose(true_samples.mean(axis=0), samples.mean(axis=0), rtol=0.5, atol=0.1) onp.testing.assert_allclose(np.cov(true_samples.T), np.cov(samples.T), rtol=0.5, atol=0.1)
def test_adaptive_diag_stepsize(self): sampler = RandomWalkABC() sampler.tuning.target = 0.1 sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0), correction=RMMetropolisDiagStepsize()) self._test_mean(sample.value) self._test_cov(sample.value) npt.assert_almost_equal(sample.alpha.mean(), sampler.tuning.target, decimal=1) npt.assert_array_almost_equal(sample.stepsize[-1], jnp.diag(jnp.cov(sample.value, rowvar=False)) / self.scenario.dim * 2.38 ** 2, decimal=1)
def get_test_locations(samples, J=10, key=random.PRNGKey(0)): _, dim = jnp.shape(samples) gauss_mean = jnp.mean(samples, axis=0) gauss_cov = jnp.cov(samples.T) + 1e-10*jnp.eye(dim) gauss_chol = jnp.linalg.cholesky(gauss_cov) # gauss_chol = jnp.diag(jnp.diag(jnp.linalg.cholesky(gauss_cov))) # diagonal version batch_get_samples = vmap(lambda k: jnp.dot(gauss_chol, random.normal(key, shape=(dim,))) + gauss_mean) # gauss_chol = jnp.std(samples, axis=0) # diagonal cholesky # batch_get_samples = vmap(lambda k: gauss_chol*random.normal(key, shape=(dim,)) + gauss_mean) V = batch_get_samples(random.split(key, J)) return V
def metric(samp_arr, log_weight=None): if isinstance(samp_arr, mocat.cdict): samp_arr = samp_arr.value if samp_arr.ndim == 3: samp_arr = samp_arr[..., 0].T if log_weight is not None: samp_arr = samp_arr[random.categorical(random.PRNGKey(0), log_weight, shape=(len(samp_arr), ))] mean = samp_arr.mean(0) cov = jnp.cov(samp_arr.T, ddof=1) return 0.5 * (jnp.trace(prec_post @ cov) + (mean_post - mean).T @ prec_post @ (mean_post - mean) - len_t + jnp.log(jnp.linalg.det(cov_post) / jnp.linalg.det(cov)))
def test_sample(self): np.random.seed(0) n, d = 1000, 5 batch_size = 10 num_samples = 200 parallel_chains = 1 obj = create_random_least_squares( num_objectives=1, batch_size=batch_size, n_features=(d - 1), n_samples=(n, ), lam=1e-3, )[0] opt = obj.solve() q_obj = Quadratic.from_least_squares(obj) posterior_cov = jnp.linalg.pinv(q_obj.A) posterior_cov /= jnp.trace(posterior_cov) # Approximate sampling from the posterior. prng_key = random.PRNGKey(0) sampler = IASG( avg_steps=100, burnin_steps=100, learning_rate=1.0, discard_steps=100, ) prng_key, subkey = random.split(prng_key) init_state = random.normal(subkey, shape=(d, )) samples = sampler.sample( objective=obj, prng_key=prng_key, init_state=init_state, num_samples=num_samples, parallel_chains=parallel_chains, ) self.assertEqual(samples.shape[0], num_samples) sample_mean = jnp.mean(samples, axis=0) sample_cov = jnp.cov(samples, rowvar=False) sample_cov /= jnp.trace(sample_cov) sample_cov_fro_err = jnp.linalg.norm(sample_cov - posterior_cov, "fro") np.testing.assert_allclose(sample_mean, opt, rtol=1e-1, atol=1e-1) np.testing.assert_allclose(sample_cov_fro_err, 0.0, rtol=1e-1, atol=1e-1)
def em(X, k, T, key): n, _ = X.shape # Initialize centroids using k-means++ scheme centroids = kmeans_pp_init(X, k, key) # [k, d, d] covs = jnp.array([jnp.cov(X, rowvar=False)] * k) # centroid mixture weights, [k] log_weights = -jnp.ones(k) * jnp.log(n) def update_centroids_body(unused_t, state): centroids, covs, log_weights, _ = state # E step # [n, k] log_ps = vmap(jscipy.stats.multivariate_normal.logpdf, in_axes=(None, 0, 0))(X, centroids, covs) log_ps = log_ps.T # [n, 1] log_zs = jscipy.special.logsumexp(log_ps + log_weights[jnp.newaxis, :], axis=1, keepdims=True) # [n, k] log_mem_weights = log_ps + log_weights[jnp.newaxis, :] - log_zs # M step # [k] log_ns = jscipy.special.logsumexp(log_mem_weights, axis=0) # [k] log_weights = log_ns - jnp.log(n) # Compute new centroids # [k, d] centroids = jnp.sum((X[:, jnp.newaxis, :] * jnp.exp(log_mem_weights)[:, :, jnp.newaxis]) / jnp.exp(log_ns[jnp.newaxis, :, jnp.newaxis]), axis=0) # [n, k, d] centered_x = X[:, jnp.newaxis, :] - centroids[jnp.newaxis, :, :] # [n, k, d, d] outers = jnp.einsum('...i,...j->...ij', centered_x, centered_x) weighted_outers = outers * jnp.exp( log_mem_weights[Ellipsis, jnp.newaxis, jnp.newaxis]) covs = jnp.sum(weighted_outers, axis=0) / jnp.exp( log_ns[:, jnp.newaxis, jnp.newaxis]) return (centroids, covs, log_weights, log_mem_weights) out_centroids, out_covs, _, log_mem_weights = jax.lax.fori_loop( 0, T, update_centroids_body, (centroids, covs, log_weights, jnp.zeros([n, k]))) return out_centroids, out_covs, log_mem_weights
def __init__(self, dataset, bw_method=None, weights=None): _check_arraylike("gaussian_kde", dataset) dataset = jnp.atleast_2d(dataset) if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating): raise NotImplementedError( "gaussian_kde does not support complex data") if not dataset.size > 1: raise ValueError("`dataset` input should have multiple elements.") d, n = dataset.shape if weights is not None: _check_arraylike("gaussian_kde", weights) dataset, weights = _promote_dtypes_inexact(dataset, weights) weights = jnp.atleast_1d(weights) weights /= jnp.sum(weights) if weights.ndim != 1: raise ValueError("`weights` input should be one-dimensional.") if len(weights) != n: raise ValueError("`weights` input should be of length n") else: dataset, = _promote_dtypes_inexact(dataset) weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype) self._setattr("dataset", dataset) self._setattr("weights", weights) neff = self._setattr("neff", 1 / jnp.sum(weights**2)) bw_method = "scott" if bw_method is None else bw_method if bw_method == "scott": factor = jnp.power(neff, -1. / (d + 4)) elif bw_method == "silverman": factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4)) elif jnp.isscalar(bw_method) and not isinstance(bw_method, str): factor = bw_method elif callable(bw_method): factor = bw_method(self) else: raise ValueError( "`bw_method` should be 'scott', 'silverman', a scalar, or a callable." ) data_covariance = jnp.atleast_2d( jnp.cov(dataset, rowvar=1, bias=False, aweights=weights)) data_inv_cov = jnp.linalg.inv(data_covariance) covariance = data_covariance * factor**2 inv_cov = data_inv_cov / factor**2 self._setattr("covariance", covariance) self._setattr("inv_cov", inv_cov)
def main(args): N, D_X, D_H = args.num_data, 3, args.num_hidden X, Y, X_test = get_data(N=N, D_X=D_X) # do inference rng_key, rng_key_predict = random.split(random.PRNGKey(0)) samples = run_inference(model, args, rng_key, X, Y, D_H) # predict Y_test at inputs X_test vmap_args = (samples, random.split(rng_key_predict, args.num_samples * args.num_chains)) predictions = vmap(lambda samples, rng_key: predict( model, rng_key, samples, X_test, D_H))(*vmap_args) predictions = predictions[..., 0] # compute mean prediction and confidence interval around median mean_prediction = np.mean(predictions, axis=0) percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0) # make plots fig, ax = plt.subplots(1, 1) # plot training data ax.plot(X[:, 1], Y[:, 0], 'kx') # plot 90% confidence level of predictions ax.fill_between(X_test[:, 1], percentiles[0, :], percentiles[1, :], color='lightblue') # plot mean prediction ax.plot(X_test[:, 1], mean_prediction, 'blue', ls='solid', lw=2.0) ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI") plt.savefig('bnn_plot.pdf') plt.tight_layout() pars_list = [] for keys in samples.keys(): items = samples[keys] pars_list.append(items.reshape(args.num_samples, -1)) pars = np.hstack(pars_list) eigs = np.real(np.linalg.eig(np.cov(pars.T))[0]) fig, ax = plt.subplots() plt.plot(eigs) plt.savefig('bnn_eigs.pdf')
def test_linear_regression(in_dim=5, out_dim=2, shape=(1000, )): key = jr.PRNGKey(time.time_ns()) key1, key2 = jr.split(key, 2) covariates = jr.normal(key1, shape + (in_dim, )) covariates = np.column_stack([covariates, np.ones(shape)]) data = jr.normal(key2, shape + (out_dim, )) lr = regrs.GaussianLinearRegression.fit( dict(data=data, covariates=covariates)) # compare to least squares fit. note that the covariance matrix is only the # covariance of the residuals if we fit the intercept term what = np.linalg.lstsq(covariates, data)[0].T assert np.allclose(lr.weights, what) resid = data - covariates @ what.T assert np.allclose(lr.covariance_matrix, np.cov(resid, rowvar=False, bias=True), atol=1e-6)
def test_prior_sample_matches_analytical(self, input_dim, length_scale, seed, amplitude): output_dim = 2 # The current test only handles output_dim == 1 # assert output_dim == 1 num_basis = 4096 num_train_points = 5 num_function_samples = 5000 length_scales = jnp.ones((input_dim, )) * length_scale # FIXME(ethan): the following should pass # length_scales = jnp.ones((output_dim, input_dim)) * length_scale num_inducing_points = 16 kernel = tfk.FeatureScaled( tfk.ExponentiatedQuadratic(amplitude=amplitude), scale_diag=length_scales**2 # Use FeatureScaled for RBF ARD ) key = jax.random.PRNGKey(seed) model = gp.SparseGaussianProcess(input_dimension=input_dim, output_dimension=output_dim, kernel=kernel, key=key, num_basis=num_basis, num_samples=num_function_samples, num_inducing=num_inducing_points) x = jax.random.uniform(key, (num_train_points, input_dim)) # Generate samples at x f = model.prior(x) self.assertEqual(f.shape, (num_function_samples, num_train_points, output_dim)) # f = jnp.squeeze(model.prior(x), -1) # assuming output dimension is 1 analytical_mean = jnp.zeros((num_train_points, output_dim)) analytical_cov = kernel.matrix(x, x) sample_mean = jnp.mean(f, axis=0) cov_fn = lambda x: jnp.cov(x, rowvar=False, bias=True) sample_cov = jax.vmap(cov_fn, in_axes=2)(f) self.assertLessEqual( jnp.max(jnp.abs(sample_cov - analytical_cov) / analytical_cov), 0.2)
def test_gaussian_subposterior(method, diagonal): D = 10 n_samples = 10000 n_draws = 9000 n_subs = 8 mean = np.arange(D) cov = np.ones((D, D)) * 0.9 + np.identity(D) * 0.1 subcov = n_subs * cov # subposterior's covariance subposteriors = list(dist.MultivariateNormal(mean, subcov).sample( random.PRNGKey(1), (n_subs, n_samples))) draws = method(subposteriors, n_draws, diagonal=diagonal) assert draws.shape == (n_draws, D) assert_allclose(np.mean(draws, axis=0), mean, atol=0.03) if diagonal: assert_allclose(np.var(draws, axis=0), np.diag(cov), atol=0.05) else: assert_allclose(np.cov(draws.T), cov, atol=0.05)
def debug_mvee(): import pylab as plt n = random.normal(random.PRNGKey(0), (10000,2)) n = n /jnp.linalg.norm(n, axis=1, keepdims=True) angle = jnp.arctan2(n[:,1], n[:,0]) plt.hist(angle, bins=100) plt.show() N = 120 D = 2 points = random.uniform(random.PRNGKey(0), (N, D)) from jax import disable_jit with disable_jit(): center, radii, rotation = minimum_volume_enclosing_ellipsoid(points, 0.01) plt.hist(jnp.linalg.norm((rotation.T @ (points.T - center[:, None])) / radii[:, None], axis=0)) plt.show() print(center, radii, rotation) plt.scatter(points[:, 0], points[:, 1]) theta = jnp.linspace(0., jnp.pi*2, 100) ellipsis = center[:, None] + rotation @ jnp.stack([radii[0]*jnp.cos(theta), radii[1]*jnp.sin(theta)], axis=0) plt.plot(ellipsis[0,:], ellipsis[1,:]) for i in range(1000): y = sample_ellipsoid(random.PRNGKey(i), center, radii, rotation) plt.scatter(y[0], y[1]) C = jnp.linalg.pinv(jnp.cov(points, rowvar=False, bias=True)) p = (N - D - 1)/N def q(p): return p + p**2/(4.*(D-1)) C = C / q(p) c = jnp.mean(points, axis=0) W, Q, Vh = jnp.linalg.svd(C) radii = jnp.reciprocal(jnp.sqrt(Q)) rotation = Vh.conj().T ellipsis = c[:, None] + rotation @ jnp.stack([radii[0] * jnp.cos(theta), radii[1] * jnp.sin(theta)], axis=0) plt.plot(ellipsis[0, :], ellipsis[1, :]) plt.show()
def test_sample(self): np.random.seed(0) obj = Quadratic(A=jnp.asarray([[2.0, 1.0], [1.0, 5.0]]), b=jnp.asarray([1.0, 2.0])) obj_opt_state = obj.solve() posterior_cov = jnp.linalg.pinv(obj.A) num_samples = 10000 prng_key = random.PRNGKey(0) sampler = EQS() samples = sampler.sample(objective=obj, prng_key=prng_key, num_samples=num_samples) self.assertEqual(samples.shape[0], num_samples) sample_mean = jnp.mean(samples, axis=0) sample_cov = jnp.cov(samples, rowvar=False) np.testing.assert_allclose(sample_mean, obj_opt_state, rtol=1e-2) np.testing.assert_allclose(sample_cov, posterior_cov, rtol=1e-1)
# Direct estimation on the torus. bij_params, trace = train(rng_train, bij_params, bij_fns, args.num_steps, args.lr, 100) # Sample from the learned distribution. num_samples = 100000 num_dims = 2 xamb = random.normal(rng_xamb, [num_samples, num_dims]) xamb = forward(bij_params, bij_fns, xamb) xtor = jnp.mod(xamb, 2.0 * jnp.pi) lp = induced_torus_log_density(bij_params, bij_fns, xtor) xobs = rejection_sampling(rng_xobs, len(xtor), torus_density, args.beta) # Compute comparison statistics. mean_mse = jnp.square(jnp.linalg.norm(xtor.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xtor.T) - jnp.cov(xobs.T))) approx = jnp.exp(lp) target = torus_density(xtor) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, log_target log_approx = induced_torus_log_density(bij_params, bij_fns, xobs) approx = jnp.exp(log_approx) target = torus_density(xobs) log_target = jnp.log(target) w = approx / target