示例#1
0
def generate_sin(n_obs, noise_level, seed):
    """
    Generate data (noisy and non-noisy from 0.5*sin(3*x)
    :param n_obs:
    :param noise_level:
    :return: tuple of noisy data and original data (without noise)
            data = {'X': X, 'y': y}
    """
    # jax random generator
    rng_key1, rng_key2 = random.split(random.PRNGKey(seed))
    # uniform sample from X
    X = dist.Uniform(0.0, 5.0).sample(rng_key1, sample_shape=(n_obs, 1))
    # generate noise
    noise = dist.Uniform(-noise_level, noise_level).sample(
        rng_key2, sample_shape=(X.shape[0],)
    )
    # generate y
    y = 0.5 * np.sin(3 * X[:, 0]) + noise
    noisy_data = {
        "X": X,
        "y": y,
    }

    # generate real observation from 0.5*sin(3*x)
    X_ = np.linspace(0.0, 5.0, 400)
    y_ = 0.5 * np.sin(3 * X_)
    data = {
        "X": X_,
        "y": y_,
    }

    return noisy_data, data
示例#2
0
 def model(data):
     alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
     with handlers.reparam(config={'loc': TransformReparam()}):
         loc = numpyro.sample('loc', dist.TransformedDistribution(
             dist.Uniform(0, 1).mask(False),
             AffineTransform(0, alpha)))
     numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
示例#3
0
def model_3(capture_history, sex):
    N, T = capture_history.shape
    phi_mean = numpyro.sample("phi_mean", dist.Uniform(0.0, 1.0))  # mean survival probability
    phi_logit_mean = logit(phi_mean)
    # controls temporal variability of survival probability
    phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0))
    rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}):
            phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma))
        phi_t = expit(phi_logit_t)
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y)

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None

    z = jnp.ones(N, dtype=jnp.int32)
    # we use this mask to eliminate extraneous log probabilities
    # that arise for a given individual before its first capture.
    first_capture_mask = capture_history[:, 0].astype(bool)
    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
    scan(transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1))
示例#4
0
def _init_to_uniform(site, radius=2, skip_param=False):
    if site['type'] == 'sample' and not site['is_observed']:
        if isinstance(site['fn'], dist.TransformedDistribution):
            fn = site['fn'].base_dist
        else:
            fn = site['fn']
        value = numpyro.sample('_init',
                               fn,
                               sample_shape=site['kwargs']['sample_shape'])
        base_transform = biject_to(fn.support)
        unconstrained_value = numpyro.sample('_unconstrained_init',
                                             dist.Uniform(-radius, radius),
                                             sample_shape=np.shape(
                                                 base_transform.inv(value)))
        return base_transform(unconstrained_value)

    if site['type'] == 'param' and not skip_param:
        # return base value of param site
        constraint = site['kwargs'].pop('constraint', real)
        transform = biject_to(constraint)
        value = site['args'][0]
        unconstrained_value = numpyro.sample('_unconstrained_init',
                                             dist.Uniform(-radius, radius),
                                             sample_shape=np.shape(
                                                 transform.inv(value)))
        if isinstance(transform, ComposeTransform):
            base_transform = transform.parts[0]
        else:
            base_transform = transform
        return base_transform(unconstrained_value)
示例#5
0
def model_1(capture_history, sex):
    N, T = capture_history.shape
    phi = numpyro.sample("phi", dist.Uniform(0.0, 1.0))  # survival probability
    rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample(
                    "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                )

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None

    z = jnp.ones(N, dtype=jnp.int32)
    # we use this mask to eliminate extraneous log probabilities
    # that arise for a given individual before its first capture.
    first_capture_mask = capture_history[:, 0].astype(bool)
    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
    scan(
        transition_fn,
        (first_capture_mask, z),
        jnp.swapaxes(capture_history[:, 1:], 0, 1),
    )
示例#6
0
 def model(data):
     alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
     with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
         loc = numpyro.sample(
             "loc",
             dist.TransformedDistribution(
                 dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)),
         )
     numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
