def model(X, Y, hypers):
    S, P, N = hypers['expected_sparsity'], X.shape[1], X.shape[0]

    sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['alpha3']))
    phi = sigma * (S / np.sqrt(N)) / (P - S)
    eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))

    msq = numpyro.sample("msq",
                         dist.InverseGamma(hypers['alpha1'], hypers['beta1']))
    xisq = numpyro.sample("xisq",
                          dist.InverseGamma(hypers['alpha2'], hypers['beta2']))

    eta2 = np.square(eta1) * np.sqrt(xisq) / msq

    lam = numpyro.sample("lambda", dist.HalfCauchy(np.ones(P)))
    kappa = np.sqrt(msq) * lam / np.sqrt(msq + np.square(eta1 * lam))

    # sample observation noise
    var_obs = numpyro.sample(
        "var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs']))

    # compute kernel
    kX = kappa * X
    k = kernel(kX, kX, eta1, eta2, hypers['c']) + var_obs * np.eye(N)
    assert k.shape == (N, N)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y",
                   dist.MultivariateNormal(loc=np.zeros(X.shape[0]),
                                           covariance_matrix=k),
                   obs=Y)
Example #2
0
def model(X, Y, hypers):
    S, P, N = hypers["expected_sparsity"], X.shape[1], X.shape[0]

    sigma = numpyro.sample("sigma", dist.HalfNormal(hypers["alpha3"]))
    phi = sigma * (S / jnp.sqrt(N)) / (P - S)
    eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))

    msq = numpyro.sample("msq",
                         dist.InverseGamma(hypers["alpha1"], hypers["beta1"]))
    xisq = numpyro.sample("xisq",
                          dist.InverseGamma(hypers["alpha2"], hypers["beta2"]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq

    lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P)))
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    # compute kernel
    kX = kappa * X
    k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N)
    assert k.shape == (N, N)

    # sample Y according to the standard gaussian process formula
    numpyro.sample(
        "Y",
        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]),
                                covariance_matrix=k),
        obs=Y,
    )
Example #3
0
def nondynamic_mixture_model(beliefs, y, mask):
    M, T, _ = beliefs[0].shape

    tau = .5
    weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

    assert weights.shape == (M, )

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 5.))

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))

    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    U = jnp.log(p0)

    def transition_fn(carry, t):

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs).mask(mask[t])
        npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist))

        return None, None

    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, None, jnp.arange(T))
Example #4
0
def nondyn_single_model(beliefs, y, mask):
    T, _ = beliefs[0].shape

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 3.))

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))
    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    U = jnp.log(p0)

    def transition_fn(carry, t):

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))
        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))

        return None, None

    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, None, jnp.arange(T))
def model(z=None) -> None:
    batch_size = 1
    if z is not None:
        batch_size = z.shape[0]

    mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1))
    sigma = sample('sigma',
                   dists.InverseGamma(1.).expand_by((2, )).to_event(1))
    with plate('batch', batch_size, batch_size):
        sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
    def guide(z=None, num_obs_total=None) -> None:
        batch_size = 1
        if z is not None:
            batch_size = z.shape[0]
        if num_obs_total is None:
            num_obs_total = batch_size

        mu_param = param('mu_param', 0.)
        sample('mu', dists.Normal(mu_param, 1.).expand_by((d, )).to_event(1))
        sample('sigma', dists.InverseGamma(1.).expand_by((d, )).to_event(1))
Example #7
0
    def model(z = None, num_obs_total = None) -> None:
        batch_size = 1
        if z is not None:
            batch_size = z.shape[0]
        if num_obs_total is None:
            num_obs_total = batch_size

        mu = sample('mu', dists.Normal(args.prior_mu).expand_by((d,)).to_event(1))
        sigma = sample('sigma', dists.InverseGamma(1.).expand_by((d,)).to_event(1))
        with plate('batch', num_obs_total, batch_size):
            sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
