示例#1
0
 def covariance_matrix(self):
     # TODO: find a better solution to create a diagonal matrix
     new_diag = self.cov_diag[..., np.newaxis] * np.identity(
         self.loc.shape[-1])
     covariance_matrix = new_diag + np.matmul(
         self.cov_factor, np.swapaxes(self.cov_factor, -1, -2))
     return covariance_matrix
示例#2
0
 def inv_vec_transform(y):
     matrix = vec_to_tril_matrix(y)
     if constraint is constraints.positive_definite:
         # fill the upper triangular part
         matrix = matrix + np.swapaxes(matrix, -2, -1) - np.diag(
             np.diag(matrix))
     return transform.inv(matrix)
示例#3
0
def model_1(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with mask(mask=include_prior):
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                x = numpyro.sample(
                    "x",
                    dist.Categorical(probs_x[x_prev]),
                    infer={"enumerate": "parallel"},
                )
                with numpyro.plate("tones", data_dim, dim=-1):
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[x.squeeze(-1)]),
                                   obs=y)
        return (x, t + 1), None

    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    # NB swapaxes: we move time dimension of `sequences` to the front to scan over it
    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
示例#4
0
 def final_fn(state, regularize=False):
     """
     :param state: Current state of the scheme.
     :param bool regularize: Whether to adjust diagonal for numerical stability.
     :return: a triple of estimated covariance, the square root of precision, and
         the inverse of that square root.
     """
     mean, m2, n = state
     # XXX it is not necessary to check for the case n=1
     cov = m2 / (n - 1)
     if regularize:
         # Regularization from Stan
         scaled_cov = (n / (n + 5)) * cov
         shrinkage = 1e-3 * (5 / (n + 5))
         if diagonal:
             cov = scaled_cov + shrinkage
         else:
             cov = scaled_cov + shrinkage * jnp.identity(mean.shape[0])
     if jnp.ndim(cov) == 2:
         # copy the implementation of distributions.util.cholesky_of_inverse here
         tril_inv = jnp.swapaxes(
             jnp.linalg.cholesky(cov[..., ::-1, ::-1])[..., ::-1, ::-1], -2,
             -1)
         identity = jnp.identity(cov.shape[-1])
         cov_inv_sqrt = solve_triangular(tril_inv, identity, lower=True)
     else:
         tril_inv = jnp.sqrt(cov)
         cov_inv_sqrt = jnp.reciprocal(tril_inv)
     return cov, cov_inv_sqrt, tril_inv
示例#5
0
文件: sa.py 项目: gully/numpyro
    def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}):
        # non-vectorized
        if rng_key.ndim == 1:
            rng_key, rng_key_init_model = random.split(rng_key)
        # vectorized
        else:
            rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
            # we need only a single key for initializing PE / constraints fn
            rng_key_init_model = rng_key_init_model[0]
        init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
        if self._potential_fn and init_params is None:
            raise ValueError('Valid value of `init_params` must be provided with'
                             ' `potential_fn`.')

        # NB: init args is different from HMC
        sa_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
            init_params,
            num_warmup=num_warmup,
            adapt_state_size=self._adapt_state_size,
            dense_mass=self._dense_mass,
            rng_key=rng_key,
            model_args=model_args,
            model_kwargs=model_kwargs,
        )
        if rng_key.ndim == 1:
            init_state = sa_init_fn(init_params, rng_key)
        else:
            init_state = vmap(sa_init_fn)(init_params, rng_key)
            sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
            self._sample_fn = sample_fn
        return init_state
def unroll_reinject(initial_states, initial_token, sequence_length,
                    apply_embedding, apply_rnn, apply_readout):
  """Unrolls an RNN, reinjecting the output back into the RNN."""

  def _step(state, _):

    # Unpack loop state.
    tokens, rnn_state = state

    # Apply embedding, RNN, and readout.
    rnn_inputs = apply_embedding(tokens)
    rnn_state = apply_rnn(rnn_inputs, rnn_state)
    logits = apply_readout(rnn_state)

    # Pack new loop state
    next_state = (jnp.argmax(logits, axis=-1), rnn_state)

    return next_state, logits

  # Format scan arguments.
  batch_size = initial_states.shape[0]
  batch_inputs = initial_token * jnp.ones(batch_size).astype(jnp.int32)
  dummy_inputs = jnp.zeros((sequence_length, 1))

  # Unroll loop via scan.
  _, outputs = jax.lax.scan(_step, (batch_inputs, initial_states), dummy_inputs)

  return jnp.swapaxes(outputs, 0, 1)