示例#7
0
 def model(x):
     numpyro.sample(
         "x",
         dist.Normal(
             numpyro.sample("loc", dist.Uniform(0, 20)),
             numpyro.sample("scale", dist.Uniform(0, 20)),
         ),
         obs=x,
     )
示例#8
0
def model(X, y=None):
    sd_randwalk = numpyro.sample('sd_randwalk', dist.Uniform(low=0,
                                                             high=100.0))
    randwalk = numpyro.sample('mu', dist.Normal(loc=0, scale=sd_randwalk))
    value = X + randwalk
    sd_value = numpyro.sample('sd', dist.Uniform(low=0.0, high=100.0))
    pred_y = numpyro.sample('pred_y',
                            dist.Normal(loc=value, scale=sd_value),
                            obs=y)
示例#9
0
def sgt(y: jnp.ndarray, seasonality: int, future: int = 0) -> None:

    cauchy_sd = jnp.max(y) / 150

    nu = numpyro.sample("nu", dist.Uniform(2, 20))
    powx = numpyro.sample("powx", dist.Uniform(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
    offset_sigma = numpyro.sample(
        "offset_sigma",
        dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd))

    coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
    pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
    pow_trend = 1.5 * pow_trend_beta - 0.5
    pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

    level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
    s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
    init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))

    num_lim = y.shape[0]

    def transition_fn(
        carry: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], t: jnp.ndarray
    ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]:

        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = jnp.clip(exp_val, a_min=0)
        y_t = jnp.where(t >= num_lim, exp_val, y[t])

        moving_sum = moving_sum + y[t] - jnp.where(t >= seasonality,
                                                   y[t - seasonality], 0.0)
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality,
                            y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        new_s = jnp.where(t >= num_lim, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        omega = sigma * exp_val**powx + offset_sigma
        y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega))

        return (level, s, moving_sum), y_

    level_init = y[0]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:]}):
        _, ys = scan(transition_fn, (level_init, s_init, moving_sum),
                     jnp.arange(1, num_lim + future))

    numpyro.deterministic("y_forecast", ys)
示例#10
0
 def actual_model(data):
     alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
     with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
         loc = numpyro.sample(
             "loc",
             dist.TransformedDistribution(
                 dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)),
         )
     with numpyro.plate("N", len(data)):
         numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
示例#11
0
def model_c(nu1, y1):
    Rp = numpyro.sample('Rp', dist.Uniform(0.4, 1.2))
    RV = numpyro.sample('RV', dist.Uniform(5.0, 15.0))
    MMR_CO = numpyro.sample('MMR_CO', dist.Uniform(0.0, 0.015))
    vsini = numpyro.sample('vsini', dist.Uniform(15.0, 25.0))
    g = 2478.57730044555 * Mp / Rp**2  # gravity
    u1 = 0.0
    u2 = 0.0

    # Layer-by-layer T-P model//
    lnsT = 6.0
    #    lnsT = numpyro.sample('lnsT', dist.Uniform(3.0,5.0))
    sT = 10**lnsT
    lntaup = 0.5
    #    lntaup =  numpyro.sample('lntaup', dist.Uniform(0,1))
    taup = 10**lntaup
    cov = modelcov(lnParr, taup, sT)

    #    T0=numpyro.sample('T0', dist.Uniform(1000.0,1100.0))
    T0 = numpyro.sample('T0', dist.Uniform(1000, 2000))
    Tarr = numpyro.sample(
        'Tarr', dist.MultivariateNormal(loc=ONEARR,
                                        covariance_matrix=cov)) + T0
    # line computation CO
    qt_CO = vmap(mdbCO.qr_interp)(Tarr)

    def obyo(y, tag, nusd, nus, numatrix_CO, mdbCO, cdbH2H2):
        # CO
        SijM_CO = jit(vmap(SijT,
                           (0, None, None, None, 0)))(Tarr, mdbCO.logsij0,
                                                      mdbCO.dev_nu_lines,
                                                      mdbCO.elower, qt_CO)
        gammaLMP_CO = jit(vmap(gamma_exomol,
                               (0, 0, None, None)))(Parr, Tarr, mdbCO.n_Texp,
                                                    mdbCO.alpha_ref)
        gammaLMN_CO = gamma_natural(mdbCO.A)
        gammaLM_CO = gammaLMP_CO + gammaLMN_CO[None, :]

        sigmaDM_CO = jit(vmap(doppler_sigma,
                              (None, 0, None)))(mdbCO.dev_nu_lines, Tarr,
                                                molmassCO)
        xsm_CO = xsmatrix(numatrix_CO, sigmaDM_CO, gammaLM_CO, SijM_CO)
        dtaumCO = dtauM(dParr, xsm_CO, MMR_CO * ONEARR, molmassCO, g)
        # CIA
        dtaucH2H2 = dtauCIA(nus, Tarr, Parr, dParr, vmrH2, vmrH2, mmw, g,
                            cdbH2H2.nucia, cdbH2H2.tcia, cdbH2H2.logac)
        dtau = dtaumCO + dtaucH2H2
        sourcef = planck.piBarr(Tarr, nus)
        F0 = rtrun(dtau, sourcef) / norm

        Frot = response.rigidrot(nus, F0, vsini, u1, u2)
        mu = response.ipgauss_sampling(nusd, nus, Frot, beta, RV)
        numpyro.sample(tag, dist.Normal(mu, sigmain), obs=y)

    obyo(y1, 'y1', nu1, nus, numatrix_CO, mdbCO, cdbH2H2)
