def generate_data( rng=0, num_points=10000, sig_mean=(-1, 1), bup_mean=(2.5, 2), bdown_mean=(-2.5, -1.5), b_mean=(1, -1), ): sig = multivariate_normal( PRNGKey(rng), jnp.asarray(sig_mean), jnp.asarray([[1, 0], [0, 1]]), shape=(num_points, ), ) bkg_up = multivariate_normal( PRNGKey(rng), jnp.asarray(bup_mean), jnp.asarray([[1, 0], [0, 1]]), shape=(num_points, ), ) bkg_down = multivariate_normal( PRNGKey(rng), jnp.asarray(bdown_mean), jnp.asarray([[1, 0], [0, 1]]), shape=(num_points, ), ) bkg_nom = multivariate_normal( PRNGKey(rng), jnp.asarray(b_mean), jnp.asarray([[1, 0], [0, 1]]), shape=(num_points, ), ) return sig, bkg_nom, bkg_up, bkg_down
def draw_state(val, key, params): """ Simulate one step of a system that evolves as A z_{t-1} + Bk + eps, where eps ~ N(0, Q). Parameters ---------- val: tuple (int, jnp.array) (latent value of system, state value of system). params: PRBPFParamsDiscrete key: PRNGKey """ latent_old, state_old = val probabilities = params.transition_matrix[latent_old, :] logits = logit(probabilities) latent_new = random.categorical(key, logits) key_latent, key_obs = random.split(key) state_new = params.A @ state_old + params.B[latent_new, :] state_new = random.multivariate_normal(key_latent, state_new, params.Q) obs_new = random.multivariate_normal(key_obs, params.C @ state_new, params.R) return (latent_new, state_new), (latent_new, state_new, obs_new)
def filter(self, key, init_state, sample_obs, nsamples=2000): """ init_state: array(state_size,) Initial state estimate sample_obs: array(nsamples, obs_size) Samples of the observations """ nsteps = sample_obs.shape[0] mu_hist = jnp.zeros((nsteps, 2)) keys = split(key, nsteps) for t, key_t in enumerate(keys): if t == 0: zt_rvs = random.multivariate_normal(key_t, init_state, self.Q, (nsamples, )) else: zt_rvs = random.multivariate_normal(key_t, self.fz(zt_rvs), self.Q) xt_rvs = random.multivariate_normal(key_t, self.fx(zt_rvs), self.R) weights_t = stats.multivariate_normal.pdf(sample_obs[t], xt_rvs, self.Q) weights_t = weights_t / weights_t.sum() mu_t = (zt_rvs * weights_t[:, None]).sum(axis=0) mu_hist = index_update(mu_hist, t, mu_t) return mu_hist
def bcR(self, rng=None, aspect_ratio=1.0): """bcR creates a random boundary condition for a rectangular domain defined by aspect_ratio. The boundary is a random 3rd order polynomial. rng variable allows to reproduce the results. The current boundary conditions are not periodic. """ if rng is None: rng = random.PRNGKey(1) x = self.bcmesh n = self.n n_y = equations.num_row(n, aspect_ratio) y = np.linspace(0, 1, num=n_y) if rng is not None: coeffs = random.multivariate_normal(rng, np.zeros(16), np.diag(np.ones(16))) else: key = random.randint(random.PRNGKey(1), (1,), 1, 1000) coeffs = random.multivariate_normal( random.PRNGKey(key[0]), np.zeros(16), np.diag(np.ones(16))) left = coeffs[0] * y**3 + coeffs[1] * y**2 + coeffs[2] * y + coeffs[3] right = coeffs[4] * y**3 + coeffs[5] * y**2 + coeffs[6] * y + coeffs[7] lower = coeffs[8] * x**3 + coeffs[9] * x**2 + coeffs[10] * x + coeffs[11] upper = coeffs[12] * x**3 + coeffs[13] * x**2 + coeffs[14] * x + coeffs[15] shape = 2 * x.shape source = onp.zeros(shape) source[0, :] = upper source[n_y - 1, :] = lower source[0:n_y, -1] = right source[0:n_y, 0] = left # because this makes the correct order of boundary conditions return source * (n + 1)**2
def bcL(self, rng=None): """bcL creates a random boundary condition for a L-shaped domain. The boundary is a random 3rd order polynomial of sine functions. rng variable allows to reproduce the results. Sine functions are chosen so that the boundary is periodic and does not have discontinuities. """ if rng is None: rng = random.PRNGKey(1) n = self.n x = onp.sin(self.bcmesh * np.pi) n_y = (np.floor((n + 1) / 2) - 1).astype(int) if rng is not None: coeffs = random.multivariate_normal(rng, np.zeros(16), np.diag(np.ones(16))) else: key = random.randint(random.PRNGKey(1), (1,), 1, 1000) coeffs = random.multivariate_normal( random.PRNGKey(key[0]), np.zeros(16), np.diag(np.ones(16))) left = coeffs[0] * x**3 + coeffs[1] * x**2 + coeffs[2] * x #+ coeffs[3] right = coeffs[4] * x**3 + coeffs[5] * x**2 + coeffs[6] * x #+ coeffs[7] lower = coeffs[8] * x**3 + coeffs[9] * x**2 + coeffs[10] * x #+ coeffs[11] upper = coeffs[12] * x**3 + coeffs[13] * x**2 + coeffs[14] * x #+ coeffs[15] shape = 2 * x.shape source = onp.zeros(shape) source[0, :] = upper source[n_y - 1, n_y - 1:] = lower[:n - n_y + 1] source[n_y - 1:, n_y - 1] = right[:n - n_y + 1] source[:, 0] = left source[-1, :n_y - 1] = right[n:n - n_y:-1] source[:n_y - 1, -1] = lower[n:n - n_y:-1] # because this makes the correct order of boundary conditions return source * (n + 1)**2
def __sample_step(self, input_vals, obs): key, state_t = input_vals key_system, key_obs, key = random.split(key, 3) state_t = random.multivariate_normal(key_system, self.fz(state_t), self.Q(state_t)) obs_t = random.multivariate_normal(key_obs, self.fx(state_t, *obs), self.R(state_t, *obs)) return (key, state_t), (state_t, obs_t)
def sample(self, key, n_samples=1, sample_intial_state=False): """ Simulate a run of n_sample independent stochastic linear dynamical systems Parameters ---------- key: jax.random.PRNGKey Seed of initial random states n_samples: int Number of independent linear systems with shared dynamics (optional) sample_initial_state: bool Whether to sample from an initial state or sepecified Returns ------- * array(n_samples, timesteps, state_size): Simulation of Latent states * array(n_samples, timesteps, observation_size): Simulation of observed states """ key_z1, key_system_noise, key_obs_noise = random.split(key, 3) if not sample_intial_state: state_t =self.mu0 * jnp.ones((n_samples, self.state_size)) else: state_t = random.multivariate_normal(key_z1, self.mu0, self.Sigma0, (n_samples,)) # Generate all future noise terms zeros_state = jnp.zeros(self.state_size) zeros_obs = jnp.zeros(self.observation_size) system_noise = random.multivariate_normal(key_system_noise, zeros_state, self.Q, (n_samples, self.timesteps)) obs_noise = random.multivariate_normal(key_obs_noise, zeros_obs, self.R, (n_samples, self.timesteps)) state_hist = jnp.zeros((n_samples, self.timesteps, self.state_size)) obs_hist = jnp.zeros((n_samples, self.timesteps, self.observation_size)) obs_t = jnp.einsum("ij,sj->si", self.C, state_t) + obs_noise[:, 0, :] state_hist = index_update(state_hist, jax.ops.index[:, 0, :], state_t) obs_hist = index_update(obs_hist, jax.ops.index[:, 0, :], obs_t) for t in range(1, self.timesteps): system_noise_t = system_noise[:, t, :] obs_noise_t = obs_noise[:, t, :] state_new = jnp.einsum("ij,sj->si", self.A, state_t) + system_noise_t obs_t = jnp.einsum("ij,sj->si", self.C, state_new) + obs_noise_t state_t = state_new state_hist = index_update(state_hist, jax.ops.index[:, t, :], state_t) obs_hist = index_update(obs_hist, jax.ops.index[:, t, :], obs_t) if n_samples == 1: state_hist = state_hist[0, ...] obs_hist = obs_hist[0, ...] return state_hist, obs_hist
def sample(self, key, x0, T, nsamples, dt=0.01, noisy=False): """ Run the Extended Kalman Filter algorithm. First, we integrate up to time T, then we obtain nsamples equally-spaced points. Finally, we transform the latent space to obtain the observations Parameters ---------- key: jax.random.PRNGKey Initial seed x0: array(state_size) Initial state of simulation T: float Final time of integration nsamples: int Number of observations to take from the total integration dt: float integration step size noisy: bool Whether to (naively) add noise to the state space Returns ------- * array(nsamples, state_size) State-space values * array(nsamples, obs_size) Observed-space values * int Number of observations skipped between one datapoint and the next """ nsteps = ceil(T / dt) jump_size = ceil(nsteps / nsamples) correction = nsamples - ceil(nsteps / jump_size) nsteps += correction * jump_size key_state, key_obs = random.split(key) state_noise = random.multivariate_normal(key_state, jnp.zeros(self.state_size), self.Q, (nsteps, )) obs_noise = random.multivariate_normal(key_obs, jnp.zeros(self.obs_size), self.R, (nsteps, )) simulation = self._rk2(x0, self.fz, nsteps, dt) if noisy: simulation = simulation + jnp.sqrt(dt) * state_noise sample_state = simulation[::jump_size] sample_obs = jnp.apply_along_axis( self.fx, 1, sample_state) + obs_noise[:len(sample_state)] return sample_state, sample_obs, jump_size
def driven_Langevin_move(x, potential, dt, A_function, b_function, potential_parameter, A_parameter, b_parameter, key): """ driven Langevin Algorithm; a driven langevin propagator is N(x_t; \Theta_t^i(x_{t-1})*(f_t-dt*b_t^i)(x_{t-1}), dt*\Theta_t^i(x_{t-1})) where \Theta_t^i(x_{t-1}) = (I_d + 2*dt*A_t^i(x_{t-1}))^{-1}, f_t(x_{t-1}) = x_{t-1} + 0.5*dt*\nabla \pi_t(x_{t-1}) arguments: x : jnp.array(N) current position (or latent variable) potential : function potential function (takes args x and parameters) dt : float incremental time A_function : function covariance driver function b_function : function mean driver function potential_parameter : jnp.array parameters passed to potential A_parameter : jnp.array() second argument of A function b_parameter : jnp.array() second argument of b function key : float randomization key returns out : jnp.array(N) multivariate gaussian proposal """ A, b, f, theta = driven_Langevin_parameters(x, potential, dt, A_function, b_function, potential_parameter, A_parameter, b_parameter) mu, cov = driven_mu_cov(b, f, theta, dt) return random.multivariate_normal(key, mu, cov)
def Wishart(key, dof, scale, shape=None): if scale is None: scale = jnp.eye(shape) batch_shape = () if jnp.ndim(scale) > 2: batch_shape = scale.shape[:-2] p = scale.shape[-1] if dof is None: dof = p if jnp.ndim(dof) > 0: raise ValueError("only scalar dof implemented") if ~(int(dof) == dof): raise ValueError( "dof should be integer-like (i.e. int(dof) == dof should return true)" ) else: dof = int(dof) if shape is not None: if batch_shape != (): assert batch_shape == shape, "Disagreement in batch shape between scale and shape" else: batch_shape = shape mn = jnp.zeros(shape=batch_shape + (p, )) mvn_shape = (dof, ) + batch_shape mvn = random.multivariate_normal(key, mean=mn, cov=scale, shape=mvn_shape) if jnp.ndim(mvn) > 2: mvn = jnp.swapaxes(mvn, 0, -2) S = jnp.einsum('...ji,...jk', mvn, mvn) return S
def sample(self, key, shape=None): """ Sample from the skewnormal distribution. Arguments: - key: a PRNGKey used as the random key. - shape: (Optional) optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with `mean.shape[:-1]` and `cov.shape[:-2]`. The default (None) produces a result batch shape by broadcasting together the batch shapes of `mean` and `cov`. """ if shape is None: shape = (1, ) X = random.multivariate_normal( key=key, shape=shape, mean=jnp.zeros(shape=(self.k + 1), ), cov=self.omega, ) X0 = jnp.expand_dims(X[:, -1], 1) X = X[:, :-1] Z = jnp.where(X0 > 0, X, -X) return self.loc + jnp.einsum('i,ji->ji', self.scale, Z)
def bootstrap_step(state, y): latent_t, state_t, key = state key_latent, key_state, key_reindex, key_next = random.split(key, 4) # Discrete states latent_t = random.categorical(key_latent, jnp.log(transition_matrix[latent_t]), shape=(nparticles, )) # Continous states state_mean = jnp.einsum("nm,sm->sn", A, state_t) + B[latent_t] state_t = random.multivariate_normal(key_state, mean=state_mean, cov=Q) # Compute weights weights_t = multivariate_normal.pdf(y, mean=state_t, cov=C) indices_t = random.categorical(key_reindex, jnp.log(weights_t), shape=(nparticles, )) # Reindex and compute weights state_t = state_t[indices_t, ...] latent_t = latent_t[indices_t, ...] # weights_t = jnp.ones(nparticles) / nparticles mu_t = state_t.mean(axis=0) return (latent_t, state_t, key_next), (mu_t, latent_t, state_t)
def test_1D_ULA_propagator(key=random.PRNGKey(0), num_runs=1000): """ take a batch of 1000 particles distributed according to N(0,2), run dynamics with ULA for 1000 steps with dt=0.01 on a potential whose invariant is N(0,2) and assert that the mean and variance is unchanged within a tolerance """ key, genkey = random.split(key) potential, (mu, cov), dG = get_nondefault_potential_initializer(1) x_ula_starter = random.multivariate_normal(key=genkey, mean=mu, cov=cov, shape=[num_runs]) dt = 1e-2 batch_ula_move = vmap(ULA_move, in_axes=(0, None, None, 0, None)) potential_parameter = jnp.array([0.]) for i in tqdm.trange(100): key, ula_keygen = random.split(key, 2) ula_keys = random.split(ula_keygen, num_runs) x_ULA = batch_ula_move(x_ula_starter, potential, dt, ula_keys, potential_parameter) x_ula_starter = x_ULA ula_mean, ula_std = x_ula_starter.mean(), x_ula_starter.std() assert checker_function(ula_mean, 0.2) assert checker_function(ula_std - jnp.sqrt(2), 0.2)
def rvs(self, prng_key: DeviceArray, n_samples: int) -> DeviceArray: return random.multivariate_normal( key=prng_key, mean=self._mean, cov=self._cov, shape=(n_samples, ), )
def __filter_step(self, state, obs_t): nsamples = self.nsamples indices = jnp.arange(nsamples) zt_rvs, key_t = state key_t, key_reindex, key_next = random.split(key_t, 3) # 1. Draw new points from the dynamic model zt_rvs = random.multivariate_normal(key_t, self.fz(zt_rvs), self.Q) # 2. Calculate unnormalised weights xt_rvs = self.fx(zt_rvs) weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, self.R) # 3. Resampling pi = random.choice(key_reindex, indices, p=weights_t, shape=(nsamples, )) zt_rvs = zt_rvs[pi, ...] weights_t = jnp.ones(nsamples) / nsamples # 4. Compute latent-state estimate, # Set next covariance state matrix mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t) return (zt_rvs, key_next), mu_t
def testMultivariateNormalCovariance(self): # test code based on https://github.com/google/jax/issues/1869 N = 100000 cov = jnp.array([[0.19, 0.00, -0.13, 0.00], [0.00, 0.29, 0.00, -0.23], [-0.13, 0.00, 0.39, 0.00], [0.00, -0.23, 0.00, 0.49]]) mean = jnp.zeros(4) out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N) key = random.PRNGKey(0) out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N, )) var_np = out_np.var(axis=0) var_jnp = out_jnp.var(axis=0) self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False) var_np = np.cov(out_np, rowvar=False) var_jnp = np.cov(out_jnp, rowvar=False) self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False)
def test_integrated_pos_enc(self): num_dims = 2 # The number of input dimensions. min_deg = 0 max_deg = 4 num_samples = 100000 rng = random.PRNGKey(0) for _ in range(5): # Generate a coordinate's mean and covariance matrix. key, rng = random.split(rng) mean = random.normal(key, (2, )) key, rng = random.split(rng) half_cov = jax.random.normal(key, [num_dims] * 2) cov = half_cov @ half_cov.T for diag in [False, True]: # Generate an IPE. enc = mip.integrated_pos_enc( (mean, jnp.diag(cov) if diag else cov), min_deg, max_deg, diag, ) # Draw samples, encode them, and take their mean. key, rng = random.split(rng) samples = random.multivariate_normal(key, mean, cov, [num_samples]) enc_samples = mip.pos_enc(samples, min_deg, max_deg, append_identity=False) enc_gt = jnp.mean(enc_samples, 0) self.assertAllClose(enc, enc_gt, rtol=1e-2, atol=1e-2)
def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape draws = random.multivariate_normal(rng_key, mean=self.mu, cov=self.covariance_matrix, shape=shape) return draws
def plot_mlp_prediction(key, xobs, yobs, xtest, fw, w, Sw, ax, n_samples=100): W_samples = multivariate_normal(key, w, Sw, (n_samples,)) sample_yhat = fw(W_samples, xtest[:, None]) for sample in sample_yhat: # sample curves ax.plot(xtest, sample, c="tab:gray", alpha=0.07) ax.plot(xtest, sample_yhat.mean(axis=0)) # mean of posterior predictive ax.scatter(xobs, yobs, s=14, c="none", edgecolor="black", label="observations", alpha=0.5) ax.set_xlim(xobs.min(), xobs.max())
def prior_sample(self, X=None, num_samps=1, key=0): if X is None: X = self.X N = X.shape[0] m = np.zeros(N) K = self.kernel(X, X) + 1e-12 * np.eye(N) s = multivariate_normal(PRNGKey(key), m, K, shape=[num_samps]) return s.T
def sample(self, n_samples, key=None): if key is None: self.threadkey, key = random.split(self.threadkey) out = random.multivariate_normal(key, self.gauss_mean, self.gauss_cov, shape=(n_samples, )) return vmap(self.bananify)(out.reshape((n_samples, self.d)))
def sample(self, rng_key, sample_shape=()): # random.multivariate_normal automatically adds the event shape # to the shape passed as argument. shape = sample_shape + self.batch_shape draws = random.multivariate_normal( rng_key, mean=self.mu, cov=self.covariance_matrix, shape=shape ) return draws
def sample(self, n_samples, key=None): """mutates self.key if key is None""" if key is None: self.threadkey, key = random.split(self.threadkey) out = random.multivariate_normal(key, self.mean, self.gauss_cov, shape=(n_samples, )) return vmap(self.squiggle)(out.reshape((n_samples, self.d)))
def sample(self, key, x0, nsteps): """ Sample discrete elements of a nonlinear system Parameters ---------- key: jax.random.PRNGKey x0: array(state_size) Initial state of simulation nsteps: int Total number of steps to sample from the system Returns ------- * array(nsamples, state_size) State-space values * array(nsamples, obs_size) Observed-space values """ key, key_system_noise, key_obs_noise = random.split(key, 3) state_hist = jnp.zeros((nsteps, self.state_size)) obs_hist = jnp.zeros((nsteps, self.obs_size)) state_t = x0.copy() obs_t = self.fx(state_t) state_noise = random.multivariate_normal( key_system_noise, jnp.zeros((self.state_size, )), self.Q, (nsteps, )) obs_noise = random.multivariate_normal(key_obs_noise, jnp.zeros((self.obs_size, )), self.R, (nsteps, )) state_hist = index_update(state_hist, 0, state_t) obs_hist = index_update(obs_hist, 0, obs_t) for t in range(1, nsteps): state_t = self.fz(state_t) + state_noise[t] obs_t = self.fx(state_t) + obs_noise[t] state_hist = index_update(state_hist, t, state_t) obs_hist = index_update(obs_hist, t, obs_t) return state_hist, obs_hist
def sample(self, n_samples, key=None): """mutates self.key if key is None""" if key is None: self.threadkey, key = random.split(self.threadkey) out = random.multivariate_normal(key, self.mean, self.cov, shape=(n_samples, )) shape = (n_samples, self.d) return out.reshape(shape)
def sample(rng, params, num_samples=1): cluster_samples = [] for mean, cov in zip(means, covariances): rng, temp_rng = random.split(rng) cluster_sample = random.multivariate_normal( temp_rng, mean, cov, (num_samples, )) cluster_samples.append(cluster_sample) samples = np.dstack(cluster_samples) idx = random.categorical(rng, weights, shape=(num_samples, 1, 1)) return np.squeeze(np.take_along_axis(samples, idx, -1))
def experiment(key, K, cliques0, cliques1, n): p = K.shape[0] X = random.multivariate_normal(key, jnp.zeros(p), jnp.linalg.inv(K), (n, )) ssd = X.T @ X Shat = ssd / n # + np.finfo(np.float32).eps * jnp.eye(p) Khat0 = itpropscaling(cliques0, Shat) Khat1 = itpropscaling(cliques1, Shat) w = n * (jnp.linalg.slogdet(Khat1)[1] - jnp.linalg.slogdet(Khat0)[1]) return jnp.exp(-w / 2)
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size, shape): r = np.random.RandomState(0) key = random.PRNGKey(0) eff_batch_size = mean_batch_size \ if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size mean = r.randn(*(mean_batch_size + (dim,))) cov_factor = r.randn(*(cov_batch_size + (dim, dim))) cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor) cov += 1e-3 * np.eye(dim) shape = shape + eff_batch_size samples = random.multivariate_normal(key, mean, cov, shape=shape) assert samples.shape == shape + (dim,)
def _generate_user_preferences(self): ''' generate_user_preferences creates a NxK list of bernoulli probabilities for how likely a user will purchase an item from each bin (each row sums to 1) ''' # Create one column that is largely preferred and peter out over the rest of the bins. user_prefs = random.multivariate_normal(self.rng_key,np.array([0.9]+[0.1/(self.num_bins-1)]*(self.num_bins-1)) , 0.005*np.eye(self.num_bins), shape=(self.num_users,)) all_perm = onp.array((list(itertools.permutations(list(range(self.num_bins)))))) # Note all possible permutations of the bins temp = all_perm[random.randint(self.rng_key,(self.num_users,),0,len(all_perm))] # Randomly select a permutation for each row of 'user_prefs' # Randomly permute the columns of each row in 'user_prefs' user_prefs = (user_prefs.flatten()[(temp+self.num_bins*np.arange(self.num_users)[...,np.newaxis]).flatten()]).reshape(user_prefs.shape) # Normalize user preferences user_prefs = abs(user_prefs) * (1/np.sum(abs(user_prefs),axis=1))[:,np.newaxis] self.user_prefs = user_prefs # Set class values self._generate_user_context() # Generate the various context variables that will be provided to the algorithm
def sample( self, objective: Union[Quadratic, LeastSquares], prng_key: jnp.ndarray, num_samples: int = 1, **_unused_kwargs, ) -> jnp.ndarray: """Generates exact samples from a quadratic posterior (Gaussian).""" if isinstance(objective, LeastSquares): objective = Quadratic.from_least_squares(objective) state_mean = objective.solve() state_cov = jnp.linalg.pinv(objective.A) samples = random.multivariate_normal(prng_key, state_mean, state_cov, shape=(num_samples, )) return samples