示例#7
0
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return np.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, constraints._Real):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return np.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return np.matmul(x, np.swapaxes(x, -2, -1))
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
示例#8
0
def model_4(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with mask(mask=include_prior):
        probs_w = numpyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).expand_by(
                [hidden_dim]).to_event(2))
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))

    def transition_fn(carry, y):
        w_prev, x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
                x = numpyro.sample(
                    "x", dist.Categorical(Vindex(probs_x)[w, x_prev]))
                with numpyro.plate("tones", data_dim, dim=-1) as tones:
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[w, x, tones]),
                                   obs=y)
        return (w, x, t + 1), None

    w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
示例#9
0
 def __call__(self, x):
     # check for symmetric
     symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1),
                         axis=-1)
     # check for the smallest eigenvalue is positive
     positive = jnp.linalg.eigh(x)[0][..., 0] > 0
     return symmetric & positive
示例#10
0
    def _construct_derivatives(self, x_mp):
        """Builds derivatives of the network outputs wrt model parameters

        The network outputs from the simulations generated with model parameter
        values above and below the fiducial are subtracted from each other and
        divided by the perturbation size in each model parameter value. The
        axes are swapped such that the derivatives with respect to parameters
        are in the last axis.

        .. math::
            \\frac{\\partial{\\bf x}^i}{\\partial\\theta_\\alpha} =
            \\frac{{\\bf x}^i_{\\alpha^+}-{\\bf x}^i_{\\alpha^+}}{
            \\delta\\theta_\\alpha}

        Parameters
        ----------
        derivatives : float(n_d, 2, n_params, n_summaries)
            The outputs of the network of simulations made at perturbed
            parameter values to construct the derivative of the network outputs
            with respect to the model parameters numerically

        Returns
        -------
        float(n_d, n_summaries, n_params):
            The numerical derivatives of the network ouputs with respect to the
            model parameters
        """
        return np.swapaxes(x_mp[:, 1] - x_mp[:, 0], 1, 2) / self.δθ
示例#11
0
def model_3(capture_history, sex):
    N, T = capture_history.shape
    phi_mean = numpyro.sample("phi_mean", dist.Uniform(0.0, 1.0))  # mean survival probability
    phi_logit_mean = logit(phi_mean)
    # controls temporal variability of survival probability
    phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0))
    rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}):
            phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma))
        phi_t = expit(phi_logit_t)
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y)

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None

    z = jnp.ones(N, dtype=jnp.int32)
    # we use this mask to eliminate extraneous log probabilities
    # that arise for a given individual before its first capture.
    first_capture_mask = capture_history[:, 0].astype(bool)
    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
    scan(transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1))