示例#12
0
def test_elbo_dynamic_support():
    x_prior = dist.TransformedDistribution(
        dist.Normal(),
        [AffineTransform(0, 2),
         SigmoidTransform(),
         AffineTransform(0, 3)])
    x_guide = dist.Uniform(0, 3)

    def model():
        numpyro.sample('x', x_prior)

    def guide():
        numpyro.sample('x', x_guide)

    adam = optim.Adam(0.01)
    # set base value of x_guide is 0.9
    x_base = 0.9
    guide = substitute(guide, base_param_map={'x': x_base})
    svi = SVI(model, guide, elbo, adam)
    svi_state = svi.init(random.PRNGKey(0), (), ())
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)
    x, _ = x_guide.transform_with_intermediates(x_base)
    expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
    assert_allclose(actual_loss, expected_loss)
示例#13
0
 def model(data):
     alpha = 1 / jnp.mean(data.astype(np.float32))
     lambda1 = numpyro.sample("lambda1", dist.Exponential(alpha))
     lambda2 = numpyro.sample("lambda2", dist.Exponential(alpha))
     tau = numpyro.sample("tau", dist.Uniform(0, 1))
     lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
     numpyro.sample("obs", dist.Poisson(lambda12), obs=data)
示例#14
0
 def model(data):
     alpha = 1 / jnp.mean(data)
     lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
     lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
     tau = numpyro.sample('tau', dist.Uniform(0, 1))
     lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
     numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
示例#15
0
    def init_agents(self,
                    num_steps=1000,
                    pop_size=(1e2, 1e2),
                    initial_infections=2):
        self.age = ny.sample('age', dist.Uniform(0, 100))
        self.sex = ny.sample('sex', dist.Binomial(1, .5))
        self.risk_tolerance = ny.sample('risk', dist.Beta(2, 5))
        # self.risk_factors = ny.sample('health', dist.Binomial(5, .3))
        self.hygiene = ny.sample('hygiene', dist.Beta(2, 5))
        # self.worker_type = ny.sample('worker_type', dist.Categorical((.6, .1, .2, .1)))

        self.epidemic_state = ny.sample(
            'state',
            dist.Binomial(1, initial_infections / pop_size[0] * pop_size[1]))
        self.social_radius = ny.sample('radius', dist.Binomial(10, .2))
        self.base_isolation = ny.sample('base_isolation', dist.Beta(2, 2))

        # TODO: make these depend on risk factors as well
        self.will_be_hospitalized = ny.sample(
            'hosp', dist.Binomial(1, self.params.HOSP_AGE_MAP[self.age]))
        self.will_die = ny.sample(
            'die', dist.Binomial(1, self.params.DEATH_MAP[self.age]))

        # The lengths of the infection are handled on a per agent basis via scenarios, these are just placeholders.
        self.date_infected = np.where(
            self.epidemic_state > 0, np.zeros(shape=pop_size),
            np.full(shape=pop_size, fill_value=np.inf))

        self.date_contagious = np.where(
            self.epidemic_state > 0, np.ceil(self.params.EXPOSED_PERIOD),
            np.full(shape=pop_size, fill_value=np.inf))
        self.date_symptomatic = np.full(shape=pop_size, fill_value=np.inf)
        self.date_recovered = np.full(shape=pop_size, fill_value=np.inf)
        self.date_hospitalized = np.full(shape=pop_size, fill_value=np.inf)
        self.date_died = np.full(shape=pop_size, fill_value=np.inf)
