示例#1
0
def model(X, Y, hypers):
    S, P, N = hypers['expected_sparsity'], X.shape[1], X.shape[0]

    sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['alpha3']))
    phi = sigma * (S / np.sqrt(N)) / (P - S)
    eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))

    msq = numpyro.sample("msq",
                         dist.InverseGamma(hypers['alpha1'], hypers['beta1']))
    xisq = numpyro.sample("xisq",
                          dist.InverseGamma(hypers['alpha2'], hypers['beta2']))

    eta2 = np.square(eta1) * np.sqrt(xisq) / msq

    lam = numpyro.sample("lambda", dist.HalfCauchy(np.ones(P)))
    kappa = np.sqrt(msq) * lam / np.sqrt(msq + np.square(eta1 * lam))

    # sample observation noise
    var_obs = numpyro.sample(
        "var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs']))

    # compute kernel
    kX = kappa * X
    k = kernel(kX, kX, eta1, eta2, hypers['c']) + var_obs * np.eye(N)
    assert k.shape == (N, N)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y",
                   dist.MultivariateNormal(loc=np.zeros(X.shape[0]),
                                           covariance_matrix=k),
                   obs=Y)
示例#2
0
文件: utils.py 项目: dimarkov/pybefit
def fulldyn_mixture_model(beliefs, y, mask):
    M, T, N, _ = beliefs[0].shape

    c0 = beliefs[-1]

    tau = .5
    with npyro.plate('N', N):
        weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

        assert weights.shape == (N, M)

        mu = npyro.sample('mu', dist.Normal(5., 5.))

        lam12 = npyro.sample('lam12',
                             dist.HalfCauchy(1.).expand([2]).to_event(1))
        lam34 = npyro.sample('lam34', dist.HalfCauchy(1.))

        _lam34 = jnp.expand_dims(lam34, -1)
        lam0 = npyro.deterministic(
            'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1))

        eta = npyro.sample('eta', dist.Beta(1, 10))

        scale = npyro.sample('scale', dist.HalfNormal(1.))
        theta = npyro.sample('theta', dist.HalfCauchy(5.))
        rho = jnp.exp(-theta)
        sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale

    x0 = jnp.zeros(N)

    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None

    lam_start = npyro.deterministic('lam_start',
                                    lam0 + jnp.expand_dims(eta, -1) * c0)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, (lam_start, x0), jnp.arange(T))
示例#3
0
def singlevariate_kf(T=None, T_forecast=15, obs=None):
    """Define Kalman Filter in a single variate fashion.

    Parameters
    ----------
    T:  int
    T_forecast: int
        Times to forecast ahead.
    obs: np.array
        observed variable (infected, deaths...)

    """
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    T = len(obs) if T is None else T
    beta = numpyro.sample(name="beta", fn=dist.Normal(loc=0.0, scale=1))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=0.1))
    noises = numpyro.sample(
        "noises", fn=dist.Normal(0, 1.0), sample_shape=(T + T_forecast - 2,)
    )
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=0, scale=0.1))

    # Propagate the dynamics forward using jax.lax.scan
    carry = (beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 2)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    # Sample the observed y (y_obs) and missing y (y_mis)
    numpyro.sample(
        name="y_obs", fn=dist.Normal(loc=z_collection[:T], scale=sigma), obs=obs
    )
    numpyro.sample(
        name="y_pred", fn=dist.Normal(loc=z_collection[T:], scale=sigma), obs=None
    )
示例#4
0
def model(T, T_forecast, X, obs=None):
    """ 
    Define priors over delta, tau, noises, sigma, z_prev, eta
    """
    tau = numpyro.sample("tau", dist.HalfCauchy(10.0))
    noises = numpyro.sample(
        "noises",
        dist.Normal(jnp.zeros(X.shape[0]), 5.0 * jnp.ones(X.shape[0])),
    )
    sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=5.0))
    delta = numpyro.sample("delta", dist.Normal(0, 5.0))

    z_prev = numpyro.sample("z_prev", dist.Normal(0, 3.0))

    eta = numpyro.sample(
        "eta",
        dist.Normal(jnp.zeros(X.shape[1]), 5.0 * jnp.ones(X.shape[1])),
    )
    """ Propagate the dynamics forward using jax.lax.scan
    """
    carry = (delta, eta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(
        f=f,
        init=carry,
        xs=(X, noises),
    )
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)
    """ Sample the observed_y (y_obs) and predicted_y (y_pred) - note that you don't need a pyro.plate!
    """
    numpyro.sample("y_obs", dist.Normal(z_collection[:T], sigma), obs=obs)
    numpyro.sample("y_pred", dist.Normal(z_collection[T:], sigma), obs=None)

    return z_collection