示例#12
0
def stable_svd_jvp(primals, tangents):
    """Copied from the JAX source code and slightly tweaked for stability"""
    # Deformation parameter which yields regular SVD JVP rule when set to 0
    eps = 1e-10
    A, = primals
    dA, = tangents
    U, s, Vt = jnp.linalg.svd(A, full_matrices=False, compute_uv=True)

    _T = lambda x: jnp.swapaxes(x, -1, -2)
    _H = lambda x: jnp.conj(_T(x))
    k = s.shape[-1]
    Ut, V = _H(U), _H(Vt)
    s_dim = s[..., None, :]
    dS = jnp.matmul(jnp.matmul(Ut, dA), V)
    ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

    # Deformation by eps avoids getting NaN's when SV's are degenerate
    f = jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k)
    f = f + eps / f  # eps controls stability
    F = 1 / f - jnp.eye(k) / (1 + eps)

    dSS = s_dim * dS
    SdS = _T(s_dim) * dS
    dU = jnp.matmul(U, F * (dSS + _T(dSS)))
    dV = jnp.matmul(V, F * (SdS + _T(SdS)))

    m, n = A.shape[-2], A.shape[-1]
    if m > n:
        dU = dU + jnp.matmul(
            jnp.eye(m) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
    if n > m:
        dV = dV + jnp.matmul(
            jnp.eye(n) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim
    return (U, s, Vt), (dU, ds, _T(dV))
def infer(xs, a_min=None, a_max=None):
    @partial(
        jax.pmap,
        in_axes=(None, None, 0, 0),
        static_broadcasted_argnums=(0, 1),
        axis_name="device",
    )
    def _infer(model, n_refeed, params, xs):
        def body_fun(inputs, i):
            x = inputs
            y_hat = model.apply(params, x)
            y_hat = jnp.clip(x, a_min, a_max)
            x = deepx.optimise.refeed(x,
                                      y_hat)  #  add the new pred to the inputs
            return x, y_hat

        _, ys_hat = jax.lax.scan(body_fun, xs, xs=jnp.arange(n_refeed))
        ys_hat = jnp.swapaxes(jnp.squeeze(ys_hat), 0, 1)
        return ys_hat

    model, hparams, params = load_model("p3aetobr/9336")
    #     model, hparams, params = load_model("p3aetobr/9996")
    #     model, hparams, params = load_model("p3aetobr/10326")
    #     model, hparams, params = load_model("p3aetobr/10976")
    start = time.time()
    if (a_min is None) and (a_max is None):
        ys_hat = deepx.optimise.infer(model, n_refeed, params, xs)
    else:
        ys_hat = _infer(model, n_refeed, params, xs)
    print("Solved forward propagation to {}ms in: {}s".format(
        n_refeed * 5,
        time.time() - start))
    if ys_hat.shape[0] == 1:
        ys_hat = jnp.swapaxes(ys_hat, -3, -4)[None]
    return ys_hat
示例#14
0
    def init(self,
             rng_key,
             num_warmup,
             init_params=None,
             model_args=(),
             model_kwargs={}):
        # non-vectorized
        if rng_key.ndim == 1:
            rng_key, rng_key_init_model = random.split(rng_key)
        # vectorized
        else:
            rng_key, rng_key_init_model = jnp.swapaxes(
                vmap(random.split)(rng_key), 0, 1)
        init_params = self._init_state(rng_key_init_model, model_args,
                                       model_kwargs, init_params)
        if self._potential_fn and init_params is None:
            raise ValueError(
                'Valid value of `init_params` must be provided with'
                ' `potential_fn`.')

        # change dense_mass to a structural form
        dense_mass = self._dense_mass
        inverse_mass_matrix = self._inverse_mass_matrix
        if self._model is not None:
            z = init_params[0] if isinstance(init_params,
                                             ParamInfo) else init_params
            if isinstance(dense_mass, bool):
                # XXX: by default, the order variables are sorted by their names,
                # this is to be compatible with older numpyro versions
                # and to match autoguide scale parameter and jax flatten utils
                dense_mass = [tuple(sorted(z))] if dense_mass else []
            assert isinstance(dense_mass, list)

        hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
            init_params,
            num_warmup=num_warmup,
            step_size=self._step_size,
            inverse_mass_matrix=inverse_mass_matrix,
            adapt_step_size=self._adapt_step_size,
            adapt_mass_matrix=self._adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=self._target_accept_prob,
            trajectory_length=self._trajectory_length,
            max_tree_depth=self._max_tree_depth,
            find_heuristic_step_size=self._find_heuristic_step_size,
            forward_mode_differentiation=self._forward_mode_differentiation,
            model_args=model_args,
            model_kwargs=model_kwargs,
            rng_key=rng_key,
        )
        if rng_key.ndim == 1:
            init_state = hmc_init_fn(init_params, rng_key)
        else:
            # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
            # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
            # wa_steps because those variables do not depend on traced args: init_params, rng_key.
            init_state = vmap(hmc_init_fn)(init_params, rng_key)
            sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
            self._sample_fn = sample_fn
        return init_state
示例#15
0
    def rkhsel_gram(self, X, Y=None, logsp=False):
        """
        X - axis 0 contains observations (of sample sets), axis 1 is input dimension, axis 2 are different points per observation (the samples of a sample set)
        """

        assert (not logsp)
        if Y is not None:
            assert (len(Y.shape) == 2)
        if len(X.shape) == 2:
            return self.gram(X, Y)
        assert (len(X.shape) == 3)

        X_resh = np.concatenate(np.swapaxes(
            X, 0, 2), axis=1).T  #np.swapaxes(X, 1,2).reshape(-1, X.shape[1])
        if Y is None:
            # compute the full gram matrix
            G = self.gram(X_resh)
            # sum up the blockmatrices of shape (X.shape[2], X.shape[2]) that make up G
            G = np.mean(
                np.split(np.mean(np.split(G, X.shape[2], 1), 0), X.shape[2]),
                0)

            # return the matrix of RKHS inner products of the mean embeding objects
            return G
        else:
            return np.mean(np.split(self.gram(X_resh, Y), X.shape[2]), 0)
示例#16
0
def model_6(sequences, lengths, args, include_prior=False):
    num_sequences, max_length, data_dim = sequences.shape

    with mask(mask=include_prior):
        # Explicitly parameterize the full tensor of transition probabilities, which
        # has hidden_dim cubed entries.
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).expand(
                [args.hidden_dim, args.hidden_dim]).to_event(2),
        )

        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, x_curr, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                probs_x_t = Vindex(probs_x)[x_prev, x_curr]
                x_prev, x_curr = x_curr, numpyro.sample(
                    "x", dist.Categorical(probs_x_t))
                with numpyro.plate("tones", data_dim, dim=-1):
                    probs_y_t = probs_y[x_curr.squeeze(-1)]
                    numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y)
        return (x_prev, x_curr, t + 1), None

    x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    scan(transition_fn, (x_prev, x_curr, 0),
         jnp.swapaxes(sequences, 0, 1),
         history=2)