示例#16
0
def sample_y(dist_y, theta, y, sigma_obs=None):
    if not sigma_obs:
        if dist_y == 'gamma':
            sigma_obs = numpyro.sample('sigma_obs', dist.Exponential(1))
        else:
            sigma_obs = numpyro.sample('sigma_obs', dist.HalfNormal(1))

    if dist_y == 'student':
        numpyro.sample('y', dist.StudentT(numpyro.sample('nu_y', dist.Gamma(1, .1)), theta, sigma_obs), obs=y)
    elif dist_y == 'normal':
        numpyro.sample('y', dist.Normal(theta, sigma_obs), obs=y)
    elif dist_y == 'lognormal':
        numpyro.sample('y', dist.LogNormal(theta, sigma_obs), obs=y)
    elif dist_y == 'gamma':
        numpyro.sample('y', dist.Gamma(jnp.exp(theta), sigma_obs), obs=y)
    elif dist_y == 'gamma_raw':
        numpyro.sample('y', dist.Gamma(theta, sigma_obs), obs=y)
    elif dist_y == 'poisson':
        numpyro.sample('y', dist.Poisson(theta), obs=y)
    elif dist_y == 'exponential':
        numpyro.sample('y', dist.Exponential(jnp.exp(theta)), obs=y)
    elif dist_y == 'exponential_raw':
        numpyro.sample('y', dist.Exponential(theta), obs=y)
    elif dist_y == 'uniform':
        numpyro.sample('y', dist.Uniform(0, 1), obs=y)
    else:
        raise NotImplementedError
示例#17
0
def init_to_uniform(site=None, radius=2):
    """
    Initialize to a random point in the area `(-radius, radius)` of unconstrained domain.

    :param float radius: specifies the range to draw an initial point in the unconstrained domain.
    """
    if site is None:
        return partial(init_to_uniform, radius=radius)

    if (site["type"] == "sample" and not site["is_observed"]
            and not site["fn"].support.is_discrete):
        if site["value"] is not None:
            warnings.warn(
                f"init_to_uniform() skipping initialization of site '{site['name']}'"
                " which already stores a value.",
                stacklevel=find_stack_level(),
            )
            return site["value"]

        # XXX: we import here to avoid circular import
        from numpyro.infer.util import helpful_support_errors

        rng_key = site["kwargs"].get("rng_key")
        sample_shape = site["kwargs"].get("sample_shape")

        with helpful_support_errors(site):
            transform = biject_to(site["fn"].support)
        unconstrained_shape = transform.inverse_shape(site["fn"].shape())
        unconstrained_samples = dist.Uniform(-radius, radius)(
            rng_key=rng_key, sample_shape=sample_shape + unconstrained_shape)
        return transform(unconstrained_samples)
