示例#1
0
def test_haiku_state_dropout_smoke(dropout, batchnorm):
    import haiku as hk

    def fn(x):
        if dropout:
            x = hk.dropout(hk.next_rng_key(), 0.5, x)
        if batchnorm:
            x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)(
                x, is_training=True
            )
        return x

    def model():
        transform = hk.transform_with_state if batchnorm else hk.transform
        nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3))
        x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
        if dropout:
            y = nn(numpyro.prng_key(), x)
        else:
            y = nn(x)
        numpyro.deterministic("y", y)

    with handlers.trace(model) as tr, handlers.seed(rng_seed=0):
        model()

    if batchnorm:
        assert set(tr.keys()) == {"nn$params", "nn$state", "x", "y"}
        assert tr["nn$state"]["type"] == "mutable"
    else:
        assert set(tr.keys()) == {"nn$params", "x", "y"}
示例#2
0
 def single_prediction(val):
     rng_key, samples = val
     model_trace = trace(seed(substitute(model, samples),
                              rng_key)).get_trace(*model_args,
                                                  **model_kwargs)
     if return_sites is not None:
         if return_sites == '':
             sites = {
                 k
                 for k, site in model_trace.items()
                 if site['type'] != 'plate'
             }
         else:
             sites = return_sites
     else:
         sites = {
             k
             for k, site in model_trace.items()
             if (site['type'] == 'sample' and k not in samples) or (
                 site['type'] == 'deterministic')
         }
     return {
         name: site['value']
         for name, site in model_trace.items() if name in sites
     }
示例#3
0
        def log_likelihood(params_flat, subsample_indices=None):
            if subsample_indices is None:
                subsample_indices = {
                    k: jnp.arange(v[0])
                    for k, v in subsample_plate_sizes.items()
                }
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                params = {
                    name: biject_to(prototype_trace[name]["fn"].support)(value)
                    for name, value in params.items()
                }
                with block(), trace() as tr, substitute(
                        data=subsample_indices), substitute(data=params):
                    model(*model_args, **model_kwargs)

            log_lik = {}
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in log_lik:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
                        else:
                            log_lik[frame.name] = _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik
示例#4
0
def constrain_fn(model,
                 transforms,
                 model_args,
                 model_kwargs,
                 params,
                 return_deterministic=False):
    """
    (EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given
    unconstrained parameters `params`. The `transforms` is used to transform these
    unconstrained parameters to base values of the corresponding priors in `model`.
    If a prior is a transformed distribution, the corresponding base value lies in
    the support of base distribution. Otherwise, the base value lies in the support
    of the distribution.

    :param model: a callable containing NumPyro primitives.
    :param dict transforms: dictionary of transforms keyed by names. Names in
        `transforms` and `params` should align.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of unconstrained values keyed by site
        names.
    :param bool return_deterministic: whether to return the value of `deterministic`
        sites from the model. Defaults to `False`.
    :return: `dict` of transformed params.
    """
    params_constrained = transform_fn(transforms, params)
    substituted_model = substitute(model, base_param_map=params_constrained)
    model_trace = trace(substituted_model).get_trace(*model_args,
                                                     **model_kwargs)
    return {
        k: v['value']
        for k, v in model_trace.items() if (k in params) or (
            return_deterministic and v['type'] == 'deterministic')
    }
示例#5
0
def get_model_transforms(rng_key, model, model_args=(), model_kwargs=None):
    model_kwargs = {} if model_kwargs is None else model_kwargs
    seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    # model code may need to be replayed in the presence of dynamic constraints
    # or deterministic sites
    replay_model = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                replay_model = True
            else:
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                inv_transforms[k] = transform.parts[0]
                replay_model = True
            else:
                inv_transforms[k] = transform
        elif v['type'] == 'deterministic':
            replay_model = True
    return inv_transforms, replay_model
示例#6
0
def constrain_fn(model, model_args, model_kwargs, transforms, params):
    """
    Gets value at each latent site in `model` given unconstrained parameters `params`.
    The `transforms` is used to transform these unconstrained parameters to base values
    of the corresponding priors in `model`. If a prior is a transformed distribution,
    the corresponding base value lies in the support of base distribution. Otherwise,
    the base value lies in the support of the distribution.

    :param model: a callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs`: kwargs provided to the model.
    :param dict transforms: dictionary of transforms keyed by names. Names in
        `transforms` and `params` should align.
    :param dict params: dictionary of unconstrained values keyed by site
        names.
    :return: `dict` of transformed params.
    """
    params_constrained = transform_fn(transforms, params)
    substituted_model = substitute(model, base_param_map=params_constrained)
    model_trace = trace(substituted_model).get_trace(*model_args,
                                                     **model_kwargs)
    return {
        k: model_trace[k]['value']
        for k, v in params.items() if k in model_trace
    }