def model(x_first=None, x_second=None, num_obs_total=None) -> None:
    batch_size = 1
    if x_first is not None:
        batch_size = x_first.shape[0]
    if num_obs_total is None:
        num_obs_total = batch_size

    mu = sample('mu', dists.Normal())
    sigma = sample('sigma', dists.InverseGamma(1.))
    with plate('batch', num_obs_total, batch_size):
        sample('x_first', dists.Normal(mu, sigma), obs=x_first)
        sample('x_second', dists.Normal(mu, sigma), obs=x_second)
Example #9
0
def model(z=None, z2=None, num_obs_total=None) -> None:
    batch_size = 1
    if z is not None:
        batch_size = z.shape[0]
        assert (z.shape is not None)
        assert (z.shape[0] == z2.shape[0])
    if num_obs_total is None:
        num_obs_total = batch_size

    mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1))
    sigma = sample('sigma',
                   dists.InverseGamma(1.).expand_by((2, )).to_event(1))
    with plate('batch', num_obs_total, batch_size):
        sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
Example #10
0
def gammadyn_mixture_model(beliefs, y, mask):
    M, T, _ = beliefs[0].shape

    tau = .5
    weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

    assert weights.shape == (M, )

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.))
    mu = jnp.log(jnp.exp(gamma) - 1)

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))
    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    scale = npyro.sample('scale', dist.Gamma(1., 1.))
    rho = npyro.sample('rho', dist.Beta(1., 2.))
    sigma = jnp.sqrt(-(1 - rho**2) / (2 * jnp.log(rho))) * scale

    U = jnp.log(p0)

    def transition_fn(carry, t):
        x_prev = carry

        gamma_dyn = npyro.deterministic('gamma_dyn', nn.softplus(mu + x_prev))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma_dyn, -1), jnp.expand_dims(U, -2))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs).mask(mask[t])
        npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist))

        with npyro.handlers.reparam(
                config={"x_next": npyro.infer.reparam.TransformReparam()}):
            affine = dist.transforms.AffineTransform(rho * x_prev, sigma)
            x_next = npyro.sample(
                'x_next',
                dist.TransformedDistribution(dist.Normal(0., 1.), affine))

        return (x_next), None

    x0 = jnp.zeros(1)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, (x0), jnp.arange(T))
Example #11
0
def guide(k, obs=None, num_obs_total=None, d=None):
    # the latent MixGaus distribution which learns the parameters
    if obs is not None:
        assert(jnp.ndim(obs) == 2)
        _, d = jnp.shape(obs)
    else:
        assert(num_obs_total is not None)
        assert(d is not None)

    alpha_log = param('alpha_log', jnp.zeros(k))
    alpha = jnp.exp(alpha_log)
    pis = sample('pis', dist.Dirichlet(alpha))

    mus_loc = param('mus_loc', jnp.zeros((k, d)))
    mus = sample('mus', dist.Normal(mus_loc, 1.))
    sigs = sample('sigs', dist.InverseGamma(1., 1.), obs=jnp.ones_like(mus))
    return pis, mus, sigs
Example #12
0
def model(k, obs=None, num_obs_total=None, d=None):
    # this is our model function using the GaussianMixture distribution
    # with prior belief
    if obs is not None:
        assert(jnp.ndim(obs) == 2)
        batch_size, d = jnp.shape(obs)
    else:
        assert(num_obs_total is not None)
        batch_size = num_obs_total
        assert(d is not None)
    num_obs_total = batch_size if num_obs_total is None else num_obs_total

    pis = sample('pis', dist.Dirichlet(jnp.ones(k)))
    mus = sample('mus', dist.Normal(jnp.zeros((k, d)), 10.))
    sigs = sample('sigs', dist.InverseGamma(1., 1.), sample_shape=jnp.shape(mus))
    with plate('batch', num_obs_total, batch_size):
        return sample('obs', GaussianMixture(mus, sigs, pis), obs=obs, sample_shape=(batch_size,))
Example #13
0
def yearday_effect(day_of_year):
    slab_df = 50  # 100 in original case study
    slab_scale = 2
    scale_global = 0.1
    tau = sample("tau", dist.HalfNormal(
        2 * scale_global))  # Original uses half-t with 100df
    c_aux = sample("c_aux", dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df))
    c = slab_scale * jnp.sqrt(c_aux)

    # Jan 1st:  Day 0
    # Feb 29th: Day 59
    # Dec 31st: Day 365
    with plate("plate_day_of_year", 366):
        lam = sample("lam", dist.HalfCauchy(scale=1))
        lam_tilde = jnp.sqrt(c) * lam / jnp.sqrt(c + (tau * lam)**2)
        beta = sample("beta", dist.Normal(loc=0, scale=tau * lam_tilde))

    return beta[day_of_year]