示例#5
0
def twoh_c_kf(T=None, T_forecast=15, obs=None):
    """Define Kalman Filter with two hidden variates."""
    T = len(obs) if T is None else T
    
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    #W = numpyro.sample(name="W", fn=dist.Normal(loc=jnp.zeros((2,4)), scale=jnp.ones((2,4))))
    beta = numpyro.sample(name="beta", fn=dist.Normal(loc=jnp.array([0.,0.]), scale=jnp.ones(2)))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=.1))
    z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)))
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.))
    Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega) # lower cholesky factor of the covariance matrix
    noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T+T_forecast,))
    # Propagate the dynamics forward using jax.lax.scan
    carry = (beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, noises, T+T_forecast)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    c = numpyro.sample(name="c", fn=dist.Normal(loc=jnp.array([[0.], [0.]]), scale=jnp.ones((2,1))))
    obs_mean = jnp.dot(z_collection[:T,:], c).squeeze()
    pred_mean = jnp.dot(z_collection[T:,:], c).squeeze()

    # Sample the observed y (y_obs)
    numpyro.sample(name="y_obs", fn=dist.Normal(loc=obs_mean, scale=sigma), obs=obs)
    numpyro.sample(name="y_pred", fn=dist.Normal(loc=pred_mean, scale=sigma), obs=None)
示例#6
0
def model(X, Y, hypers):
    S, P, N = hypers["expected_sparsity"], X.shape[1], X.shape[0]

    sigma = numpyro.sample("sigma", dist.HalfNormal(hypers["alpha3"]))
    phi = sigma * (S / jnp.sqrt(N)) / (P - S)
    eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))

    msq = numpyro.sample("msq",
                         dist.InverseGamma(hypers["alpha1"], hypers["beta1"]))
    xisq = numpyro.sample("xisq",
                          dist.InverseGamma(hypers["alpha2"], hypers["beta2"]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq

    lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P)))
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    # compute kernel
    kX = kappa * X
    k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N)
    assert k.shape == (N, N)

    # sample Y according to the standard gaussian process formula
    numpyro.sample(
        "Y",
        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]),
                                covariance_matrix=k),
        obs=Y,
    )
示例#7
0
def GP(X, y):

    X = numpyro.deterministic("X", X)

    # Set informative priors on kernel hyperparameters.
    η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0))
    ℓ = numpyro.sample("length_scale", dist.Gamma(2.0, 1.0))
    σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0))

    # Compute kernel
    K = rbf_kernel(X, X, η, ℓ)
    K = add_to_diagonal(K, σ)
    K = add_to_diagonal(K, wandb.config.jitter)
    # cholesky decomposition
    Lff = numpyro.deterministic("Lff", cholesky(K, lower=True))

    # Sample y according to the standard gaussian process formula
    return numpyro.sample(
        "y",
        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]),
                                scale_tril=Lff).expand_by(
                                    y.shape[:-1])  # for multioutput scenarios
        .to_event(y.ndim - 1),
        obs=y,
    )
def seir_model(initial_seir_state,
               num_days,
               infection_data,
               recovery_data,
               step_size=0.1):
    if infection_data is not None:
        assert num_days == infection_data.shape[0]
    if recovery_data is not None:
        assert num_days == recovery_data.shape[0]
    beta = numpyro.sample('beta', dist.HalfCauchy(scale=1000.))
    gamma = numpyro.sample('gamma', dist.HalfCauchy(scale=1000.))
    sigma = numpyro.sample('sigma', dist.HalfCauchy(scale=1000.))
    mu = numpyro.sample('mu', dist.HalfCauchy(scale=1000.))
    nu = np.array(0.0)  # No vaccine yet

    def seir_update(day, seir_state):
        s = seir_state[..., 0]
        e = seir_state[..., 1]
        i = seir_state[..., 2]
        r = seir_state[..., 3]
        n = s + e + i + r
        s_upd = mu * (n - s) - beta * (s * i / n) - nu * s
        e_upd = beta * (s * i / n) - (mu + sigma) * e
        i_upd = sigma * e - (mu + gamma) * i
        r_upd = gamma * i - mu * r + nu * s
        return np.stack((s_upd, e_upd, i_upd, r_upd), axis=-1)

    num_steps = int(num_days / step_size)
    sim = runge_kutta_4(seir_update, initial_seir_state, step_size, num_steps)
    sim = np.reshape(sim, (num_days, int(1 / step_size), 4))[:, -1, :] + 1e-3
    with numpyro.plate('data', num_days):
        numpyro.sample('infections',
                       dist.Poisson(sim[:, 2]),
                       obs=infection_data)
        numpyro.sample('recovery', dist.Poisson(sim[:, 3]), obs=recovery_data)