示例#7
0
def _get_model_transforms(model, model_args=(), model_kwargs=None):
    model_kwargs = {} if model_kwargs is None else model_kwargs
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    # model code may need to be replayed in the presence of deterministic sites
    replay_model = False
    has_enumerate_support = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['fn'].is_discrete:
                has_enumerate_support = True
                if not v['fn'].has_enumerate_support:
                    raise RuntimeError("MCMC only supports continuous sites or discrete sites "
                                       f"with enumerate support, but got {type(v['fn']).__name__}.")
            else:
                support = v['fn'].support
                inv_transforms[k] = biject_to(support)
                # XXX: the following code filters out most situations with dynamic supports
                args = ()
                if isinstance(support, constraints._GreaterThan):
                    args = ('lower_bound',)
                elif isinstance(support, constraints._Interval):
                    args = ('lower_bound', 'upper_bound')
                for arg in args:
                    if not isinstance(getattr(support, arg), (int, float)):
                        replay_model = True
        elif v['type'] == 'deterministic':
            replay_model = True
    return inv_transforms, replay_model, has_enumerate_support, model_trace
示例#8
0
def log_density(model, model_args, model_kwargs, params):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, param_map=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = jnp.array(0.)
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            scale = site['scale']
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace
示例#9
0
文件: modelling.py 项目: byzhang/d3p
def sample_prior_predictive(rng_key, model, model_args,
        substitutes=None, with_intermediates=False, **kwargs):
    """ Samples once from the prior predictive distribution.

    Individual sample sites, as designated by `sample`, can be frozen to
    pre-determined values given in `substitutes`. In that case, values for these
    sites are not actually sampled but the value provided in `substitutes` is
    returned as the sample. This facilitates conditional sampling.

    Note that if the model function is written in such a way that it returns, e.g.,
    multiple observations from a single prior draw, the same is true for the
    values returned by this function.

    :param rng_key: Jax PRNG key
    :param model: Function representing the model using numpyro distributions
        and the `sample` primitive
    :param model_args: Arguments to the model function
    :param substitutes: An optional dictionary of frozen substitutes for
        sample sites.
    :param with_intermediates: If True, intermediate(/latent) samples from
        sample site distributions are included in the result.
    :param **kwargs: Keyword arguments passed to the model function.
    :return: Dictionary of sampled values associated with the names given
        via `sample()` in the model. If with_intermediates is True,
        dictionary values are tuples where the first element is the final
        sample values and the second element is a list of intermediate values.
    """
    if substitutes is None: substitutes = dict()
    model = seed(substitute(model, data=substitutes), rng_key)
    t = trace(model).get_trace(*model_args, **kwargs)
    return get_samples_from_trace(t, with_intermediates)
示例#10
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            params = transform_fn(inv_transforms,
                                  {k: v for k, v in constrained_values.items()},
                                  invert=True)
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum)
        pe, z_grad = value_and_grad(potential_fn)(params)
        z_grad_flat = ravel_pytree(z_grad)[0]
        is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        return i + 1, key, (params, pe, z_grad), is_valid
