def testPmapMapVmapCombinations(self): # https://github.com/google/jax/issues/2822 def vv(x, y): """Vector-vector multiply""" return np.dot(x, y) def matrix_vector(x, y, parallel=True): """Matrix vector multiply. First batch it and then row by row""" fv = lambda z: lax.map(lambda j: vv(j, y), z) if parallel: # split leading axis in two new_x = x.reshape((jax.device_count(), -1, *x.shape[1:])) # apply map new_res = pmap(fv)(new_x) # reshape back out res = new_res.reshape(x.shape[0], *new_res.shape[2:]) else: res = fv(x) return res x = random.normal(random.PRNGKey(1), (80, 5)) y = random.normal(random.PRNGKey(1), (10, 5)) result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map result3 = lax.map(lambda b: matrix_vector(x, b, True), y) # map + pmap result4 = np.stack([matrix_vector(x, b, False) for b in y]) # none + map self.assertAllClose(result1, result2, check_dtypes=False, atol=1e-3, rtol=1e-3) self.assertAllClose(result1, result3, check_dtypes=False, atol=1e-3, rtol=1e-3) self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3)
def calc_bpd_loop(self, model_fn, *, x_start, rng): """Calculate variational bound (loop over all timesteps and sum).""" batch_size = x_start.shape[0] noise_shape = x_start.shape + (self.num_pixel_vals,) def map_fn(map_val): t, cur_rng = map_val # Calculate VB term at the current timestep t = jnp.full((batch_size,), t) vb, _ = self.vb_terms_bpd( model_fn=model_fn, x_start=x_start, t=t, x_t=self.q_sample( x_start=x_start, t=t, noise=jax.random.uniform(cur_rng, noise_shape))) del cur_rng assert vb.shape == (batch_size,) return vb vbterms_tb = lax.map( map_fn, (jnp.arange(self.num_timesteps), jax.random.split(rng, self.num_timesteps))) vbterms_bt = vbterms_tb.T assert vbterms_bt.shape == (batch_size, self.num_timesteps) prior_b = self.prior_bpd(x_start=x_start) total_b = vbterms_tb.sum(axis=0) + prior_b assert prior_b.shape == total_b.shape == (batch_size,) return { 'total': total_b, 'vbterms': vbterms_bt, 'prior': prior_b, }
def rvcoref(t, T0, P, e, omegaA, K, i): """Unit-free radial velocity curve w/o systemic velocity, in addition, i and K are separated. Args: t: Time in your time unit T0: Time of periastron passage in your time unit P: orbital period in your time unit e: eccentricity omegaA: argument of periastron K: RV semi-amplitude/sin i in your velocity unit i: inclination Returns: radial velocity curve in your velocity unit """ n = 2*jnp.pi/P M = n*(t-T0) Ea = map(lambda x: getE.getE(x, e), M) cosE = jnp.cos(Ea) cosf = (-cosE + e)/(-1 + cosE*e) sinf = jnp.sqrt((-1 + cosE*cosE)*(-1 + e*e))/(-1 + cosE*e) sinf = jnp.where(Ea < jnp.pi, -sinf, sinf) cosfpo = cosf*jnp.cos(omegaA)-sinf*jnp.sin(omegaA) face = 1.0/jnp.sqrt(1.0-e*e) Ksini = K*jnp.sin(i) model = Ksini*face*(cosfpo+e*jnp.cos(omegaA)) return model
def test_multiple_twists(dim=1, num_twists=3, integration_range=(-20, 20), num_point_ops=20000): """ for several random twists of a base gaussian, assert an interconsistent logpdf """ from jax.lax import map for i in tqdm.trange(10): #define the params x = np.random.randn(dim) mean = np.random.randn(dim) cov = np.abs(np.random.randn(dim)) + np.ones(dim)*10 As, bs = np.abs(np.random.randn(num_twists, dim)), np.random.randn(num_twists, dim) #get the untwisted logZ, logpdf, and the twisting value untwisted_logZ = Normal_logZ(mean, cov) untwisted_logpdf = unnormalized_Normal_logp(x, mean, cov) log_twist = log_psi_twist(x, As, bs) #manually compute the twisting value manual_unnorm_logp_twist = untwisted_logpdf + log_twist #add the twist value to the untwisted unnormalized logp #do the parameter twist twisted_mu, twisted_cov, log_normalizer = do_param_twist(mean, cov, As, bs, True) twisted_logZ = Normal_logZ(twisted_mu, twisted_cov) twisted_logpdf = unnormalized_Normal_logp(x, twisted_mu, twisted_cov) + log_normalizer - twisted_logZ + untwisted_logZ #make sure the manual and automatic one are equal assert np.isclose(manual_unnorm_logp_twist, twisted_logpdf) #now, we have to do the hard part and try to compute the twisted normalization constant... unnorm_logp = lambda x: jnp.exp(unnormalized_Normal_logp(x, mean, cov) + log_psi_twist(x, As, bs)) vals = map(unnorm_logp, jnp.linspace(integration_range[0], integration_range[1], num_point_ops)[..., jnp.newaxis]) dx = (integration_range[1] - integration_range[0]) / num_point_ops man_logZ = jnp.log(dx*vals.sum()) assert np.isclose(man_logZ, log_normalizer + untwisted_logZ, atol=1e-2)
def gaussian_cl_covariance(ell, probes, cl_signal, cl_noise, f_sky=0.25): """ Computes a Gaussian covariance for the angular cls of the provided probes return_cls: (returns covariance) """ ell = np.atleast_1d(ell) n_ell = len(ell) # Adding noise to auto-spectra cl_obs = cl_signal + cl_noise n_cls = cl_obs.shape[0] # Normalization of covariance norm = (2 * ell + 1) * np.gradient(ell) * f_sky # Retrieve ordering for blocks of the covariance matrix cov_blocks = np.array(_get_cov_blocks_ordering(probes)) def get_cov_block(inds): a, b, c, d = inds cov = (cl_obs[a] * cl_obs[b] + cl_obs[c] * cl_obs[d]) / norm return cov * np.eye(n_ell) cov_mat = lax.map(get_cov_block, cov_blocks) # Reshape covariance matrix into proper matrix cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell, n_ell)) cov_mat = cov_mat.transpose(axes=(0, 2, 1, 3)).reshape( (n_ell * n_cls, n_ell * n_cls)) return cov_mat
def distributed_matrix_vector(x, y): """Matrix vector multiply. First batch it and then row by row""" fv = lambda z: lax.map(lambda j: vv(j, y), z) res = pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:]))) res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:]) return res
def rejection_stitch_proposal_all(ssm_scenario: StateSpaceModel, x0_all: jnp.ndarray, t: float, x1_all: jnp.ndarray, tplus1: float, x1_log_weight: jnp.ndarray, bound_inflation: float, not_yet_accepted_arr: jnp.ndarray, x1_all_sampled_inds: jnp.ndarray, bound: float, random_keys: jnp.ndarray, rejection_iter: int, num_transition_evals: int) \ -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray, int, int]: n = len(x1_all) mapped_tup = map(lambda i: rejection_stitch_proposal_single_cond(not_yet_accepted_arr[i], x1_all_sampled_inds[i], ssm_scenario, x0_all[i], t, x1_all, tplus1, x1_log_weight, bound, random_keys[i]), jnp.arange(n)) x1_all_sampled_inds, dens_evals, not_yet_accepted_arr_new, random_keys = mapped_tup # Check if we need to start again max_dens = jnp.max(dens_evals) reset_bound = max_dens > bound bound = jnp.where(reset_bound, max_dens * bound_inflation, bound) not_yet_accepted_arr_new = jnp.where(reset_bound, jnp.ones(n, dtype='bool'), not_yet_accepted_arr_new) return not_yet_accepted_arr_new, x1_all_sampled_inds, bound, random_keys, rejection_iter + 1, \ num_transition_evals + not_yet_accepted_arr.sum()
def integrand(a): # Step 1: retrieve the associated comoving distance chi = bkgrd.radial_comoving_distance(cosmo, a) # Step 2: get the power spectrum for this combination of chi and a k = (ell + 0.5) / np.clip(chi, 1.0) # pk should have shape [na] pk = power.nonlinear_matter_power(cosmo, k, a, transfer_fn, nonlinear_fn) # Compute the kernels for all probes kernels = np.vstack([p.kernel(cosmo, a2z(a), ell) for p in probes]) # Define an ordering for the blocks of the signal vector cl_index = np.array(_get_cl_ordering(probes)) # Compute all combinations of tracers def combine_kernels(inds): return kernels[inds[0]] * kernels[inds[1]] # Now kernels has shape [ncls, na] kernels = lax.map(combine_kernels, cl_index) result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip( chi**2, 1.0) # We transpose the result just to make sure that na is first return result.T
def _predict_spatial_batched(self, elec_params, x, y): """ A batched version of _predict_spatial_jax Parameters: ------------- elec_params : np.array with shape (batch_size, n_electrodes, 3) The 3 columns are freq, amp, pdur for each electrode x, y: np.array with shape (n_electrodes) x and y coordinates of electrodes Returns: ------------ resp : np.array() representing the resulting percepts, shape (batch_size, :, 1) """ bright_effects = self.bright_model(elec_params[:, :, 0], elec_params[:, :, 1], elec_params[:, :, 2]) size_effects = self.size_model(elec_params[:, :, 0], elec_params[:, :, 1], elec_params[:, :, 2]) streak_effects = self.streak_model(elec_params[:, :, 0], elec_params[:, :, 1], elec_params[:, :, 2]) eparams = jnp.stack([bright_effects, size_effects, streak_effects], axis=2) def predict_one(e_params): return self.biphasic_axon_map_jax(e_params, x, y, self.axon_contrib, self.rho, self.thresh_percept) resps = lax.map(predict_one, eparams) return resps
def rvf(t, T0, P, e, omegaA, Ksini, Vsys): """Unit-free radial velocity curve for SB1. Args: t: Time in your time unit T0: Time of periastron passage in your time unit P: orbital period in your time unit e: eccentricity omegaA: argument of periastron Ksini: RV semi-amplitude in your velocity unit Vsys: systemic velocity in your velocity unit Returns: radial velocity curve in your velocity unit """ n = 2*jnp.pi/P M = n*(t-T0) Ea = map(lambda x: getE.getE(x, e), M) cosE = jnp.cos(Ea) cosf = (-cosE + e)/(-1 + cosE*e) sinf = jnp.sqrt((-1 + cosE*cosE)*(-1 + e*e))/(-1 + cosE*e) sinf = jnp.where(Ea < jnp.pi, -sinf, sinf) cosfpo = cosf*jnp.cos(omegaA)-sinf*jnp.sin(omegaA) face = 1.0/jnp.sqrt(1.0-e*e) model = Ksini*face*(cosfpo+e*jnp.cos(omegaA)) + Vsys return model
def _unpack_and_constrain(self, latent_sample, params): def unpack_single_latent(latent): unpacked_samples = self._unpack_latent(latent) # XXX: we need to add param here to be able to replay model unpacked_samples.update({ k: v for k, v in params.items() if k in self.prototype_trace and self.prototype_trace[k]["type"] == "param" }) samples = self._postprocess_fn(unpacked_samples) # filter out param sites return { k: v for k, v in samples.items() if k in self.prototype_trace and self.prototype_trace[k]["type"] != "param" } sample_shape = jnp.shape(latent_sample)[:-1] if sample_shape: latent_sample = jnp.reshape(latent_sample, (-1, jnp.shape(latent_sample)[-1])) unpacked_samples = lax.map(unpack_single_latent, latent_sample) return tree_map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), unpacked_samples, ) else: return unpack_single_latent(latent_sample)
def safe_map(f, xs): """equivalent to jax.lax.map, but handles empty arrays too. """ if xs.shape[0] == 0: return xs return l.map(f, xs)
def laxmap_postprocess_fn(states, args, kwargs): if self.postprocess_fn is None: body_fn = self.sampler.postprocess_fn(args, kwargs) else: body_fn = self.postprocess_fn if self.chain_method == "vectorized" and self.num_chains > 1: body_fn = vmap(body_fn) return lax.map(body_fn, states)
def gaussians_twist(Xp, potential, dt, potential_params, A_fn, b_fn, A_params, b_params, get_log_normalizer): """ conduct a gaussian twist to return a twisted mean, covariance vector, normalizing constant, As, and bs arguments Xp : jnp.array(Dx) previous x values potential : function potential function dt : float time increment potential_parameters : jnp.array(R) arg 1 for potential function A_fn : function A twisting function b_fn : function b twisting function A_params : jnp.array(U, V) U vector of parameters to twist; each is of dim V b_params : jnp.array(U, W) U vector of parameters to twist; each is of dim W get_log_normalizer : bool whether to compute the normalization constant of the forward kernel return twisted_mu : jnp.array(Dx) twisted mean twisted_cov : jnp.array(Dx) twisted covariance vector logZ : float forward kernel log normalizer As : jnp.array(Q, Dx) array of A twisting vectors bs : jnp.array(Q, Dx) array of b twisting vectors """ partial_A_fn, partial_b_fn = partial(A_fn, Xp), partial(b_fn, Xp) As = map(partial_A_fn, A_params) bs = map(partial_b_fn, b_params) start_mu, start_cov = base_fk_params(Xp, potential, dt, potential_params) twisted_mu, twisted_cov, logZ = do_param_twist(start_mu, start_cov, As, bs, get_log_normalizer) return twisted_mu, twisted_cov, logZ, (As, bs)
def init_kernel(init_params, num_warmup, adapt_state_size=None, inverse_mass_matrix=None, dense_mass=False, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0)): nonlocal wa_steps wa_steps = num_warmup pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( 'Only one of `potential_fn` or `potential_fn_gen` must be provided.' ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) rng_key_sa, rng_key_zs, rng_key_z = random.split(rng_key, 3) z = init_params z_flat, unravel_fn = ravel_pytree(z) if inverse_mass_matrix is None: inverse_mass_matrix = jnp.identity( z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1]) inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \ else jnp.sqrt(inverse_mass_matrix) if adapt_state_size is None: # XXX: heuristic choice adapt_state_size = 2 * z_flat.shape[-1] else: assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.' # NB: mean is init_params zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size, )) # compute potential energies pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs) if dense_mass: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) # if cholesky is NaN, we use the scale from `sample_proposal` here inv_mass_matrix_sqrt = jnp.where(jnp.any(jnp.isnan(cholesky)), inv_mass_matrix_sqrt, cholesky) else: inv_mass_matrix_sqrt = jnp.std(zs, 0) adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0), inv_mass_matrix_sqrt) k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] sa_state = SAState(jnp.array(0), z, pe, jnp.zeros(()), jnp.zeros(()), jnp.array(False), adapt_state, rng_key_sa) return device_put(sa_state)
def _binomial(key, p, n, shape): shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n)) # reshape to map over axis 0 p = jnp.reshape(jnp.broadcast_to(p, shape), -1) n = jnp.reshape(jnp.broadcast_to(n, shape), -1) key = random.split(key, jnp.size(p)) if jax.default_backend() == "cpu": ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n)) else: ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n) return jnp.reshape(ret, shape)
def _binomial(key, p, n, shape): shape = shape or lax.broadcast_shapes(np.shape(p), np.shape(n)) # reshape to map over axis 0 p = np.reshape(np.broadcast_to(p, shape), -1) n = np.reshape(np.broadcast_to(n, shape), -1) key = random.split(key, np.size(p)) if xla_bridge.get_backend().platform == 'cpu': ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n)) else: ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n) return np.reshape(ret, shape)
def _poisson(key, rate, shape, dtype): # Ref: https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables shape = shape or np.shape(rate) rate = lax.convert_element_type(rate, canonicalize_dtype(np.float64)) rate = np.broadcast_to(rate, shape) rng_keys = random.split(key, np.size(rate)) if xla_bridge.get_backend().platform == 'cpu': k = lax.map(_poisson_one, (rng_keys, np.reshape(rate, -1))) else: k = vmap(_poisson_one)((rng_keys, np.reshape(rate, -1))) k = lax.convert_element_type(k, dtype) return np.reshape(k, shape)
def _constrain(self, latent_samples): name = list(latent_samples)[0] sample_shape = jnp.shape(latent_samples[name])[ :jnp.ndim(latent_samples[name]) - jnp.ndim(self._init_locs[name])] if sample_shape: flatten_samples = tree_map(lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[len(sample_shape):]), latent_samples) contrained_samples = lax.map(self._postprocess_fn, flatten_samples) return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), contrained_samples) else: return self._postprocess_fn(latent_samples)
def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): """ Vectorizing map that maps a function `fn` over `batch_ndims` leading axes of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions to keep memory usage constant. :param callable fn: The function to map over. :param xs: JAX pytree (e.g. an array, a list/tuple/dict of arrays,...) :param int batch_ndims: The number of leading dimensions of `xs` to apply `fn` element-wise over them. :param int chunk_size: Size of each chunk of `xs`. Defaults to the size of batch dimensions. :returns: output of `fn(xs)`. """ flatten_xs = tree_flatten(xs)[0] batch_shape = np.shape(flatten_xs[0])[:batch_ndims] for x in flatten_xs[1:]: assert np.shape(x)[:batch_ndims] == batch_shape # we'll do map(vmap(fn), xs) and make xs.shape = (num_chunks, chunk_size, ...) num_chunks = batch_size = int(np.prod(batch_shape)) prepend_shape = (batch_size, ) if batch_size > 1 else () xs = tree_map( lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]), xs) # XXX: probably for the default behavior with chunk_size=None, # it is better to catch OOM error and reduce chunk_size by half until OOM disappears. chunk_size = batch_size if chunk_size is None else min( batch_size, chunk_size) if chunk_size > 1: pad = chunk_size - (batch_size % chunk_size) xs = tree_map( lambda x: jnp.pad(x, ((0, pad), ) + ((0, 0), ) * (np.ndim(x) - 1)), xs) num_chunks = batch_size // chunk_size + int(pad > 0) prepend_shape = (-1, ) if num_chunks > 1 else () xs = tree_map( lambda x: jnp.reshape( x, prepend_shape + (chunk_size, ) + jnp.shape(x)[1:]), xs, ) fn = vmap(fn) ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs) map_ndims = int(num_chunks > 1) + int(chunk_size > 1) ys = tree_map( lambda y: jnp.reshape(y, (int(np.prod(jnp.shape(y)[:map_ndims])), ) + jnp.shape(y)[map_ndims:])[:batch_size], ys, ) return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys)
def matrix_vector(x, y, parallel=True): """Matrix vector multiply. First batch it and then row by row""" fv = lambda z: lax.map(lambda j: vv(j, y), z) if parallel: # split leading axis in two new_x = x.reshape((jax.device_count(), -1, *x.shape[1:])) # apply map new_res = pmap(fv)(new_x) # reshape back out res = new_res.reshape(x.shape[0], *new_res.shape[2:]) else: res = fv(x) return res
def gaussian_cl_covariance(ell, probes, cl_signal, cl_noise, f_sky=0.25, sparse=True): """ Computes a Gaussian covariance for the angular cls of the provided probes Set sparse True to return a sparse matrix representation that uses a factor of n_ell less memory and is compatible with the linear algebra operations in :mod:`jax_cosmo.sparse`. return_cls: (returns covariance) """ ell = np.atleast_1d(ell) n_ell = len(ell) one = 1.0 if sparse else np.eye(n_ell) # Adding noise to auto-spectra cl_obs = cl_signal + cl_noise n_cls = cl_obs.shape[0] # Normalization of covariance norm = (2 * ell + 1) * np.gradient(ell) * f_sky # Retrieve ordering for blocks of the covariance matrix cov_blocks = np.array(_get_cov_blocks_ordering(probes)) def get_cov_block(inds): a, b, c, d = inds cov = (cl_obs[a] * cl_obs[b] + cl_obs[c] * cl_obs[d]) / norm return cov * one # Return a sparse representation of the matrix containing only the diagonals # for each of the n_cls x n_cls blocks of size n_ell x n_ell. # We could compress this further using the symmetry of the blocks, but # it is easier to invert this matrix with this redundancy included. cov_mat = lax.map(get_cov_block, cov_blocks) # Reshape covariance matrix into proper matrix if sparse: cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell)) else: cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell, n_ell)) cov_mat = cov_mat.transpose(axes=(0, 2, 1, 3)).reshape( (n_ell * n_cls, n_ell * n_cls)) return cov_mat
def _single_chain_mcmc(self, init, args, kwargs, collect_fields): rng_key, init_state, init_params = init if init_state is None: init_state = self.sampler.init(rng_key, self.num_warmup, init_params, model_args=args, model_kwargs=kwargs) if self.postprocess_fn is None: postprocess_fn = self.sampler.postprocess_fn(args, kwargs) else: postprocess_fn = self.postprocess_fn diagnostics = lambda x: self.sampler.get_diagnostics_str(x[ 0]) if rng_key.ndim == 1 else '' # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state, ) lower_idx = self._collection_params["lower"] upper_idx = self._collection_params["upper"] phase = self._collection_params["phase"] collect_vals = fori_collect( lower_idx, upper_idx, self._get_cached_fn(), init_val, transform=_collect_fn(collect_fields), progbar=self.progress_bar, return_last_val=True, collection_size=self._collection_params["collection_size"], progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase), diagnostics_fn=diagnostics) states, last_val = collect_vals # Get first argument of type `HMCState` last_state = last_val[0] if len(collect_fields) == 1: states = (states, ) states = dict(zip(collect_fields, states)) # Apply constraints if number of samples is non-zero site_values = tree_flatten(states[self._sample_field])[0] # XXX: lax.map still works if some arrays have 0 size # so we only need to filter out the case site_value.shape[0] == 0 # (which happens when lower_idx==upper_idx) if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0: if self.chain_method == "vectorized" and self.num_chains > 1: postprocess_fn = vmap(postprocess_fn) states[self._sample_field] = lax.map(postprocess_fn, states[self._sample_field]) return states, last_state
def _unpack_and_constrain(self, latent_sample, params): def unpack_single_latent(latent): unpacked_samples = self._unpack_latent(latent) # add param sites in model unpacked_samples.update({k: v for k, v in params.items() if k in self.prototype_trace and self.prototype_trace[k]['type'] == 'param'}) return self._postprocess_fn(unpacked_samples) sample_shape = jnp.shape(latent_sample)[:-1] if sample_shape: latent_sample = jnp.reshape(latent_sample, (-1, jnp.shape(latent_sample)[-1])) unpacked_samples = lax.map(unpack_single_latent, latent_sample) return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), unpacked_samples) else: return unpack_single_latent(latent_sample)
def noise_cl(ell, probes): """ Computes noise contributions to auto-spectra """ n_ell = len(ell) # Concatenate noise power for each tracer noise = np.concatenate([p.noise() for p in probes]) # Define an ordering for the blocks of the signal vector cl_index = np.array(_get_cl_ordering(probes)) # Only include a noise contribution for the auto-spectra def get_noise_cl(inds): i, j = inds delta = 1.0 - np.clip(np.abs(i - j), 0.0, 1.0) return noise[i] * delta * np.ones(n_ell) return lax.map(get_noise_cl, cl_index)
def rejection_stitching(ssm_scenario: StateSpaceModel, x0_all: jnp.ndarray, t: float, x1_all: jnp.ndarray, tplus1: float, x1_log_weight: jnp.ndarray, random_key: jnp.ndarray, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> Tuple[jnp.ndarray, int]: rejection_initial_keys = random.split(random_key, 3) n = len(x1_all) # Prerun to initiate bound x1_initial_inds = random.categorical(rejection_initial_keys[0], x1_log_weight, shape=(n,)) initial_cond_dens = jnp.exp(-vmap(ssm_scenario.transition_potential, (0, None, 0, None))(x0_all, t, x1_all[x1_initial_inds], tplus1)) max_cond_dens = jnp.max(initial_cond_dens) initial_bound = jnp.where(max_cond_dens > init_bound_param, max_cond_dens * bound_inflation, init_bound_param) initial_not_yet_accepted_arr = random.uniform(rejection_initial_keys[1], (n,)) > initial_cond_dens / initial_bound out_tup = while_loop(lambda tup: jnp.logical_and(tup[0].sum() > 0, tup[-2] < maximum_rejections), lambda tup: rejection_stitch_proposal_all(ssm_scenario, x0_all, t, x1_all, tplus1, x1_log_weight, bound_inflation, *tup), (initial_not_yet_accepted_arr, x1_initial_inds, initial_bound, random.split(rejection_initial_keys[2], n), 1, n)) not_yet_accepted_arr, x1_final_inds, final_bound, random_keys, rej_attempted, num_transition_evals = out_tup x1_final_inds = map(lambda i: full_stitch_single_cond(not_yet_accepted_arr[i], x1_final_inds[i], ssm_scenario, x0_all[i], t, x1_all, tplus1, x1_log_weight, random_keys[i]), jnp.arange(n)) num_transition_evals = num_transition_evals + len(x1_all) * not_yet_accepted_arr.sum() return x1_final_inds, num_transition_evals
def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z', )): if init_state is None: init_state = self.sampler.init(rng_key, self.num_warmup, init_params, model_args=args, model_kwargs=kwargs) if self.constrain_fn is None: self.constrain_fn = self.sampler.constrain_fn(args, kwargs) diagnostics = lambda x: get_diagnostics_str(x[ 0]) if rng_key.ndim == 1 else None # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state, ) lower_idx = self._collection_params["lower"] upper_idx = self._collection_params["upper"] collect_vals = fori_collect( lower_idx, upper_idx, self._get_cached_fn(), init_val, transform=_collect_fn(collect_fields), progbar=self.progress_bar, return_last_val=True, collection_size=self._collection_params["collection_size"], progbar_desc=functools.partial(get_progbar_desc_str, lower_idx), diagnostics_fn=diagnostics) states, last_val = collect_vals # Get first argument of type `HMCState` last_state = last_val[0] if len(collect_fields) == 1: states = (states, ) states = dict(zip(collect_fields, states)) # Apply constraints if number of samples is non-zero site_values = tree_flatten(states['z'])[0] if len(site_values) > 0 and site_values[0].size > 0: states['z'] = lax.map(self.constrain_fn, states['z']) return states, last_state
def _predictive(rng_key, model, posterior_samples, num_samples, return_sites=None, parallel=True, model_args=(), model_kwargs={}): rng_keys = random.split(rng_key, num_samples) def single_prediction(val): rng_key, samples = val model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(*model_args, **model_kwargs) if return_sites is not None: if return_sites == '': sites = { k for k, site in model_trace.items() if site['type'] != 'plate' } else: sites = return_sites else: sites = { k for k, site in model_trace.items() if (site['type'] == 'sample' and k not in samples) or ( site['type'] == 'deterministic') } return { name: site['value'] for name, site in model_trace.items() if name in sites } if parallel: return vmap(single_prediction)((rng_keys, posterior_samples)) else: return lax.map(single_prediction, (rng_keys, posterior_samples))
def testMap(self): f = lambda x: x**2 xs = np.arange(10) expected = xs**2 actual = lax.map(f, xs) self.assertAllClose(actual, expected, check_dtypes=True)
def find_valid_initial_params( rng_key, model, *, init_strategy=init_to_uniform, enum=False, model_args=(), model_kwargs=None, prototype_params=None, forward_mode_differentiation=False, validate_grad=True, ): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns the corresponding potential energy, the gradients, and an `is_valid` flag to say whether the initial parameters are valid. Parameter values are considered valid if the values and the gradients for the log density have finite values. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. :param bool enum: whether to enumerate over discrete latent sites. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict prototype_params: an optional prototype parameters, which is used to define the shape for initial parameters. :param bool forward_mode_differentiation: whether to use forward-mode differentiation or reverse-mode differentiation. Defaults to False. :param bool validate_grad: whether to validate gradient of the initial params. Defaults to True. :return: tuple of `init_params_info` and `is_valid`, where `init_params_info` is the tuple containing the initial params, their potential energy, and their gradients. """ model_kwargs = {} if model_kwargs is None else model_kwargs init_strategy = (init_strategy if isinstance(init_strategy, partial) else init_strategy()) # handle those init strategies differently to save computation if init_strategy.func is init_to_uniform: radius = init_strategy.keywords.get("radius") init_values = {} elif init_strategy.func is _init_to_unconstrained_value: radius = 2 init_values = init_strategy.keywords.get("values") else: radius = None def cond_fn(state): i, _, _, is_valid = state return (i < 100) & (~is_valid) def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace( *model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if (v["type"] == "sample" and not v["is_observed"] and not v["fn"].is_discrete): constrained_values[k] = v["value"] inv_transforms[k] = biject_to(v["fn"].support) params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True, ) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius) key, subkey = random.split(key) potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum) if validate_grad: if forward_mode_differentiation: pe = potential_fn(params) z_grad = jacfwd(potential_fn)(params) else: pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) else: pe = potential_fn(params) is_valid = jnp.isfinite(pe) z_grad = None return i + 1, key, (params, pe, z_grad), is_valid def _find_valid_params(rng_key, exit_early=False): init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False) if exit_early and not_jax_tracer(rng_key): # Early return if valid params found. This is only helpful for single chain, # where we can avoid compiling body_fn in while_loop. _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state) if not_jax_tracer(is_valid): if device_get(is_valid): return (init_params, pe, z_grad), is_valid # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times # even if the init_state is a valid result _, _, (init_params, pe, z_grad), is_valid = while_loop(cond_fn, body_fn, init_state) return (init_params, pe, z_grad), is_valid # Handle possible vectorization if rng_key.ndim == 1: (init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True) else: (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key) return (init_params, pe, z_grad), is_valid