示例#9
0
def multivariate_kf(T=None, T_forecast=15, obs=None):
    """Define Kalman Filter in a multivariate fashion.

    The "time-series" are correlated. To define these relationships in
    a efficient manner, the covarianze matrix of h_t (or, equivalently, the
    noises) is drown from a Cholesky decomposed matrix.

    Parameters
    ----------
    T:  int
    T_forecast: int
    obs: np.array
       observed variable (infected, deaths...)

    """
    T = len(obs) if T is None else T
    beta = numpyro.sample(
        name="beta", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))
    )
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(
        name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))
    )
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0))
    Sigma_lower = jnp.matmul(
        jnp.diag(jnp.sqrt(tau)), L_Omega
    )  # lower cholesky factor of the covariance matrix
    noises = numpyro.sample(
        "noises",
        fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),
        sample_shape=(T + T_forecast - 1,),
    )

    # Propagate the dynamics forward using jax.lax.scan
    carry = (beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 1)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    # Sample the observed y (y_obs) and missing y (y_mis)
    numpyro.sample(
        name="y_obs1",
        fn=dist.Normal(loc=z_collection[:T, 0], scale=sigma),
        obs=obs[:, 0],
    )
    numpyro.sample(
        name="y_pred1", fn=dist.Normal(loc=z_collection[T:, 0], scale=sigma), obs=None
    )
    numpyro.sample(
        name="y_obs2",
        fn=dist.Normal(loc=z_collection[:T, 1], scale=sigma),
        obs=obs[:, 1],
    )
    numpyro.sample(
        name="y_pred2", fn=dist.Normal(loc=z_collection[T:, 1], scale=sigma), obs=None
    )
示例#10
0
def multih_kf(T=None, T_forecast=15, hidden=4, obs=None):
    """Define Kalman Filter: multiple hidden variables; just one time series.

    Parameters
    ----------
    T:  int
    T_forecast: int
        Times to forecast ahead.
    hidden: int
        number of variables in the latent space
    obs: np.array
        observed variable (infected, deaths...)

    """
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    T = len(obs) if T is None else T
    beta = numpyro.sample(
        name="beta", fn=dist.Normal(loc=jnp.zeros(hidden), scale=jnp.ones(hidden))
    )
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(
        name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))
    )
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0))
    Sigma_lower = jnp.matmul(
        jnp.diag(jnp.sqrt(tau)), L_Omega
    )  # lower cholesky factor of the covariance matrix
    noises = numpyro.sample(
        "noises",
        fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),
        sample_shape=(T + T_forecast - 1,),
    )

    # Propagate the dynamics forward using jax.lax.scan
    carry = (beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 1)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    # Sample the observed y (y_obs) and missing y (y_mis)
    numpyro.sample(
        name="y_obs",
        fn=dist.Normal(loc=z_collection[:T, :].sum(axis=1), scale=sigma),
        obs=obs[:, 0],
    )
    numpyro.sample(
        name="y_pred", fn=dist.Normal(loc=z_collection[T:, :].sum(axis=1), scale=sigma), obs=None
    )
示例#11
0
文件: utils.py 项目: dimarkov/pybefit
def fulldyn_single_model(beliefs, y, mask):
    T, _ = beliefs[0].shape

    c0 = beliefs[-1]

    mu = npyro.sample('mu', dist.Normal(5., 5.))

    lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1))
    lam34 = npyro.sample('lam34', dist.HalfCauchy(1.))

    _lam34 = jnp.expand_dims(lam34, -1)
    lam0 = npyro.deterministic(
        'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1))

    eta = npyro.sample('eta', dist.Beta(1, 10))

    scale = npyro.sample('scale', dist.HalfNormal(1.))
    theta = npyro.sample('theta', dist.HalfCauchy(5.))
    rho = jnp.exp(-theta)
    sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale

    x0 = jnp.zeros(1)

    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))
        noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None

    lam_start = npyro.deterministic('lam_start',
                                    lam0 + jnp.expand_dims(eta, -1) * c0)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, (lam_start, x0), jnp.arange(T))