Example #14
0
def gammadyn_single_model(beliefs, y, mask):
    T, _ = beliefs[0].shape

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.))
    mu = jnp.log(jnp.exp(gamma) - 1)

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))
    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    scale = npyro.sample('scale', dist.Gamma(1., 1.))
    rho = npyro.sample('rho', dist.Beta(1., 2.))
    sigma = jnp.sqrt(-(1 - rho**2) / (2 * jnp.log(rho))) * scale

    U = jnp.log(p0)

    def transition_fn(carry, t):
        x_prev = carry

        dyn_gamma = npyro.deterministic('dyn_gamma', nn.softplus(mu + x_prev))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(dyn_gamma, -1), jnp.expand_dims(U, -2))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))
        noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return x_next, None

    x0 = jnp.zeros(1)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, x0, jnp.arange(T))
Example #15
0
def prefdyn_single_model(beliefs, y, mask):
    T, _ = beliefs[0].shape

    c0 = beliefs[-1]

    lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1))
    lam34 = npyro.sample('lam34', dist.HalfCauchy(1.))

    _lam34 = jnp.expand_dims(lam34, -1)
    lam0 = npyro.deterministic(
        'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1))

    eta = npyro.sample('eta', dist.Beta(1, 10))

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.))

    def transition_fn(carry, t):
        lam_prev = carry

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))

        return lam_next, None

    lam_start = npyro.deterministic('lam_start',
                                    lam0 + jnp.expand_dims(eta, -1) * c0)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, lam_start, jnp.arange(T))