示例#17
0
def Wishart(key, dof, scale, shape=None):
    if scale is None:
        scale = jnp.eye(shape)
    batch_shape = ()
    if jnp.ndim(scale) > 2:
        batch_shape = scale.shape[:-2]
    p = scale.shape[-1]

    if dof is None:
        dof = p
    if jnp.ndim(dof) > 0:
        raise ValueError("only scalar dof implemented")
    if ~(int(dof) == dof):
        raise ValueError(
            "dof should be integer-like (i.e. int(dof) == dof should return true)"
        )
    else:
        dof = int(dof)

    if shape is not None:
        if batch_shape != ():
            assert batch_shape == shape, "Disagreement in batch shape between scale and shape"
        else:
            batch_shape = shape

    mn = jnp.zeros(shape=batch_shape + (p, ))
    mvn_shape = (dof, ) + batch_shape

    mvn = random.multivariate_normal(key, mean=mn, cov=scale, shape=mvn_shape)
    if jnp.ndim(mvn) > 2:
        mvn = jnp.swapaxes(mvn, 0, -2)

    S = jnp.einsum('...ji,...jk', mvn, mvn)

    return S
示例#18
0
    def init(self,
             rng_key,
             num_warmup,
             init_params=None,
             model_args=(),
             model_kwargs={}):
        # non-vectorized
        if rng_key.ndim == 1:
            rng_key, rng_key_init_model = random.split(rng_key)
        # vectorized
        else:
            rng_key, rng_key_init_model = jnp.swapaxes(
                vmap(random.split)(rng_key), 0, 1)
        init_params = self._init_state(rng_key_init_model, model_args,
                                       model_kwargs, init_params)
        if self._potential_fn and init_params is None:
            raise ValueError(
                'Valid value of `init_params` must be provided with'
                ' `target_log_prob_fn`.')

        if rng_key.ndim == 1:
            init_state = self._init_fn(init_params, rng_key)
        else:
            # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
            # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
            # wa_steps because those variables do not depend on traced args: init_params, rng_key.
            init_state = vmap(self._init_fn)(init_params, rng_key)
            sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
            self._sample_fn = sample_fn
        return init_state
