def covariance_matrix(self): # TODO: find a better solution to create a diagonal matrix new_diag = self.cov_diag[..., np.newaxis] * np.identity( self.loc.shape[-1]) covariance_matrix = new_diag + np.matmul( self.cov_factor, np.swapaxes(self.cov_factor, -1, -2)) return covariance_matrix
def inv_vec_transform(y): matrix = vec_to_tril_matrix(y) if constraint is constraints.positive_definite: # fill the upper triangular part matrix = matrix + np.swapaxes(matrix, -2, -1) - np.diag( np.diag(matrix)) return transform.inv(matrix)
def model_1(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) def transition_fn(carry, y): x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): x = numpyro.sample( "x", dist.Categorical(probs_x[x_prev]), infer={"enumerate": "parallel"}, ) with numpyro.plate("tones", data_dim, dim=-1): numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y) return (x, t + 1), None x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) # NB swapaxes: we move time dimension of `sequences` to the front to scan over it scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
def final_fn(state, regularize=False): """ :param state: Current state of the scheme. :param bool regularize: Whether to adjust diagonal for numerical stability. :return: a triple of estimated covariance, the square root of precision, and the inverse of that square root. """ mean, m2, n = state # XXX it is not necessary to check for the case n=1 cov = m2 / (n - 1) if regularize: # Regularization from Stan scaled_cov = (n / (n + 5)) * cov shrinkage = 1e-3 * (5 / (n + 5)) if diagonal: cov = scaled_cov + shrinkage else: cov = scaled_cov + shrinkage * jnp.identity(mean.shape[0]) if jnp.ndim(cov) == 2: # copy the implementation of distributions.util.cholesky_of_inverse here tril_inv = jnp.swapaxes( jnp.linalg.cholesky(cov[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1) identity = jnp.identity(cov.shape[-1]) cov_inv_sqrt = solve_triangular(tril_inv, identity, lower=True) else: tril_inv = jnp.sqrt(cov) cov_inv_sqrt = jnp.reciprocal(tril_inv) return cov, cov_inv_sqrt, tril_inv
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1) # we need only a single key for initializing PE / constraints fn rng_key_init_model = rng_key_init_model[0] init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') # NB: init args is different from HMC sa_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 init_params, num_warmup=num_warmup, adapt_state_size=self._adapt_state_size, dense_mass=self._dense_mass, rng_key=rng_key, model_args=model_args, model_kwargs=model_kwargs, ) if rng_key.ndim == 1: init_state = sa_init_fn(init_params, rng_key) else: init_state = vmap(sa_init_fn)(init_params, rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) self._sample_fn = sample_fn return init_state
def unroll_reinject(initial_states, initial_token, sequence_length, apply_embedding, apply_rnn, apply_readout): """Unrolls an RNN, reinjecting the output back into the RNN.""" def _step(state, _): # Unpack loop state. tokens, rnn_state = state # Apply embedding, RNN, and readout. rnn_inputs = apply_embedding(tokens) rnn_state = apply_rnn(rnn_inputs, rnn_state) logits = apply_readout(rnn_state) # Pack new loop state next_state = (jnp.argmax(logits, axis=-1), rnn_state) return next_state, logits # Format scan arguments. batch_size = initial_states.shape[0] batch_inputs = initial_token * jnp.ones(batch_size).astype(jnp.int32) dummy_inputs = jnp.zeros((sequence_length, 1)) # Unroll loop via scan. _, outputs = jax.lax.scan(_step, (batch_inputs, initial_states), dummy_inputs) return jnp.swapaxes(outputs, 0, 1)
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): eps = 1e-6 if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints._GreaterThan): return np.exp(random.normal(key, size)) + constraint.lower_bound + eps elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.randint(key, size, lower_bound, upper_bound + 1) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound + poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=lower_bound, maxval=upper_bound) elif isinstance(constraint, constraints._Real): return random.normal(key, size) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1]) elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) elif isinstance(constraint, constraints._LowerCholesky): return np.tril(random.uniform(key, size)) elif isinstance(constraint, constraints._PositiveDefinite): x = random.normal(key, size) return np.matmul(x, np.swapaxes(x, -2, -1)) else: raise NotImplementedError('{} not implemented.'.format(constraint))
def model_4(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape hidden_dim = int(args.hidden_dim**0.5) # split between w and x with mask(mask=include_prior): probs_w = numpyro.sample( "probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)) probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).expand_by( [hidden_dim]).to_event(2)) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3)) def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample( "x", dist.Categorical(Vindex(probs_x)[w, x_prev])) with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
def __call__(self, x): # check for symmetric symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1) # check for the smallest eigenvalue is positive positive = jnp.linalg.eigh(x)[0][..., 0] > 0 return symmetric & positive
def _construct_derivatives(self, x_mp): """Builds derivatives of the network outputs wrt model parameters The network outputs from the simulations generated with model parameter values above and below the fiducial are subtracted from each other and divided by the perturbation size in each model parameter value. The axes are swapped such that the derivatives with respect to parameters are in the last axis. .. math:: \\frac{\\partial{\\bf x}^i}{\\partial\\theta_\\alpha} = \\frac{{\\bf x}^i_{\\alpha^+}-{\\bf x}^i_{\\alpha^+}}{ \\delta\\theta_\\alpha} Parameters ---------- derivatives : float(n_d, 2, n_params, n_summaries) The outputs of the network of simulations made at perturbed parameter values to construct the derivative of the network outputs with respect to the model parameters numerically Returns ------- float(n_d, n_summaries, n_params): The numerical derivatives of the network ouputs with respect to the model parameters """ return np.swapaxes(x_mp[:, 1] - x_mp[:, 0], 1, 2) / self.δθ
def model_3(capture_history, sex): N, T = capture_history.shape phi_mean = numpyro.sample("phi_mean", dist.Uniform(0.0, 1.0)) # mean survival probability phi_logit_mean = logit(phi_mean) # controls temporal variability of survival probability phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0)) rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability def transition_fn(carry, y): first_capture_mask, z = carry with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}): phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma)) phi_t = expit(phi_logit_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None z = jnp.ones(N, dtype=jnp.int32) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = capture_history[:, 0].astype(bool) # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it scan(transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1))
def stable_svd_jvp(primals, tangents): """Copied from the JAX source code and slightly tweaked for stability""" # Deformation parameter which yields regular SVD JVP rule when set to 0 eps = 1e-10 A, = primals dA, = tangents U, s, Vt = jnp.linalg.svd(A, full_matrices=False, compute_uv=True) _T = lambda x: jnp.swapaxes(x, -1, -2) _H = lambda x: jnp.conj(_T(x)) k = s.shape[-1] Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = jnp.matmul(jnp.matmul(Ut, dA), V) ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) # Deformation by eps avoids getting NaN's when SV's are degenerate f = jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k) f = f + eps / f # eps controls stability F = 1 / f - jnp.eye(k) / (1 + eps) dSS = s_dim * dS SdS = _T(s_dim) * dS dU = jnp.matmul(U, F * (dSS + _T(dSS))) dV = jnp.matmul(V, F * (SdS + _T(SdS))) m, n = A.shape[-2], A.shape[-1] if m > n: dU = dU + jnp.matmul( jnp.eye(m) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim if n > m: dV = dV + jnp.matmul( jnp.eye(n) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim return (U, s, Vt), (dU, ds, _T(dV))
def infer(xs, a_min=None, a_max=None): @partial( jax.pmap, in_axes=(None, None, 0, 0), static_broadcasted_argnums=(0, 1), axis_name="device", ) def _infer(model, n_refeed, params, xs): def body_fun(inputs, i): x = inputs y_hat = model.apply(params, x) y_hat = jnp.clip(x, a_min, a_max) x = deepx.optimise.refeed(x, y_hat) # add the new pred to the inputs return x, y_hat _, ys_hat = jax.lax.scan(body_fun, xs, xs=jnp.arange(n_refeed)) ys_hat = jnp.swapaxes(jnp.squeeze(ys_hat), 0, 1) return ys_hat model, hparams, params = load_model("p3aetobr/9336") # model, hparams, params = load_model("p3aetobr/9996") # model, hparams, params = load_model("p3aetobr/10326") # model, hparams, params = load_model("p3aetobr/10976") start = time.time() if (a_min is None) and (a_max is None): ys_hat = deepx.optimise.infer(model, n_refeed, params, xs) else: ys_hat = _infer(model, n_refeed, params, xs) print("Solved forward propagation to {}ms in: {}s".format( n_refeed * 5, time.time() - start)) if ys_hat.shape[0] == 1: ys_hat = jnp.swapaxes(ys_hat, -3, -4)[None] return ys_hat
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: rng_key, rng_key_init_model = jnp.swapaxes( vmap(random.split)(rng_key), 0, 1) init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) if self._potential_fn and init_params is None: raise ValueError( 'Valid value of `init_params` must be provided with' ' `potential_fn`.') # change dense_mass to a structural form dense_mass = self._dense_mass inverse_mass_matrix = self._inverse_mass_matrix if self._model is not None: z = init_params[0] if isinstance(init_params, ParamInfo) else init_params if isinstance(dense_mass, bool): # XXX: by default, the order variables are sorted by their names, # this is to be compatible with older numpyro versions # and to match autoguide scale parameter and jax flatten utils dense_mass = [tuple(sorted(z))] if dense_mass else [] assert isinstance(dense_mass, list) hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 init_params, num_warmup=num_warmup, step_size=self._step_size, inverse_mass_matrix=inverse_mass_matrix, adapt_step_size=self._adapt_step_size, adapt_mass_matrix=self._adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=self._target_accept_prob, trajectory_length=self._trajectory_length, max_tree_depth=self._max_tree_depth, find_heuristic_step_size=self._find_heuristic_step_size, forward_mode_differentiation=self._forward_mode_differentiation, model_args=model_args, model_kwargs=model_kwargs, rng_key=rng_key, ) if rng_key.ndim == 1: init_state = hmc_init_fn(init_params, rng_key) else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) self._sample_fn = sample_fn return init_state
def rkhsel_gram(self, X, Y=None, logsp=False): """ X - axis 0 contains observations (of sample sets), axis 1 is input dimension, axis 2 are different points per observation (the samples of a sample set) """ assert (not logsp) if Y is not None: assert (len(Y.shape) == 2) if len(X.shape) == 2: return self.gram(X, Y) assert (len(X.shape) == 3) X_resh = np.concatenate(np.swapaxes( X, 0, 2), axis=1).T #np.swapaxes(X, 1,2).reshape(-1, X.shape[1]) if Y is None: # compute the full gram matrix G = self.gram(X_resh) # sum up the blockmatrices of shape (X.shape[2], X.shape[2]) that make up G G = np.mean( np.split(np.mean(np.split(G, X.shape[2], 1), 0), X.shape[2]), 0) # return the matrix of RKHS inner products of the mean embeding objects return G else: return np.mean(np.split(self.gram(X_resh, Y), X.shape[2]), 0)
def model_6(sequences, lengths, args, include_prior=False): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).expand( [args.hidden_dim, args.hidden_dim]).to_event(2), ) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) def transition_fn(carry, y): x_prev, x_curr, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, numpyro.sample( "x", dist.Categorical(probs_x_t)) with numpyro.plate("tones", data_dim, dim=-1): probs_y_t = probs_y[x_curr.squeeze(-1)] numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y) return (x_prev, x_curr, t + 1), None x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (x_prev, x_curr, 0), jnp.swapaxes(sequences, 0, 1), history=2)
def Wishart(key, dof, scale, shape=None): if scale is None: scale = jnp.eye(shape) batch_shape = () if jnp.ndim(scale) > 2: batch_shape = scale.shape[:-2] p = scale.shape[-1] if dof is None: dof = p if jnp.ndim(dof) > 0: raise ValueError("only scalar dof implemented") if ~(int(dof) == dof): raise ValueError( "dof should be integer-like (i.e. int(dof) == dof should return true)" ) else: dof = int(dof) if shape is not None: if batch_shape != (): assert batch_shape == shape, "Disagreement in batch shape between scale and shape" else: batch_shape = shape mn = jnp.zeros(shape=batch_shape + (p, )) mvn_shape = (dof, ) + batch_shape mvn = random.multivariate_normal(key, mean=mn, cov=scale, shape=mvn_shape) if jnp.ndim(mvn) > 2: mvn = jnp.swapaxes(mvn, 0, -2) S = jnp.einsum('...ji,...jk', mvn, mvn) return S
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: rng_key, rng_key_init_model = jnp.swapaxes( vmap(random.split)(rng_key), 0, 1) init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) if self._potential_fn and init_params is None: raise ValueError( 'Valid value of `init_params` must be provided with' ' `target_log_prob_fn`.') if rng_key.ndim == 1: init_state = self._init_fn(init_params, rng_key) else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(self._init_fn)(init_params, rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) self._sample_fn = sample_fn return init_state
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). mask = jax.lax.tie_in(dots, mask) dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - utils.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def dot_product_attention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate - keep probability mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: dots = np.where(mask, dots, -1e9) dots = stax.softmax(dots, axis=-1) if dropout is not None and mode == 'train': keep = random.bernoulli(rng, dropout, dots.shape) dots = np.where(keep, dots / dropout, 0) out = np.matmul(dots, value) return out
def sample(rng, memory, B, D, n_dim, pop_size): """ Jittable Gaussian Sample Helper. """ z = jax.random.normal(rng, (n_dim, pop_size)) # ~ N(0, I) y = B.dot(jnp.diag(D)).dot(z) # ~ N(0, C) y = jnp.swapaxes(y, 1, 0) x = memory["mean"] + memory["sigma"] * y # ~ N(m, σ^2 C) return x
def model_2(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3), ) def transition_fn(carry, y): x_prev, y_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): x = numpyro.sample( "x", dist.Categorical(probs_x[x_prev]), infer={"enumerate": "parallel"}, ) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with numpyro.plate("tones", data_dim, dim=-1) as tones: y = numpyro.sample("y", dist.Bernoulli(probs_y[x, y_prev, tones]), obs=y) return (x, y, t + 1), None x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) y_init = jnp.zeros((num_sequences, data_dim), dtype=jnp.int32) scan(transition_fn, (x_init, y_init, 0), jnp.swapaxes(sequences, 0, 1))
def materialize_matrix(symmetric_matrix): """Returns a materialized symmetric matrix. Args: symmetric_matrix: the matrix represented by lower-triangular block slices. """ block_rows = symmetric_matrix.block_rows block_size = block_rows[0].shape[-2] num_blocks = len(block_rows) # Slice the lower-triangular and diagonal blocks into blocks. blocks = [[ block_row[Ellipsis, i * block_size:(i + 1) * block_size] for i in range(k + 1) ] for k, block_row in enumerate(block_rows)] # Generate the (off-diagonal) upper-triangular blocks. off_diags = [[] for _ in range(num_blocks - 1)] for k, block_row in enumerate(block_rows[1:]): for i in range(k + 1): off_diags[i].append( jnp.swapaxes(a=block_row[Ellipsis, i * block_size:(i + 1) * block_size], axis1=-1, axis2=-2)) return jnp.block( [row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]])
def model_1(capture_history, sex): N, T = capture_history.shape phi = numpyro.sample("phi", dist.Uniform(0.0, 1.0)) # survival probability rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability def transition_fn(carry, y): first_capture_mask, z = carry with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y ) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None z = jnp.ones(N, dtype=jnp.int32) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = capture_history[:, 0].astype(bool) # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it scan( transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1), )
def inv_vec_transform(y): matrix = vec_to_tril_matrix(y, diagonal=-1) if constraint is constraints.corr_matrix: # fill the upper triangular part matrix = matrix + np.swapaxes( matrix, -2, -1) + np.identity(matrix.shape[-1]) return transform.inv(matrix)
def build_t_op(core_tensor, direction, jitted=True): """ Get the transfer operator for a TI-MPS, which acts on an input matrix Args: core_tensor: MPS core tensor of shape (bond_dim, bond_dim, in_dim) which defines the output transfer operator direction: Either 'left', 'right', or 'both', specifying the direction in which output transfer operator propagates its input. The last case gives a bidirectional t_op which acts on a pair of density operators, described as a matrix with extra batch index of dim 2 jitted: Whether we want the output transfer operator function to be passed through Jax's `jit` function """ assert direction in ['left', 'right', 'both'] if direction == 'left': t_op = lambda mat: np.einsum('cai,ab,dbi->cd', core_tensor, mat, core_tensor) elif direction == 'right': t_op = lambda mat: np.einsum('aci,ab,bdi->cd', core_tensor, mat, core_tensor) elif direction == 'both': core_tensors = np.stack([core_tensor, np.swapaxes(core_tensor, 0, 1)]) t_op = lambda mat: np.einsum('Baci,Bab,Bbdi->Bcd', core_tensors, mat, core_tensors) return jax.jit(t_op) if jitted else t_op
def interp_manygrids(grids, xs, axis=0, return_wnext=True, trim=False): # this routine interpolates xs on many grids, defined along # the axis in an array grids. (so for axis=0 grids are #grids[:,i,j,k] for all i, j, k) assert np.all(np.diff(grids, axis=axis) > 0) ''' if trim: xs = np.clip(xs[:,None,None], grids.min(axis=axis,keepdims=True), grids.max(axis=axis,keepdims=True)) ''' # this requires everything to be sorted mat = grids[..., None] < xs[(None, ) * grids.ndim + (slice(None), )] ng = grids.shape[axis] j = np.clip(np.sum(mat, axis=axis)[None, ...] - 1, 0, ng - 2) j = np.swapaxes(j, -1, axis).squeeze(axis=-1) grid_j = np.take_along_axis(grids, j, axis=axis) grid_jp = np.take_along_axis(grids, j + 1, axis=axis) xs_r = xs.reshape((1, ) * (axis - 1) + (xs.size, ) + (1, ) * (grids.ndim - 1 - axis)) wnext = (xs_r - grid_j) / (grid_jp - grid_j) return j, (wnext if return_wnext else 1 - wnext)
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: rng_key, rng_key_init_model = np.swapaxes( vmap(random.split)(rng_key), 0, 1) # we need only a single key for initializing PE / constraints fn rng_key_init_model = rng_key_init_model[0] if not self._init_fn: self._init_state(rng_key_init_model, model_args, model_kwargs) if self._potential_fn and init_params is None: raise ValueError( 'Valid value of `init_params` must be provided with' ' `potential_fn`.') # Find valid initial params if self._model and not init_params: init_params, is_valid = find_valid_initial_params( rng_key, self._model, init_strategy=self._init_strategy, param_as_improper=True, model_args=model_args, model_kwargs=model_kwargs) if not_jax_tracer(is_valid): if device_get(~np.all(is_valid)): raise RuntimeError("Cannot find valid initial parameters. " "Please check your model again.") hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 init_params, num_warmup=num_warmup, step_size=self._step_size, adapt_step_size=self._adapt_step_size, adapt_mass_matrix=self._adapt_mass_matrix, dense_mass=self._dense_mass, target_accept_prob=self._target_accept_prob, trajectory_length=self._trajectory_length, max_tree_depth=self._max_tree_depth, rng_key=rng_key, model_args=model_args, model_kwargs=model_kwargs, ) if rng_key.ndim == 1: init_state = hmc_init_fn(init_params, rng_key) else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) self._sample_fn = sample_fn return init_state
def test_cholesky_update(chol_batch_shape, vec_batch_shape, dim, coef): A = random.normal(random.PRNGKey(0), chol_batch_shape + (dim, dim)) A = A @ jnp.swapaxes(A, -2, -1) + jnp.eye(dim) x = random.normal(random.PRNGKey(0), vec_batch_shape + (dim,)) * 0.1 xxt = x[..., None] @ x[..., None, :] expected = jnp.linalg.cholesky(A + coef * xxt) actual = cholesky_update(jnp.linalg.cholesky(A), x, coef) assert_allclose(actual, expected, atol=1e-4, rtol=1e-4)
def __call__(self, x): # check for symmetric symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1) # check for the smallest eigenvalue is positive positive = jnp.linalg.eigh(x)[0][..., 0] > 0 # check for diagonal equal to 1 unit_variance = jnp.all(jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1) return symmetric & positive & unit_variance