示例#18
0
def pacs_model(priors):
    pointing_matrices = [([p.amat_row, p.amat_col], p.amat_data)
                         for p in priors]
    flux_lower = np.asarray([p.prior_flux_lower for p in priors]).T
    flux_upper = np.asarray([p.prior_flux_upper for p in priors]).T

    bkg_mu = np.asarray([p.bkg[0] for p in priors]).T
    bkg_sig = np.asarray([p.bkg[1] for p in priors]).T

    with numpyro.plate('bands', len(priors)):
        sigma_conf = numpyro.sample('sigma_conf', dist.HalfCauchy(1.0, 0.5))
        bkg = numpyro.sample('bkg', dist.Normal(bkg_mu, bkg_sig))

        with numpyro.plate('nsrc', priors[0].nsrc):
            src_f = numpyro.sample('src_f',
                                   dist.Uniform(flux_lower, flux_upper))
    db_hat_psw = sp_matmul(pointing_matrices[0], src_f[:, 0][:, None],
                           priors[0].snpix).reshape(-1) + bkg[0]
    db_hat_pmw = sp_matmul(pointing_matrices[1], src_f[:, 1][:, None],
                           priors[1].snpix).reshape(-1) + bkg[1]

    sigma_tot_psw = jnp.sqrt(
        jnp.power(priors[0].snim, 2) + jnp.power(sigma_conf[0], 2))
    sigma_tot_pmw = jnp.sqrt(
        jnp.power(priors[1].snim, 2) + jnp.power(sigma_conf[1], 2))

    with numpyro.plate('psw_pixels', priors[0].sim.size):  # as ind_psw:
        numpyro.sample("obs_psw",
                       dist.Normal(db_hat_psw, sigma_tot_psw),
                       obs=priors[0].sim)
    with numpyro.plate('pmw_pixels', priors[1].sim.size):  # as ind_pmw:
        numpyro.sample("obs_pmw",
                       dist.Normal(db_hat_pmw, sigma_tot_pmw),
                       obs=priors[1].sim)
示例#19
0
def model(center1, center2, radius, width, enum=False):
    z = numpyro.sample("z",
                       dist.Bernoulli(0.5),
                       infer={"enumerate": "parallel"} if enum else {})
    x = numpyro.sample("x", dist.Uniform(-6.0, 6.0).expand([2]).to_event(1))
    center = jnp.stack([center1, center2])[z]
    numpyro.sample("shell", GaussianShell(center, radius, width), obs=x)
示例#20
0
def model_c(t1, y1, e1):
    P = numpyro.sample('P', dist.Uniform(8.0, 12.0))
    # should be modified Jeffery later
    Ksini = numpyro.sample('Ksini', dist.Exponential(0.1))
    T0 = numpyro.sample('T0', dist.Uniform(-6.0, 6.0))
    sesinw = numpyro.sample('sesinw', dist.Uniform(-1.0, 1.0))
    secosw = numpyro.sample('secosw', dist.Uniform(-1.0, 1.0))
    etmp = sesinw**2 + secosw**2
    e = jnp.where(etmp > 1.0, 1.0, etmp)
    omegaA = jnp.arctan2(sesinw, secosw)
    #    sigmajit=numpyro.sample('sigmajit', dist.Uniform(0.1,100.0))
    sigmajit = numpyro.sample('sigmajit', dist.Exponential(1.0))
    Vsys = numpyro.sample('Vsys', dist.Uniform(-10, 10.0))
    mu = rvf(t1, T0, P, e, omegaA, Ksini, Vsys)
    errall = jnp.sqrt(e1**2 + sigmajit**2)
    numpyro.sample('y1', dist.Normal(mu, errall), obs=y1)  # -