Example #16
0
def trend_gp(x, L, M):
    alpha = sample("alpha", dist.HalfNormal(1.0))
    length = sample("length", dist.InverseGamma(10.0, 2.0))
    f = approx_se_ncp(x, alpha, length, L, M)
    return f
    def model(self, t, X):
        noise = sample('noise',
                       dist.LogNormal(0.0, 1.0),
                       sample_shape=(self.D, ))

        hyp = sample('hyp', dist.Gamma(1.0, 0.5), sample_shape=(self.D, ))
        W = sample('W', dist.LogNormal(0.0, 1.0), sample_shape=(self.D, ))

        m0 = self.M - 1
        sigma = 1

        tau0 = (m0 / (self.M - m0) * (sigma / np.sqrt(1.0 * sum(self.N))))
        tau_tilde = sample('tau_tilde',
                           dist.HalfCauchy(1.),
                           sample_shape=(self.D, ))
        tau = np.repeat(tau0 * tau_tilde, self.M // self.D)

        slab_scale = 1
        slab_scale2 = slab_scale**2

        slab_df = 1
        half_slab_df = slab_df / 2
        c2_tilde = sample('c2_tilde',
                          dist.InverseGamma(half_slab_df, half_slab_df))
        c2 = slab_scale2 * c2_tilde

        lambd = sample('lambd', dist.HalfCauchy(1.), sample_shape=(self.M, ))
        lambd_tilde = tau**2 * c2 * lambd**2 / (c2 + tau**2 * lambd**2)
        par = sample(
            'par',
            dist.MultivariateNormal(np.zeros(self.M, ), np.diag(lambd_tilde)))

        # compute kernel
        K_11 = W[0] * self.RBF(self.t_t[0], self.t_t[0], hyp[0]) + np.eye(
            self.N[0]) * (noise[0] + self.jitter)
        K_22 = W[1] * self.RBF(self.t_t[1], self.t_t[1], hyp[1]) + np.eye(
            self.N[1]) * (noise[1] + self.jitter)
        K_33 = W[2] * self.RBF(self.t_t[2], self.t_t[2], hyp[2]) + np.eye(
            self.N[2]) * (noise[2] + self.jitter)
        K = np.concatenate([
            np.concatenate([
                K_11,
                np.zeros((self.N[0], self.N[1])),
                np.zeros((self.N[0], self.N[2]))
            ],
                           axis=1),
            np.concatenate([
                np.zeros((self.N[1], self.N[0])), K_22,
                np.zeros((self.N[1], self.N[2]))
            ],
                           axis=1),
            np.concatenate([
                np.zeros((self.N[2], self.N[0])),
                np.zeros((self.N[2], self.N[1])), K_33
            ],
                           axis=1)
        ],
                           axis=0)

        # compute mean
        mut = odeint(self.dxdt, self.x0, self.t.flatten(), par)
        mu1 = mut[self.i_t[0], ind[0]] / self.max_X[0]
        mu2 = mut[self.i_t[1], ind[1]] / self.max_X[1]
        mu3 = mut[self.i_t[2], ind[2]] / self.max_X[2]
        mu = np.concatenate((mu1, mu2, mu3), axis=0)
        mu = mu.flatten('F')

        X = np.concatenate((self.X[0], self.X[1], self.X[2]), axis=0)
        X = X.flatten('F')

        # sample X according to the standard gaussian process formula
        sample("X",
               dist.MultivariateNormal(loc=mu, covariance_matrix=K),
               obs=X)
Example #18
0
def horseshoe_model(
    y_vals,
    gid,
    cid,
    N,  # array of number of y_vals in each gene
    slab_df=1,
    slab_scale=1,
    expected_large_covar_num=5,  # expected large covar num here is the prior on the number of conditions we expect to affect expression of a given gene
    condition_intercept=False):

    gene_count = gid.max() + 1
    condition_count = cid.max() + 1

    # separate regularizing prior on intercept for each gene
    a_prior = dist.Normal(10., 10.)
    a = numpyro.sample("alpha", a_prior, sample_shape=(gene_count, ))

    # implement Finnish horseshoe
    half_slab_df = slab_df / 2
    variance = y_vals.var()
    slab_scale2 = slab_scale**2
    hs_shape = (gene_count, condition_count)

    # set up "local" horseshoe priors for each gene and condition
    beta_tilde = numpyro.sample(
        'beta_tilde', dist.Normal(0., 1.), sample_shape=hs_shape
    )  # beta_tilde contains betas for all hs parameters
    lambd = numpyro.sample(
        'lambd', dist.HalfCauchy(1.),
        sample_shape=hs_shape)  # lambd contains lambda for each hs covariate
    # set up global hyperpriors.
    # each gene gets its own hyperprior for regularization of large effects to keep the sampling from wandering unfettered from 0.
    tau_tilde = numpyro.sample('tau_tilde',
                               dist.HalfCauchy(1.),
                               sample_shape=(gene_count, 1))
    c2_tilde = numpyro.sample('c2_tilde',
                              dist.InverseGamma(half_slab_df, half_slab_df),
                              sample_shape=(gene_count, 1))

    bC = finnish_horseshoe(
        M=hs_shape[1],  # total number of conditions
        m0=
        expected_large_covar_num,  # number of condition we expect to affect expression of a given gene
        N=N,  # number of observations for the gene
        var=variance,
        half_slab_df=half_slab_df,
        slab_scale2=slab_scale2,
        tau_tilde=tau_tilde,
        c2_tilde=c2_tilde,
        lambd=lambd,
        beta_tilde=beta_tilde)
    numpyro.sample("b_condition", dist.Delta(bC), obs=bC)

    if condition_intercept:
        a_C_prior = dist.Normal(0., 1.)
        a_C = numpyro.sample('a_condition',
                             a_C_prior,
                             sample_shape=(condition_count, ))

        mu = a[gid] + a_C[cid] + bC[gid, cid]

    else:
        # calculate implied log2(signal) for each gene/condition
        #   by adding each gene's intercept (a) to each of that gene's
        #   condition effects (bC).
        mu = a[gid] + bC[gid, cid]

    sig_prior = dist.Exponential(1.)
    sigma = numpyro.sample('sigma', sig_prior)
    return numpyro.sample('obs', dist.Normal(mu, sigma), obs=y_vals)