示例#12
0
def pacs_model(priors):
    pointing_matrices = [([p.amat_row, p.amat_col], p.amat_data)
                         for p in priors]
    flux_lower = np.asarray([p.prior_flux_lower for p in priors]).T
    flux_upper = np.asarray([p.prior_flux_upper for p in priors]).T

    bkg_mu = np.asarray([p.bkg[0] for p in priors]).T
    bkg_sig = np.asarray([p.bkg[1] for p in priors]).T

    with numpyro.plate('bands', len(priors)):
        sigma_conf = numpyro.sample('sigma_conf', dist.HalfCauchy(1.0, 0.5))
        bkg = numpyro.sample('bkg', dist.Normal(bkg_mu, bkg_sig))

        with numpyro.plate('nsrc', priors[0].nsrc):
            src_f = numpyro.sample('src_f',
                                   dist.Uniform(flux_lower, flux_upper))
    db_hat_psw = sp_matmul(pointing_matrices[0], src_f[:, 0][:, None],
                           priors[0].snpix).reshape(-1) + bkg[0]
    db_hat_pmw = sp_matmul(pointing_matrices[1], src_f[:, 1][:, None],
                           priors[1].snpix).reshape(-1) + bkg[1]

    sigma_tot_psw = jnp.sqrt(
        jnp.power(priors[0].snim, 2) + jnp.power(sigma_conf[0], 2))
    sigma_tot_pmw = jnp.sqrt(
        jnp.power(priors[1].snim, 2) + jnp.power(sigma_conf[1], 2))

    with numpyro.plate('psw_pixels', priors[0].sim.size):  # as ind_psw:
        numpyro.sample("obs_psw",
                       dist.Normal(db_hat_psw, sigma_tot_psw),
                       obs=priors[0].sim)
    with numpyro.plate('pmw_pixels', priors[1].sim.size):  # as ind_pmw:
        numpyro.sample("obs_pmw",
                       dist.Normal(db_hat_pmw, sigma_tot_pmw),
                       obs=priors[1].sim)
示例#13
0
 def schools_model():
     mu = numpyro.sample("mu", dist.Normal(0, 5))
     tau = numpyro.sample("tau", dist.HalfCauchy(5))
     theta = numpyro.sample("theta",
                            dist.Normal(mu, tau),
                            sample_shape=(data["J"], ))
     numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"])
示例#14
0
 def schools_model():
     mu = numpyro.sample('mu', dist.Normal(0, 5))
     tau = numpyro.sample('tau', dist.HalfCauchy(5))
     theta = numpyro.sample('theta',
                            dist.Normal(mu, tau),
                            sample_shape=(data['J'], ))
     numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])
示例#15
0
 def eight_schools(J, sigma):
     mu = numpyro.sample("mu", dist.Normal(0, 5))
     tau = numpyro.sample("tau", dist.HalfCauchy(5))
     with numpyro.plate("J", J):
         theta = numpyro.sample("theta", dist.Normal(mu, tau))
         numpyro.sample("scores",
                        dist.Normal(theta, sigma),
                        sample_shape=(n, ))
示例#16
0
def partially_pooled_with_logit(at_bats: jnp.ndarray, hits: Optional[jnp.ndarray] = None) -> None:

    loc = numpyro.sample("loc", dist.Normal(-1, 1))
    scale = numpyro.sample("scale", dist.HalfCauchy(1))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        alpha = numpyro.sample("alpha", dist.Normal(loc, scale))
        numpyro.sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)
示例#17
0
    def model(X, Y):
        N, P = X.shape

        sigma = numpyro.sample("sigma", dist.HalfCauchy(1.0))
        beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P)))
        mean = jnp.sum(beta * X, axis=-1)

        numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y)