示例#21
0
def init_to_uniform(site=None, radius=2):
    """
    Initialize to a random point in the area `(-radius, radius)` of unconstrained domain.

    :param float radius: specifies the range to draw an initial point in the unconstrained domain.
    """
    if site is None:
        return partial(init_to_uniform, radius=radius)

    if site['type'] == 'sample' and not site['is_observed'] and not site['fn'].is_discrete:
        rng_key = site['kwargs'].get('rng_key')
        sample_shape = site['kwargs'].get('sample_shape')
        rng_key, subkey = random.split(rng_key)

        # this is used to interpret the changes of event_shape in
        # domain and codomain spaces
        try:
            prototype_value = site['fn'].sample(subkey, sample_shape=())
        except NotImplementedError:
            # XXX: this works for ImproperUniform prior,
            # we can't use this logic for general priors
            # because some distributions such as TransformedDistribution might
            # have wrong event_shape.
            prototype_value = jnp.full(site['fn'].shape(), jnp.nan)

        transform = biject_to(site['fn'].support)
        unconstrained_shape = jnp.shape(transform.inv(prototype_value))
        unconstrained_samples = dist.Uniform(-radius, radius).sample(
            rng_key, sample_shape=sample_shape + unconstrained_shape)
        return transform(unconstrained_samples)
def test_elbo_dynamic_support():
    x_prior = dist.Uniform(0, 5)
    x_unconstrained = 2.

    def model():
        numpyro.sample('x', x_prior)

    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {'_auto_latent': x_unconstrained})(*args, **kwargs)

    adam = optim.Adam(0.01)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)

    guide_log_prob = dist.Normal(
        guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum()
    transfrom = transforms.biject_to(constraints.interval(0, 5))
    x = transfrom(x_unconstrained)
    logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x)
    model_log_prob = x_prior.log_prob(x) + logdet
    expected_loss = guide_log_prob - model_log_prob
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
示例#23
0
def model(X, y=None):
    ndims = np.shape(X)[-1]
    ws = numpyro.sample('betas', dist.Normal(0.0,10.0*np.ones(ndims)))
    b = numpyro.sample('b', dist.Normal(0.0, 10.0))
    sigma = numpyro.sample('sigma', dist.Uniform(0.0, 10.0))
    f = numpyro.sample('f', dist.Normal(0.0, 2.5))
    mu = f * (X @ ws + b)
    return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
示例#24
0
def partially_pooled(at_bats: jnp.ndarray, hits: Optional[jnp.ndarray] = None) -> None:

    m = numpyro.sample("m", dist.Uniform(0, 1))
    kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        phi = numpyro.sample("phi", dist.Beta(m * kappa, (1 - m) * kappa))
        numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
示例#25
0
def model(data):
    xc = numpyro.sample('xc', dist.Normal(0, 10))
    yc = numpyro.sample('yc', dist.Normal(0, 10))
    w = numpyro.sample('w', dist.LogNormal(0, 10))
    h = numpyro.sample('h', dist.LogNormal(0, 10))
    phi = numpyro.sample('phi', dist.Uniform(0, np.pi))
    px, py = forward(xc, yc, w, h, phi)
    sigma = numpyro.sample('sigma', dist.Exponential(1.))
    return numpyro.sample('obs', dist.Normal(np.vstack([px, py]), sigma), obs=data)
示例#26
0
def comp_Tarr(okey):
    okey, key = random.split(okey)
    lnsT = numpyro.sample('lnsT', dist.Uniform(3.0, 5.0), rng_key=key)
#    lnsT=4.0
    sT = 10**lnsT
    okey, key = random.split(okey)
    lntaup = numpyro.sample('lntaup', dist.Uniform(0, 1), rng_key=key)
#    lntaup=0.5
    taup = 10**lntaup
    cov = modelcov(lnParr, taup, sT)

    okey, key = random.split(okey)
    T0 = numpyro.sample('T0', dist.Uniform(800, 1500), rng_key=key)
    okey, key = random.split(okey)
    Tarr = numpyro.sample('Tarr', dist.MultivariateNormal(
        loc=ONEARR, covariance_matrix=cov), rng_key=key)+T0

    # lnT0=3.0 #1000K
    #lnTarr=numpyro.sample("Tarr", dist.MultivariateNormal(loc=lnT0*ONEARR, covariance_matrix=cov),rng_key=key)
    # Tarr=10**lnTarr
    return Tarr