示例#19
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        mask = jax.lax.tie_in(dots, mask)
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - utils.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
示例#20
0
def dot_product_attention(query, key, value, mask, dropout, mode, rng):
  """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate - keep probability
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
  depth = np.shape(query)[-1]
  dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
  if mask is not None:
    dots = np.where(mask, dots, -1e9)
  dots = stax.softmax(dots, axis=-1)
  if dropout is not None and mode == 'train':
    keep = random.bernoulli(rng, dropout, dots.shape)
    dots = np.where(keep, dots / dropout, 0)
  out = np.matmul(dots, value)
  return out
示例#21
0
def sample(rng, memory, B, D, n_dim, pop_size):
    """ Jittable Gaussian Sample Helper. """
    z = jax.random.normal(rng, (n_dim, pop_size))  # ~ N(0, I)
    y = B.dot(jnp.diag(D)).dot(z)  # ~ N(0, C)
    y = jnp.swapaxes(y, 1, 0)
    x = memory["mean"] + memory["sigma"] * y  # ~ N(m, σ^2 C)
    return x
示例#22
0
def model_2(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with mask(mask=include_prior):
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1))

        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3),
        )

    def transition_fn(carry, y):
        x_prev, y_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                x = numpyro.sample(
                    "x",
                    dist.Categorical(probs_x[x_prev]),
                    infer={"enumerate": "parallel"},
                )
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with numpyro.plate("tones", data_dim, dim=-1) as tones:
                    y = numpyro.sample("y",
                                       dist.Bernoulli(probs_y[x, y_prev,
                                                              tones]),
                                       obs=y)
        return (x, y, t + 1), None

    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    y_init = jnp.zeros((num_sequences, data_dim), dtype=jnp.int32)
    scan(transition_fn, (x_init, y_init, 0), jnp.swapaxes(sequences, 0, 1))
def materialize_matrix(symmetric_matrix):
    """Returns a materialized symmetric matrix.

  Args:
    symmetric_matrix: the matrix represented by lower-triangular block slices.
  """
    block_rows = symmetric_matrix.block_rows
    block_size = block_rows[0].shape[-2]
    num_blocks = len(block_rows)

    # Slice the lower-triangular and diagonal blocks into blocks.
    blocks = [[
        block_row[Ellipsis, i * block_size:(i + 1) * block_size]
        for i in range(k + 1)
    ] for k, block_row in enumerate(block_rows)]

    # Generate the (off-diagonal) upper-triangular blocks.
    off_diags = [[] for _ in range(num_blocks - 1)]
    for k, block_row in enumerate(block_rows[1:]):
        for i in range(k + 1):
            off_diags[i].append(
                jnp.swapaxes(a=block_row[Ellipsis,
                                         i * block_size:(i + 1) * block_size],
                             axis1=-1,
                             axis2=-2))

    return jnp.block(
        [row + row_t
         for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]])
示例#24
0
def model_1(capture_history, sex):
    N, T = capture_history.shape
    phi = numpyro.sample("phi", dist.Uniform(0.0, 1.0))  # survival probability
    rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample(
                    "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                )

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None

    z = jnp.ones(N, dtype=jnp.int32)
    # we use this mask to eliminate extraneous log probabilities
    # that arise for a given individual before its first capture.
    first_capture_mask = capture_history[:, 0].astype(bool)
    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
    scan(
        transition_fn,
        (first_capture_mask, z),
        jnp.swapaxes(capture_history[:, 1:], 0, 1),
    )
示例#25
0
 def inv_vec_transform(y):
     matrix = vec_to_tril_matrix(y, diagonal=-1)
     if constraint is constraints.corr_matrix:
         # fill the upper triangular part
         matrix = matrix + np.swapaxes(
             matrix, -2, -1) + np.identity(matrix.shape[-1])
     return transform.inv(matrix)
示例#26
0
def build_t_op(core_tensor, direction, jitted=True):
    """
    Get the transfer operator for a TI-MPS, which acts on an input matrix

    Args:
        core_tensor: MPS core tensor of shape (bond_dim, bond_dim, in_dim)
                     which defines the output transfer operator
        direction:   Either 'left', 'right', or 'both', specifying the 
                     direction in which output transfer operator propagates
                     its input. The last case gives a bidirectional t_op
                     which acts on a pair of density operators, described 
                     as a matrix with extra batch index of dim 2
        jitted:      Whether we want the output transfer operator function 
                     to be passed through Jax's `jit` function
    """
    assert direction in ['left', 'right', 'both']

    if direction == 'left':
        t_op = lambda mat: np.einsum('cai,ab,dbi->cd', 
                                     core_tensor, mat, core_tensor)
    elif direction == 'right':
        t_op = lambda mat: np.einsum('aci,ab,bdi->cd', 
                                     core_tensor, mat, core_tensor)
    elif direction == 'both':
        core_tensors = np.stack([core_tensor, 
                                 np.swapaxes(core_tensor, 0, 1)])
        t_op = lambda mat: np.einsum('Baci,Bab,Bbdi->Bcd', 
                                     core_tensors, mat, core_tensors)

    return jax.jit(t_op) if jitted else t_op
示例#27
0
def interp_manygrids(grids, xs, axis=0, return_wnext=True, trim=False):
    # this routine interpolates xs on many grids, defined along
    # the axis in an array grids. (so for axis=0 grids are
    #grids[:,i,j,k] for all i, j, k)

    assert np.all(np.diff(grids, axis=axis) > 0)
    '''
    if trim: xs = np.clip(xs[:,None,None],
                          grids.min(axis=axis,keepdims=True),
                          grids.max(axis=axis,keepdims=True))
    '''

    # this requires everything to be sorted
    mat = grids[..., None] < xs[(None, ) * grids.ndim + (slice(None), )]
    ng = grids.shape[axis]
    j = np.clip(np.sum(mat, axis=axis)[None, ...] - 1, 0, ng - 2)
    j = np.swapaxes(j, -1, axis).squeeze(axis=-1)
    grid_j = np.take_along_axis(grids, j, axis=axis)
    grid_jp = np.take_along_axis(grids, j + 1, axis=axis)

    xs_r = xs.reshape((1, ) * (axis - 1) + (xs.size, ) + (1, ) *
                      (grids.ndim - 1 - axis))

    wnext = (xs_r - grid_j) / (grid_jp - grid_j)
    return j, (wnext if return_wnext else 1 - wnext)
示例#28
0
文件: mcmc.py 项目: cnheider/numpyro
    def init(self,
             rng_key,
             num_warmup,
             init_params=None,
             model_args=(),
             model_kwargs={}):
        # non-vectorized
        if rng_key.ndim == 1:
            rng_key, rng_key_init_model = random.split(rng_key)
        # vectorized
        else:
            rng_key, rng_key_init_model = np.swapaxes(
                vmap(random.split)(rng_key), 0, 1)
            # we need only a single key for initializing PE / constraints fn
            rng_key_init_model = rng_key_init_model[0]
        if not self._init_fn:
            self._init_state(rng_key_init_model, model_args, model_kwargs)
        if self._potential_fn and init_params is None:
            raise ValueError(
                'Valid value of `init_params` must be provided with'
                ' `potential_fn`.')
        # Find valid initial params
        if self._model and not init_params:
            init_params, is_valid = find_valid_initial_params(
                rng_key,
                self._model,
                init_strategy=self._init_strategy,
                param_as_improper=True,
                model_args=model_args,
                model_kwargs=model_kwargs)
            if not_jax_tracer(is_valid):
                if device_get(~np.all(is_valid)):
                    raise RuntimeError("Cannot find valid initial parameters. "
                                       "Please check your model again.")

        hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
            init_params,
            num_warmup=num_warmup,
            step_size=self._step_size,
            adapt_step_size=self._adapt_step_size,
            adapt_mass_matrix=self._adapt_mass_matrix,
            dense_mass=self._dense_mass,
            target_accept_prob=self._target_accept_prob,
            trajectory_length=self._trajectory_length,
            max_tree_depth=self._max_tree_depth,
            rng_key=rng_key,
            model_args=model_args,
            model_kwargs=model_kwargs,
        )
        if rng_key.ndim == 1:
            init_state = hmc_init_fn(init_params, rng_key)
        else:
            # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
            # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
            # wa_steps because those variables do not depend on traced args: init_params, rng_key.
            init_state = vmap(hmc_init_fn)(init_params, rng_key)
            sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
            self._sample_fn = sample_fn
        return init_state
示例#29
0
def test_cholesky_update(chol_batch_shape, vec_batch_shape, dim, coef):
    A = random.normal(random.PRNGKey(0), chol_batch_shape + (dim, dim))
    A = A @ jnp.swapaxes(A, -2, -1) + jnp.eye(dim)
    x = random.normal(random.PRNGKey(0), vec_batch_shape + (dim,)) * 0.1
    xxt = x[..., None] @ x[..., None, :]
    expected = jnp.linalg.cholesky(A + coef * xxt)
    actual = cholesky_update(jnp.linalg.cholesky(A), x, coef)
    assert_allclose(actual, expected, atol=1e-4, rtol=1e-4)
示例#30
0
 def __call__(self, x):
     # check for symmetric
     symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1)
     # check for the smallest eigenvalue is positive
     positive = jnp.linalg.eigh(x)[0][..., 0] > 0
     # check for diagonal equal to 1
     unit_variance = jnp.all(jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1)
     return symmetric & positive & unit_variance