示例#18
0
    def numpyro_model(X, y):

        if inference == "map" or "vi_mf" or "vi_full":
            # Set priors on hyperparameters.
            η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0))
            ℓ = numpyro.sample(
                "length_scale", dist.Gamma(2.0, 1.0), sample_shape=(n_features,)
            )
            σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0))
        elif inference == "mll":

            # set params and constraints on hyperparams
            η = numpyro.param(
                "variance", init_value=1.0, constraints=dist.constraints.positive
            )
            ℓ = numpyro.param(
                "length_scale",
                init_value=jnp.ones(n_features),
                constraints=dist.constraints.positive,
            )
            σ = numpyro.param(
                "obs_noise", init_value=0.01, onstraints=dist.constraints.positive
            )
        else:
            raise ValueError(f"Unrecognized inference scheme: {inference}")

        x_u = numpyro.param("x_u", init_value=X_u_init)

        # Kernel Function
        rbf_kernel = RBF(variance=η, length_scale=ℓ)

        # GP Model
        gp_model = SGPVFE(
            X=X,
            X_u=x_u,
            y=y,
            mean=zero_mean,
            kernel=rbf_kernel,
            obs_noise=σ,
            jitter=jitter,
        )

        # Sample y according SGP
        return gp_model.to_numpyro(y=y)
示例#19
0
def sgt(y: jnp.ndarray, seasonality: int, future: int = 0) -> None:

    cauchy_sd = jnp.max(y) / 150

    nu = numpyro.sample("nu", dist.Uniform(2, 20))
    powx = numpyro.sample("powx", dist.Uniform(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
    offset_sigma = numpyro.sample(
        "offset_sigma",
        dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd))

    coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
    pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
    pow_trend = 1.5 * pow_trend_beta - 0.5
    pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

    level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
    s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
    init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))

    num_lim = y.shape[0]

    def transition_fn(
        carry: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], t: jnp.ndarray
    ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]:

        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = jnp.clip(exp_val, a_min=0)
        y_t = jnp.where(t >= num_lim, exp_val, y[t])

        moving_sum = moving_sum + y[t] - jnp.where(t >= seasonality,
                                                   y[t - seasonality], 0.0)
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality,
                            y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        new_s = jnp.where(t >= num_lim, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        omega = sigma * exp_val**powx + offset_sigma
        y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega))

        return (level, s, moving_sum), y_

    level_init = y[0]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:]}):
        _, ys = scan(transition_fn, (level_init, s_init, moving_sum),
                     jnp.arange(1, num_lim + future))

    numpyro.deterministic("y_forecast", ys)
示例#20
0
def model_bernoulli_likelihood(X, Y):
    D_X = X.shape[1]

    # sample from horseshoe prior
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X)))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

    # note that this reparameterization (i.e. coordinate transformation) improves
    # posterior geometry and makes NUTS sampling more efficient
    unscaled_betas = numpyro.sample("unscaled_betas",
                                    dist.Normal(0.0, jnp.ones(D_X)))
    scaled_betas = numpyro.deterministic("betas",
                                         tau * lambdas * unscaled_betas)

    # compute mean function using linear coefficients
    mean_function = jnp.dot(X, scaled_betas)

    # observe data
    numpyro.sample("Y", dist.Bernoulli(logits=mean_function), obs=Y)
示例#21
0
def model_w_c(T, T_forecast, x, obs=None):
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    W = numpyro.sample(name="W",
                       fn=dist.Normal(loc=jnp.zeros((2, 4)),
                                      scale=jnp.ones((2, 4))))
    beta = numpyro.sample(name="beta",
                          fn=dist.Normal(loc=jnp.array([0.0, 0.0]),
                                         scale=jnp.ones(2)))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(name="z_1",
                            fn=dist.Normal(loc=jnp.zeros(2),
                                           scale=jnp.ones(2)))
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0))
    Sigma_lower = jnp.matmul(
        jnp.diag(jnp.sqrt(tau)),
        L_Omega)  # lower cholesky factor of the covariance matrix
    noises = numpyro.sample(
        "noises",
        fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),
        sample_shape=(T + T_forecast, ),
    )
    # Propagate the dynamics forward using jax.lax.scan
    carry = (W, beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, (x, noises), T + T_forecast)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    c = numpyro.sample(name="c",
                       fn=dist.Normal(loc=jnp.array([[0.0], [0.0]]),
                                      scale=jnp.ones((2, 1))))
    obs_mean = jnp.dot(z_collection[:T, :], c).squeeze()
    pred_mean = jnp.dot(z_collection[T:, :], c).squeeze()

    # Sample the observed y (y_obs)
    numpyro.sample(name="y_obs",
                   fn=dist.Normal(loc=obs_mean, scale=sigma),
                   obs=obs)
    numpyro.sample(name="y_pred",
                   fn=dist.Normal(loc=pred_mean, scale=sigma),
                   obs=None)