示例#11
0
def test_counterfactual_query(intervene, observe, flip):
    # x -> y -> z -> w

    sites = ["x", "y", "z", "w"]
    observations = {"x": 1., "y": None, "z": 1., "w": 1.}
    interventions = {"x": None, "y": 0., "z": 2., "w": 1.}

    def model():
        with handlers.seed(rng_seed=0):
            x = numpyro.sample("x", dist.Normal(0, 1))
            y = numpyro.sample("y", dist.Normal(x, 1))
            z = numpyro.sample("z", dist.Normal(y, 1))
            w = numpyro.sample("w", dist.Normal(z, 1))
            return dict(x=x, y=y, z=z, w=w)

    if not flip:
        if intervene:
            model = handlers.do(model, data=interventions)
        if observe:
            model = handlers.condition(model, data=observations)
    elif flip and intervene and observe:
        model = handlers.do(handlers.condition(model, data=observations),
                            data=interventions)

    with handlers.trace() as tr:
        actual_values = model()
    for name in sites:
        # case 1: purely observational query like handlers.condition
        if not intervene and observe:
            if observations[name] is not None:
                assert tr[name]['is_observed']
                assert_allclose(observations[name], actual_values[name])
                assert_allclose(observations[name], tr[name]['value'])
            if interventions[name] != observations[name]:
                if interventions[name] is not None:
                    assert_raises(AssertionError, assert_allclose,
                                  interventions[name], actual_values[name])
        # case 2: purely interventional query like old handlers.do
        elif intervene and not observe:
            assert not tr[name]['is_observed']
            if interventions[name] is not None:
                assert_allclose(interventions[name], actual_values[name])
            if observations[name] is not None:
                assert_raises(AssertionError, assert_allclose,
                              observations[name], tr[name]['value'])
            if interventions[name] is not None:
                assert_raises(AssertionError, assert_allclose,
                              interventions[name], tr[name]['value'])
        # case 3: counterfactual query mixing intervention and observation
        elif intervene and observe:
            if observations[name] is not None:
                assert tr[name]['is_observed']
                assert_allclose(observations[name], tr[name]['value'])
            if interventions[name] is not None:
                assert_allclose(interventions[name], actual_values[name])
            if interventions[name] != observations[name]:
                if interventions[name] is not None:
                    assert_raises(AssertionError, assert_allclose,
                                  interventions[name], tr[name]['value'])
示例#12
0
def test_haiku_module():
    X = np.arange(100)
    Y = 2 * X + 2

    with handlers.trace() as haiku_tr, handlers.seed(rng_seed=1):
        haiku_model(X, Y)
    assert haiku_tr["nn$params"]['value']['linear']['w'].shape == (100, 100)
    assert haiku_tr["nn$params"]['value']['linear']['b'].shape == (100,)
示例#13
0
def predict(model, rng_key, samples, X, D_H, sigma_obs=None):
    model = handlers.substitute(handlers.seed(model, rng_key), samples)
    # note that Y will be sampled in the model because we pass Y=None here
    model_trace = handlers.trace(model).get_trace(X=X,
                                                  Y=None,
                                                  D_H=D_H,
                                                  sigma_obs=sigma_obs)
    return model_trace['Y']['value']
示例#14
0
 def test_load_numpyro_model_simple_working_model(self):
     """ only verifies that no errors occur and all returned functions are not None """
     model, guide, preprocess, postprocess = load_custom_numpyro_model(
         './tests/models/simple_gauss_model.py', Namespace(), [],
         pd.DataFrame())
     self.assertIsNotNone(model)
     self.assertIsNotNone(guide)
     self.assertIsNotNone(preprocess)
     self.assertIsNotNone(postprocess)
     z = np.ones((10, 2))
     samples_with_obs = trace(seed(model, jax.random.PRNGKey(0))).get_trace(
         z, num_obs_total=10)
     self.assertTrue(np.allclose(samples_with_obs['x']['value'], z))
     samples_no_obs = trace(seed(
         model, jax.random.PRNGKey(0))).get_trace(num_obs_total=10)
     self.assertEqual(samples_no_obs['x']['value'].shape, (1, 2))
     self.assertFalse(np.allclose(samples_no_obs['x']['value'], z))
示例#15
0
def test_distribution_1(temperature):
    #      +-------+
    #  z --|--> x  |
    #      +-------+
    num_particles = 10000
    data = np.array([1.0, 2.0, 3.0])

    @config_enumerate
    def model(z=None):
        p = numpyro.param("p", np.array([0.75, 0.25]))
        iz = numpyro.sample("z", dist.Categorical(p), obs=z)
        z = jnp.array([0.0, 1.0])[iz]
        logger.info("z.shape = {}".format(z.shape))
        with numpyro.plate("data", 3):
            numpyro.sample("x", dist.Normal(z, 1.0), obs=data)

    first_available_dim = -3
    vectorized_model = (
        model if temperature == 0 else vectorize_model(model, num_particles, dim=-2)
    )
    sampled_model = infer_discrete(
        vectorized_model, first_available_dim, temperature, rng_key=random.PRNGKey(1)
    )
    sampled_trace = handlers.trace(sampled_model).get_trace()
    conditioned_traces = {
        z: handlers.trace(model).get_trace(z=np.array(z)) for z in [0, 1]
    }

    # Check  posterior over z.
    actual_z_mean = sampled_trace["z"]["value"].astype(float).mean()
    if temperature:
        expected_z_mean = 1 / (
            1
            + np.exp(
                log_prob_sum(conditioned_traces[0])
                - log_prob_sum(conditioned_traces[1])
            )
        )
    else:
        expected_z_mean = (
            log_prob_sum(conditioned_traces[1]) > log_prob_sum(conditioned_traces[0])
        ).astype(float)
        expected_max = max(log_prob_sum(t) for t in conditioned_traces.values())
        actual_max = log_prob_sum(sampled_trace)
        assert_allclose(expected_max, actual_max, atol=1e-5)
    assert_allclose(actual_z_mean, expected_z_mean, atol=1e-2 if temperature else 1e-5)