示例#27
0
def fully_pooled(at_bats, hits=None):
    r"""
    Number of hits in $K$ at bats for each player has a Binomial
    distribution with a common probability of success, $\phi$.

    :param (np.DeviceArray) at_bats: Number of at bats for each player.
    :param (np.DeviceArray) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    phi_prior = dist.Uniform(np.array([0.]), np.array([1.]))
    phi = sample("phi", phi_prior)
    return sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
示例#28
0
def sample_data2():
    N = 100  # number of individuals
    # sim total height of each
    height = dist.Normal(10, 2).sample(random.PRNGKey(0), (N,))
    # leg as proportion of height
    leg_prop = dist.Uniform(0.4, 0.5).sample(random.PRNGKey(1), (N,))
    # sim left leg as proportion + error
    leg_left = leg_prop * height + dist.Normal(0, 0.02).sample(random.PRNGKey(2), (N,))
    # sim right leg as proportion + error
    leg_right = leg_prop * height + dist.Normal(0, 0.02).sample(random.PRNGKey(3), (N,))
    # combine into data frame
    d = pd.DataFrame({"height": height, "leg_left": leg_left, "leg_right": leg_right})
    return d
示例#29
0
def model_4(capture_history, sex):
    N, T = capture_history.shape
    # survival probabilities for males/females
    phi_male = numpyro.sample("phi_male", dist.Uniform(0.0, 1.0))
    phi_female = numpyro.sample("phi_female", dist.Uniform(0.0, 1.0))
    # we construct a N-dimensional vector that contains the appropriate
    # phi for each individual given its sex (female = 0, male = 1)
    phi = sex * phi_male + (1.0 - sex) * phi_female
    rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample(
                    "z",
                    dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
                    infer={"enumerate": "parallel"},
                )
                mu_y_t = rho * z
                numpyro.sample(
                    "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                )

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None

    z = jnp.ones(N, dtype=jnp.int32)
    # we use this mask to eliminate extraneous log probabilities
    # that arise for a given individual before its first capture.
    first_capture_mask = capture_history[:, 0].astype(bool)
    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
    scan(
        transition_fn,
        (first_capture_mask, z),
        jnp.swapaxes(capture_history[:, 1:], 0, 1),
    )
示例#30
0
    def sample(self):
        rng_key, rng_key_sample, rng_key_accept = split(
            self.nmc_status.rng_key, 3)
        params = self.nmc_status.params

        for site in params.keys():
            # Collect accepted trace
            for i in range(len(params[site])):
                self.acc_trace[site + str(i)].append(params[site][i])

            tr_current = trace(substitute(self.model, params)).get_trace(
                *self.model_args, **self.model_kwargs)
            ll_current = self.nmc_status.log_likelihood

            val_current = tr_current[site]["value"]
            dist_curr = tr_current[site]["fn"]

            def log_den_fun(var):
                return partial(log_density, self.model, self.model_args,
                               self.model_kwargs)(var)[0]

            val_proposal, dist_proposal = self.proposal(
                site, log_den_fun, self.get_params(tr_current), dist_curr,
                rng_key_sample)

            tr_proposal = self.retrace(site, tr_current, dist_proposal,
                                       val_proposal, self.model_args,
                                       self.model_kwargs)
            ll_proposal = log_density(self.model, self.model_args,
                                      self.model_kwargs,
                                      self.get_params(tr_proposal))[0]

            ll_proposal_val = dist_proposal.log_prob(val_current).sum()
            ll_current_val = dist_curr.log_prob(val_proposal).sum()

            hastings_ratio = (ll_proposal + ll_proposal_val) - \
                (ll_current + ll_current_val)

            accept_prob = np.minimum(1, np.exp(hastings_ratio))
            u = sample("u", dist.Uniform(0, 1), rng_key=rng_key_accept)

            if u <= accept_prob:
                params, ll_current = self.get_params(tr_proposal), ll_proposal
            else:
                params, ll_current = self.get_params(tr_current), ll_current

        iter = self.nmc_status.i + 1
        mean_accept_prob = self.nmc_status.accept_prob + \
            (accept_prob - self.nmc_status.accept_prob) / iter

        return NMC_STATUS(iter, params, ll_current, mean_accept_prob, rng_key)