示例#22
0
文件: utils.py 项目: dimarkov/pybefit
def prefdyn_single_model(beliefs, y, mask):
    T, _ = beliefs[0].shape

    c0 = beliefs[-1]

    lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1))
    lam34 = npyro.sample('lam34', dist.HalfCauchy(1.))

    _lam34 = jnp.expand_dims(lam34, -1)
    lam0 = npyro.deterministic(
        'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1))

    eta = npyro.sample('eta', dist.Beta(1, 10))

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.))

    def transition_fn(carry, t):
        lam_prev = carry

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))

        return lam_next, None

    lam_start = npyro.deterministic('lam_start',
                                    lam0 + jnp.expand_dims(eta, -1) * c0)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, lam_start, jnp.arange(T))
示例#23
0
def model_normal_likelihood(X, Y):
    D_X = X.shape[1]

    # sample from horseshoe prior
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X)))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

    # note that in practice for a normal likelihood we would probably want to
    # integrate out the coefficients (as is done for example in sparse_regression.py).
    # however, this trick wouldn't be applicable to other likelihoods
    # (e.g. bernoulli, see below) so we don't make use of it here.
    unscaled_betas = numpyro.sample("unscaled_betas",
                                    dist.Normal(0.0, jnp.ones(D_X)))
    scaled_betas = numpyro.deterministic("betas",
                                         tau * lambdas * unscaled_betas)

    # compute mean function using linear coefficients
    mean_function = jnp.dot(X, scaled_betas)

    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    numpyro.sample("Y", dist.Normal(mean_function, sigma_obs), obs=Y)
示例#24
0
文件: baseball.py 项目: ucals/numpyro
def partially_pooled_with_logit(at_bats, hits=None):
    r"""
    Number of hits has a Binomial distribution with a logit link function.
    The logits $\alpha$ for each player is normally distributed with the
    mean and scale parameters sharing a common prior.

    :param (jnp.DeviceArray) at_bats: Number of at bats for each player.
    :param (jnp.DeviceArray) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    loc = numpyro.sample("loc", dist.Normal(-1, 1))
    scale = numpyro.sample("scale", dist.HalfCauchy(1))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        alpha = numpyro.sample("alpha", dist.Normal(loc, scale))
        return numpyro.sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)
示例#25
0
def model_noncentered(num: int,
                      sigma: np.ndarray,
                      y: Optional[np.ndarray] = None) -> None:

    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("num", num):
        with numpyro.handlers.reparam(config={"theta": TransformReparam()}):
            theta = numpyro.sample(
                "theta",
                dist.TransformedDistribution(
                    dist.Normal(0.0, 1.0),
                    dist.transforms.AffineTransform(mu, tau)),
            )

        numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
示例#26
0
def partially_pooled_with_logit(at_bats, hits=None):
    r"""
    Number of hits has a Binomial distribution with a logit link function.
    The logits $\alpha$ for each player is normally distributed with the
    mean and scale parameters sharing a common prior.

    :param (np.DeviceArray) at_bats: Number of at bats for each player.
    :param (np.DeviceArray) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    num_players = at_bats.shape[0]
    loc = sample("loc", dist.Normal(np.array([-1.]), np.array([1.])))
    scale = sample("scale", dist.HalfCauchy(np.array([1.])))
    shape = np.shape(loc)[:np.ndim(loc) - 1] + (num_players,)
    alpha = sample("alpha", dist.Normal(np.broadcast_to(loc, shape),
                                        np.broadcast_to(scale, shape)))
    return sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)
示例#27
0
文件: hsgp.py 项目: pyro-ppl/numpyro
def yearday_effect(day_of_year):
    slab_df = 50  # 100 in original case study
    slab_scale = 2
    scale_global = 0.1
    tau = sample("tau", dist.HalfNormal(
        2 * scale_global))  # Original uses half-t with 100df
    c_aux = sample("c_aux", dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df))
    c = slab_scale * jnp.sqrt(c_aux)

    # Jan 1st:  Day 0
    # Feb 29th: Day 59
    # Dec 31st: Day 365
    with plate("plate_day_of_year", 366):
        lam = sample("lam", dist.HalfCauchy(scale=1))
        lam_tilde = jnp.sqrt(c) * lam / jnp.sqrt(c + (tau * lam)**2)
        beta = sample("beta", dist.Normal(loc=0, scale=tau * lam_tilde))

    return beta[day_of_year]