示例#16
0
 def get_log_probs(sample, seed=0):
     with handlers.trace() as tr, handlers.seed(
             model, seed), handlers.substitute(data=sample):
         model(*model_args, **model_kwargs)
     return {
         name: site["fn"].log_prob(site["value"])
         for name, site in tr.items() if site["type"] == "sample"
     }
示例#17
0
def log_density(model, model_args, model_kwargs, params):
    model = substitute(model, params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = 0.
    for site in model_trace.values():
        if site['type'] == 'sample':
            log_joint = log_joint + np.sum(site['fn'].log_prob(site['value']))
    return log_joint, model_trace
示例#18
0
def test_flax_module():
    X = np.arange(100)
    Y = 2 * X + 2

    with handlers.trace() as flax_tr, handlers.seed(rng_seed=1):
        flax_model(X, Y)
    assert flax_tr["nn$params"]['value']['kernel'].shape == (100, 100)
    assert flax_tr["nn$params"]['value']['bias'].shape == (100,)
示例#19
0
def log_likelihood(
    rng_key: np.ndarray, params: np.ndarray, model: Callable, *args: Any, **kwargs: Any
) -> np.ndarray:

    model = handlers.condition(model, params)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    obs_node = model_trace["obs"]
    return obs_node["fn"].log_prob(obs_node["value"])
示例#20
0
 def single_loglik(samples):
     model_trace = trace(substitute(model,
                                    samples)).get_trace(*args, **kwargs)
     return {
         name: site['fn'].log_prob(site['value'])
         for name, site in model_trace.items()
         if site['type'] == 'sample' and site['is_observed']
     }
示例#21
0
文件: modelling.py 项目: byzhang/d3p
def sample_posterior_predictive(rng_key, model, model_args, guide, guide_args,
        params, with_intermediates=False, **kwargs):
    """ Samples once from the posterior predictive distribution.

    Note that if the model function is written in such a way that it returns, e.g.,
    multiple observations from a single posterior draw, the same is true for the
    values returned by this function.

    :param rng_key: Jax PRNG key
    :param model: Function representing the model using numpyro distributions
        and the `sample` primitive
    :param model_args: Arguments to the model function
    :param guide: Function representing the variational distribution (the guide)
        using numpyro distributions as well as the `sample` and `param` primitives
    :param guide_args: Arguments to the guide function
    :param params: A dictionary providing values for the parameters
        designated by call to `param` in the guide
    :param with_intermediates: If True, intermediate(/latent) samples from
        sample site distributions are included in the result.
    :param **kwargs: Keyword arguments passed to the model and guide functions.
    :return: Dictionary of sampled values associated with the names given
        via `sample()` in the model. If with_intermediates is True,
        dictionary values are tuples where the first element is the final
        sample values and the second element is a list of intermediate values.
    """
    model_rng_key, guide_rng_key = jax.random.split(rng_key)

    guide = seed(substitute(guide, data=params), guide_rng_key)
    guide_samples = get_samples_from_trace(
        trace(guide).get_trace(*guide_args, **kwargs), with_intermediates
    )

    model_params = dict(**params)
    if with_intermediates:
        model_params.update({k: v[0] for k, v in guide_samples.items()})
    else:
        model_params.update({k: v for k, v in guide_samples.items()})

    model = seed(substitute(model, data=model_params), model_rng_key)
    model_samples = get_samples_from_trace(
        trace(model).get_trace(*model_args, **kwargs), with_intermediates
    )

    guide_samples.update(model_samples)
    return guide_samples
示例#22
0
def _get_log_probs(model, model_args, model_kwargs, sample):
    # Note: We use seed 0 for parameter initialization.
    with handlers.trace() as tr, handlers.seed(
            rng_seed=0), handlers.substitute(data=sample):
        model(*model_args, **model_kwargs)
    return {
        name: site["fn"].log_prob(site["value"])
        for name, site in tr.items() if site["type"] == "sample"
    }
示例#23
0
def test_scope_frames():
    def model(y):
        mu = numpyro.sample("mu", dist.Normal())
        sigma = numpyro.sample("sigma", dist.HalfNormal())

        with numpyro.plate("plate1", y.shape[0]):
            numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

    scope_prefix = "scope"
    scoped_model = handlers.scope(model, prefix=scope_prefix)

    obs = np.random.normal(size=(10,))

    trace = handlers.trace(handlers.seed(model, 0)).get_trace(obs)
    scoped_trace = handlers.trace(handlers.seed(scoped_model, 0)).get_trace(obs)

    assert trace["y"]["cond_indep_stack"][0].name in trace
    assert scoped_trace[f"{scope_prefix}/y"]["cond_indep_stack"][0].name in scoped_trace
示例#24
0
def predict(model, at_bats, hits, z, rng, player_names, train=True):
    header = model.__name__ + (' - TRAIN' if train else ' - TEST')
    model = substitute(seed(model, rng), z)
    model_trace = trace(model).get_trace(at_bats)
    predictions = model_trace['obs']['value']
    print_results('=' * 30 + header + '=' * 30, predictions, player_names,
                  at_bats, hits)
    if not train:
        model = substitute(model, z)
        model_trace = trace(model).get_trace(at_bats, hits)
        log_joint = 0.
        for site in model_trace.values():
            site_log_prob = site['fn'].log_prob(site['value'])
            log_joint = log_joint + np.sum(
                site_log_prob.reshape(site_log_prob.shape[:1] + (-1, )), -1)
        log_post_density = logsumexp(log_joint) - np.log(
            np.shape(log_joint)[0])
        print('\nPosterior log density: {:.2f}\n'.format(log_post_density))
示例#25
0
def predict(model, rng_key, samples, p, t, D_H, data_type):
    model = handlers.substitute(handlers.seed(model, rng_key), samples)
    model_trace = handlers.trace(model).get_trace(p=p,
                                                  t=t,
                                                  Y=None,
                                                  F=None,
                                                  D_H=D_H,
                                                  data_type=data_type)
    return model_trace['Y']['value']
示例#26
0
 def single_prediction(rng, samples):
     model_trace = trace(seed(condition(model, samples),
                              rng)).get_trace(*args, **kwargs)
     sites = model_trace.keys() - samples.keys(
     ) if return_sites is None else return_sites
     return {
         name: site['value']
         for name, site in model_trace.items() if name in sites
     }
示例#27
0
 def _setup_prototype(self, *args, **kwargs):
     # run the model so we can inspect its structure
     rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix),
                              dist.PRNGIdentity())
     model = handlers.seed(self.model, rng_key)
     self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(
         *args, **kwargs)
     self._args = args
     self._kwargs = kwargs
示例#28
0
文件: util.py 项目: jatentaki/numpyro
 def single_loglik(samples):
     substituted_model = (substitute(model, samples) if isinstance(
         samples, dict) else model)
     model_trace = trace(substituted_model).get_trace(*args, **kwargs)
     return {
         name: site["fn"].log_prob(site["value"])
         for name, site in model_trace.items()
         if site["type"] == "sample" and site["is_observed"]
     }
示例#29
0
def initialize_model(rng, model, model_args, model_kwargs):
    model = seed(model, rng)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    init_params = {
        k: v['value']
        for k, v in model_trace.items()
        if v['type'] == 'sample' and not v['is_observed']
    }
    return init_params, potential_energy(model, model_args, model_kwargs)
示例#30
0
def forecast(future, rng_key, sample, y, n_obs):
    Z_exp, obs_last = handlers.substitute(ar_k, sample)(n_obs, y)
    forecast_model = handlers.seed(_forecast, rng_key)
    forecast_trace = handlers.trace(forecast_model).get_trace(
        future, sample, Z_exp, n_obs)
    results = [
        np.clip(forecast_trace["yf[{}]".format(t)]["value"], a_min=1e-30)
        for t in range(future)
    ]
    return np.stack(results, axis=0)