示例#28
0
def ar_k(n_coefs, obs=None, X=None):

    beta = numpyro.sample(name="beta",
                          sample_shape=(n_coefs, ),
                          fn=dist.TransformedDistribution(
                              dist.Normal(loc=0., scale=1),
                              transforms=dist.transforms.AffineTransform(
                                  loc=0,
                                  scale=1,
                                  domain=dist.constraints.interval(-1, 1))))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=1))

    z_init = numpyro.sample(name='z_init',
                            fn=dist.Normal(0, 1),
                            sample_shape=(n_coefs, ))

    obs_init = z_init[:n_coefs - 1]

    (beta, obs_last), zs_exp = scan_fn(n_coefs, beta, obs_init, obs)
    Z_exp = np.concatenate((z_init, zs_exp), axis=0)

    Z = numpyro.sample(name="Z", fn=dist.Normal(loc=Z_exp, scale=tau), obs=obs)

    return Z_exp, obs_last
示例#29
0
 def model(y_obs):
     mu = numpyro.sample('mu', dist.Normal(0., 1.))
     sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
     numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)
示例#30
0
def horseshoe_model(
    y_vals,
    gid,
    cid,
    N,  # array of number of y_vals in each gene
    slab_df=1,
    slab_scale=1,
    expected_large_covar_num=5,  # expected large covar num here is the prior on the number of conditions we expect to affect expression of a given gene
    condition_intercept=False):

    gene_count = gid.max() + 1
    condition_count = cid.max() + 1

    # separate regularizing prior on intercept for each gene
    a_prior = dist.Normal(10., 10.)
    a = numpyro.sample("alpha", a_prior, sample_shape=(gene_count, ))

    # implement Finnish horseshoe
    half_slab_df = slab_df / 2
    variance = y_vals.var()
    slab_scale2 = slab_scale**2
    hs_shape = (gene_count, condition_count)

    # set up "local" horseshoe priors for each gene and condition
    beta_tilde = numpyro.sample(
        'beta_tilde', dist.Normal(0., 1.), sample_shape=hs_shape
    )  # beta_tilde contains betas for all hs parameters
    lambd = numpyro.sample(
        'lambd', dist.HalfCauchy(1.),
        sample_shape=hs_shape)  # lambd contains lambda for each hs covariate
    # set up global hyperpriors.
    # each gene gets its own hyperprior for regularization of large effects to keep the sampling from wandering unfettered from 0.
    tau_tilde = numpyro.sample('tau_tilde',
                               dist.HalfCauchy(1.),
                               sample_shape=(gene_count, 1))
    c2_tilde = numpyro.sample('c2_tilde',
                              dist.InverseGamma(half_slab_df, half_slab_df),
                              sample_shape=(gene_count, 1))

    bC = finnish_horseshoe(
        M=hs_shape[1],  # total number of conditions
        m0=
        expected_large_covar_num,  # number of condition we expect to affect expression of a given gene
        N=N,  # number of observations for the gene
        var=variance,
        half_slab_df=half_slab_df,
        slab_scale2=slab_scale2,
        tau_tilde=tau_tilde,
        c2_tilde=c2_tilde,
        lambd=lambd,
        beta_tilde=beta_tilde)
    numpyro.sample("b_condition", dist.Delta(bC), obs=bC)

    if condition_intercept:
        a_C_prior = dist.Normal(0., 1.)
        a_C = numpyro.sample('a_condition',
                             a_C_prior,
                             sample_shape=(condition_count, ))

        mu = a[gid] + a_C[cid] + bC[gid, cid]

    else:
        # calculate implied log2(signal) for each gene/condition
        #   by adding each gene's intercept (a) to each of that gene's
        #   condition effects (bC).
        mu = a[gid] + bC[gid, cid]

    sig_prior = dist.Exponential(1.)
    sigma = numpyro.sample('sigma', sig_prior)
    return numpyro.sample('obs', dist.Normal(mu, sigma), obs=y_vals)