コード例 #1
0
ファイル: elbo.py プロジェクト: xidulu/numpyro
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo
コード例 #2
0
ファイル: test_handlers.py プロジェクト: mhashemi0873/numpyro
def test_subsample_replay():
    data = jnp.arange(100.)
    subsample_size = 7

    with handlers.trace() as guide_trace, handlers.seed(rng_seed=0):
        with numpyro.plate("a", len(data), subsample_size=subsample_size):
            pass

    with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace):
        with numpyro.plate("a", len(data)):
            subsample_data = numpyro.subsample(data, event_dim=0)
            assert subsample_data.shape == (subsample_size, )
コード例 #3
0
ファイル: test_handlers.py プロジェクト: uiuc-arc/numpyro
def test_prng_key_with_vmap():
    def model(x=None):
        return numpyro.prng_key()

    x = handlers.seed(vmap(model), 0)(jnp.arange(10))
    assert (x == x[0]).all()
    y = vmap(handlers.seed(model, 0))(jnp.arange(10))
    assert (x == y).all()
    z = vmap(lambda i: handlers.seed(model, i)())(jnp.arange(10))
    z0 = handlers.seed(model, 0)()
    assert (z[1:] != z0).all()
    assert (z[0] == z0).all()
コード例 #4
0
ファイル: svi.py プロジェクト: dirmeier/numpyro
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs
        )
        params = {}
        inv_transforms = {}
        mutable_state = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site["type"] == "param":
                constraint = site["kwargs"].pop("constraint", constraints.real)
                with helpful_support_errors(site):
                    transform = biject_to(constraint)
                inv_transforms[site["name"]] = transform
                params[site["name"]] = transform.inv(site["value"])
            elif site["type"] == "mutable":
                mutable_state[site["name"]] = site["value"]
            elif (
                site["type"] == "sample"
                and (not site["is_observed"])
                and site["fn"].support.is_discrete
                and not self.loss.can_infer_discrete
            ):
                s_name = type(self.loss).__name__
                warnings.warn(
                    f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
                )

        if not mutable_state:
            mutable_state = None
        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params, mutable_state = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)),
            (params, mutable_state),
        )
        return SVIState(self.optim.init(params), mutable_state, rng_key)
コード例 #5
0
ファイル: elbo.py プロジェクト: ziatdinovmax/numpyro
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            # NB: we only want to substitute params not available in guide_trace
            model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace}
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo
コード例 #6
0
def test_flax_module():
    X = np.arange(100).astype(np.float32)
    Y = 2 * X + 2

    with handlers.trace() as flax_tr, handlers.seed(rng_seed=1):
        flax_model_by_shape(X, Y)
    assert flax_tr["nn$params"]["value"]["kernel"].shape == (100, 100)
    assert flax_tr["nn$params"]["value"]["bias"].shape == (100, )

    with handlers.trace() as flax_tr, handlers.seed(rng_seed=1):
        flax_model_by_kwargs(X, Y)
    assert flax_tr["nn$params"]["value"]["kernel"].shape == (100, 100)
    assert flax_tr["nn$params"]["value"]["bias"].shape == (100, )
コード例 #7
0
def test_discrete_helpful_error(auto_class, init_loc_fn):
    def model():
        p = numpyro.sample("p", dist.Beta(2.0, 2.0))
        x = numpyro.sample("x", dist.Bernoulli(p))
        with numpyro.plate("N", 2):
            numpyro.sample(
                "obs",
                dist.Bernoulli(p * x + (1 - p) * (1 - x)),
                obs=jnp.array([1.0, 0.0]),
            )

    guide = auto_class(model, init_loc_fn=init_loc_fn)
    with pytest.raises(ValueError, match=".*handle discrete.*"):
        handlers.seed(guide, 0)()
コード例 #8
0
ファイル: test_handlers.py プロジェクト: mhashemi0873/numpyro
def test_seed():
    def _sample():
        x = numpyro.sample('x', dist.Normal(0., 1.))
        y = numpyro.sample('y', dist.Normal(1., 2.))
        return jnp.stack([x, y])

    xs = []
    for i in range(100):
        with handlers.seed(rng_seed=i):
            xs.append(_sample())
    xs = jnp.stack(xs)

    ys = vmap(lambda rng_key: handlers.seed(lambda: _sample(), rng_key)())(
        jnp.arange(100))
    assert_allclose(xs, ys, atol=1e-6)
コード例 #9
0
 def test_load_numpyro_model_model_not_allowing_None_arguments(self):
     model, _, _, _ = load_custom_numpyro_model(
         './tests/models/simple_gauss_model_no_none.py', Namespace(), [],
         pd.DataFrame())
     try:
         seed(model, jax.random.PRNGKey(0))(num_obs_total=100)
     except ModelException as e:
         if e.title.find('model'.upper()) != -1 and e.msg.find(
                 'None for synthesising data') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in model, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for error in model")
コード例 #10
0
ファイル: util.py プロジェクト: synabreu/numpyro
 def single_prediction(val):
     rng_key, samples = val
     model_trace = trace(seed(substitute(model, samples),
                              rng_key)).get_trace(*model_args,
                                                  **model_kwargs)
     if return_sites is not None:
         if return_sites == '':
             sites = {
                 k
                 for k, site in model_trace.items()
                 if site['type'] != 'plate'
             }
         else:
             sites = return_sites
     else:
         sites = {
             k
             for k, site in model_trace.items()
             if (site['type'] == 'sample' and k not in samples) or (
                 site['type'] == 'deterministic')
         }
     return {
         name: site['value']
         for name, site in model_trace.items() if name in sites
     }
コード例 #11
0
ファイル: util.py プロジェクト: synabreu/numpyro
def get_model_transforms(rng_key, model, model_args=(), model_kwargs=None):
    model_kwargs = {} if model_kwargs is None else model_kwargs
    seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    # model code may need to be replayed in the presence of dynamic constraints
    # or deterministic sites
    replay_model = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                replay_model = True
            else:
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                inv_transforms[k] = transform.parts[0]
                replay_model = True
            else:
                inv_transforms[k] = transform
        elif v['type'] == 'deterministic':
            replay_model = True
    return inv_transforms, replay_model
コード例 #12
0
ファイル: test_handlers.py プロジェクト: mhashemi0873/numpyro
 def model():
     with handlers.seed(rng_seed=0):
         x = numpyro.sample("x", dist.Normal(0, 1))
         y = numpyro.sample("y", dist.Normal(x, 1))
         z = numpyro.sample("z", dist.Normal(y, 1))
         w = numpyro.sample("w", dist.Normal(z, 1))
         return dict(x=x, y=y, z=z, w=w)
コード例 #13
0
ファイル: test_handlers.py プロジェクト: mhashemi0873/numpyro
def test_plate_stack(shape):
    def guide():
        with numpyro.plate_stack("plates", shape):
            return numpyro.sample("x", dist.Normal(0, 1))

    x = handlers.seed(guide, 0)()
    assert x.shape == shape
コード例 #14
0
ファイル: __init__.py プロジェクト: cnheider/numpyro
 def _setup_prototype(self, *args, **kwargs):
     # run the model so we can inspect its structure
     rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
     model = handlers.seed(self.model, rng_key)
     self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(*args, **kwargs)
     self._args = args
     self._kwargs = kwargs
コード例 #15
0
def test_update_params():
    params = {"a": {"b": {"c": {"d": 1}, "e": np.array(2)}, "f": np.ones(4)}}
    prior = {"a.b.c.d": dist.Delta(4), "a.f": dist.Delta(5)}
    new_params = deepcopy(params)
    with handlers.seed(rng_seed=0):
        _update_params(params, new_params, prior)
    assert params == {
        "a": {
            "b": {
                "c": {
                    "d": ParamShape(())
                },
                "e": 2
            },
            "f": ParamShape((4, ))
        }
    }
    test_util.check_eq(
        new_params,
        {
            "a": {
                "b": {
                    "c": {
                        "d": np.array(4.0)
                    },
                    "e": np.array(2)
                },
                "f": np.full((4, ), 5.0),
            }
        },
    )
コード例 #16
0
    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        with handlers.block():

            # we need to tell unconstrained messenger in potential energy computation
            # that only the item at time `i` is needed when transforming
            fn = handlers.infer_config(
                f, config_fn=lambda msg: {"_scan_current_index": i})

            seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == "condition":
                    seeded_fn = handlers.condition(seeded_fn,
                                                   condition_fn=subs_fn)
                elif subs_type == "substitute":
                    seeded_fn = handlers.substitute(seeded_fn,
                                                    substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
コード例 #17
0
ファイル: scan.py プロジェクト: mhashemi0873/numpyro
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
コード例 #18
0
 def test_load_numpyro_model_model_without_num_obs_total(self):
     model, _, _, _ = load_custom_numpyro_model(
         './tests/models/simple_gauss_model_no_num_obs_total.py',
         Namespace(), [], pd.DataFrame())
     z = np.ones((10, 2))
     try:
         seed(model, jax.random.PRNGKey(0))(z, num_obs_total=100)
     except ModelException as e:
         if e.title.find('model'.upper()) != -1 and e.msg.find(
                 'num_obs_total') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in model, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for error in model")
コード例 #19
0
 def test_load_numpyro_model_broken_model(self):
     model, _, _, _ = load_custom_numpyro_model(
         './tests/models/simple_gauss_model_broken.py', Namespace(), [],
         pd.DataFrame())
     z = np.ones((10, 2))
     try:
         seed(model, jax.random.PRNGKey(0))(z)
     except ModelException as e:
         if isinstance(e.base,
                       NameError) and e.title.find('model'.upper()) != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in model, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for error in model")
コード例 #20
0
def test_scan_hmm_smoke(length, temperature):

    # This should match the example in the infer_discrete docstring.
    def hmm(data, hidden_dim=10):
        transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
        means = jnp.arange(float(hidden_dim))

        def transition_fn(state, y):
            state = numpyro.sample("states",
                                   dist.Categorical(transition[state]))
            y = numpyro.sample("obs", dist.Normal(means[state], 1.0), obs=y)
            return state, (state, y)

        _, (states, data) = scan(transition_fn, 0, data, length=length)

        return [0] + [s for s in states], data

    true_states, data = handlers.seed(hmm, 0)(None)
    assert len(data) == length
    assert len(true_states) == 1 + len(data)

    decoder = infer_discrete(config_enumerate(hmm),
                             temperature=temperature,
                             rng_key=random.PRNGKey(1))
    inferred_states, _ = decoder(data)
    assert len(inferred_states) == len(true_states)

    logger.info("true states: {}".format(list(map(int, true_states))))
    logger.info("inferred states: {}".format(list(map(int, inferred_states))))
コード例 #21
0
ファイル: test_brm.py プロジェクト: neerajprad/brmp
def test_numpyro_codegen(N, formula_str, non_real_cols, contrasts, family,
                         priors, expected):
    # Make dummy data.
    formula = parse(formula_str)
    cols = expand_columns(formula, non_real_cols)
    metadata = metadata_from_cols(cols)
    desc = makedesc(formula, metadata, family, priors, code_lengths(contrasts))

    # Generate model function and data.
    modelfn = numpyro_backend.gen(desc).fn

    df = dummy_df(cols, N)
    data = data_from_numpy(numpyro_backend,
                           makedata(formula, df, metadata, contrasts))

    # Check sample sites.
    rng = random.PRNGKey(0)
    trace = numpyro.trace(numpyro.seed(modelfn, rng)).get_trace(**data)
    expected_sites = [site for (site, _, _) in expected]
    sample_sites = [
        name for name, node in trace.items() if not node['is_observed']
    ]
    assert set(sample_sites) == set(expected_sites)
    for (site, family_name, maybe_params) in expected:
        numpyro_family_name = dict(LKJ='LKJCholesky').get(
            family_name, family_name)
        fn = trace[site]['fn']
        params = maybe_params or default_params[family_name]
        assert type(fn).__name__ == numpyro_family_name
        for (name, expected_val) in params.items():
            if family_name == 'LKJ':
                assert name == 'eta'
                name = 'concentration'
            val = fn.__getattribute__(name)
            assert_equal(val._value, np.broadcast_to(expected_val, val.shape))
コード例 #22
0
def test_hmm_smoke(length, temperature):

    # This should match the example in the infer_discrete docstring.
    def hmm(data, hidden_dim=10):
        transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
        means = jnp.arange(float(hidden_dim))
        states = [0]
        for t in markov(range(len(data))):
            states.append(
                numpyro.sample(
                    "states_{}".format(t), dist.Categorical(transition[states[-1]])
                )
            )
            data[t] = numpyro.sample(
                "obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t]
            )
        return states, data

    true_states, data = handlers.seed(hmm, 0)([None] * length)
    assert len(data) == length
    assert len(true_states) == 1 + len(data)

    decoder = infer_discrete(
        config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1)
    )
    inferred_states, _ = decoder(data)
    assert len(inferred_states) == len(true_states)

    logger.info("true states: {}".format(list(map(int, true_states))))
    logger.info("inferred states: {}".format(list(map(int, inferred_states))))
コード例 #23
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # XXX: we don't want to apply enum to draw latent samples
            model_ = model
            if enum:
                from numpyro.contrib.funsor import enum as enum_handler

                if isinstance(model, substitute) and isinstance(model.fn, enum_handler):
                    model_ = substitute(model.fn.fn, data=model.data)
                elif isinstance(model, enum_handler):
                    model_ = model.fn

            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model_, subkey), substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if (
                    v["type"] == "sample"
                    and not v["is_observed"]
                    and not v["fn"].is_discrete
                ):
                    constrained_values[k] = v["value"]
                    inv_transforms[k] = biject_to(v["fn"].support)
            params = transform_fn(
                inv_transforms,
                {k: v for k, v in constrained_values.items()},
                invert=True,
            )
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(
                        subkey, jnp.shape(v), minval=-radius, maxval=radius
                    )
                    key, subkey = random.split(key)

        potential_fn = partial(
            potential_energy, model, model_args, model_kwargs, enum=enum
        )
        if validate_grad:
            if forward_mode_differentiation:
                pe = potential_fn(params)
                z_grad = jacfwd(potential_fn)(params)
            else:
                pe, z_grad = value_and_grad(potential_fn)(params)
            z_grad_flat = ravel_pytree(z_grad)[0]
            is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        else:
            pe = potential_fn(params)
            is_valid = jnp.isfinite(pe)
            z_grad = None

        return i + 1, key, (params, pe, z_grad), is_valid
コード例 #24
0
    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
        rng_key, key_u = random.split(rng_key)
        self._prototype_trace = trace(seed(self.model, key_u)).get_trace(*model_args, **model_kwargs)
        self._subsample_plate_sizes = {
            name: site["args"]
            for name, site in self._prototype_trace.items()
            if site["type"] == "plate" and site["args"][0] > site["args"][1]  # i.e. size > subsample_size
        }
        self._gibbs_sites = list(self._subsample_plate_sizes.keys())
        assert self._gibbs_sites, "Cannot detect any subsample statements in the model."
        if self._proxy is not None:
            proxy_fn, gibbs_init, self._gibbs_update = self._proxy(self._prototype_trace,
                                                                   self._subsample_plate_sizes,
                                                                   self.model,
                                                                   model_args,
                                                                   model_kwargs.copy(),
                                                                   num_blocks=self._num_blocks)
            method = perturbed_method(self._subsample_plate_sizes, proxy_fn)
            self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, method)

            z_gibbs = {name: site["value"] for name, site in self._prototype_trace.items() if name in self._gibbs_sites}
            rng_key, rng_state = random.split(rng_key)
            gibbs_state = gibbs_init(rng_state, z_gibbs)
        else:
            self._gibbs_update = partial(_block_update, self._subsample_plate_sizes, self._num_blocks)
            gibbs_state = ()

        model_kwargs["_gibbs_state"] = gibbs_state
        state = super().init(rng_key, num_warmup, init_params, model_args, model_kwargs)
        return HMCECSState(state.z, state.hmc_state, state.rng_key, gibbs_state, jnp.array(0.))
コード例 #25
0
ファイル: bnn.py プロジェクト: ramonemiliani93/uncertainty
 def predict(bnn_model, rng_key, samples, x, num_hidden):
     bnn_model = handlers.substitute(handlers.seed(bnn_model, rng_key),
                                     samples)
     # note that y will be sampled in the bnn_model because we pass y=None here
     model_trace = handlers.trace(bnn_model).get_trace(
         x=x, y=None, num_hidden=num_hidden)
     return model_trace['y']['value']
コード例 #26
0
ファイル: test_handlers.py プロジェクト: uiuc-arc/numpyro
def test_prng_key():
    assert numpyro.prng_key() is None

    with handlers.seed(rng_seed=0):
        rng_key = numpyro.prng_key()

    assert rng_key.shape == (2, ) and rng_key.dtype == "uint32"
コード例 #27
0
ファイル: util.py プロジェクト: jatentaki/numpyro
 def single_prediction(val):
     rng_key, samples = val
     model_trace = trace(seed(substitute(model, samples),
                              rng_key)).get_trace(*model_args,
                                                  **model_kwargs)
     if return_sites is not None:
         if return_sites == "":
             sites = {
                 k
                 for k, site in model_trace.items()
                 if site["type"] != "plate"
             }
         else:
             sites = return_sites
     else:
         sites = {
             k
             for k, site in model_trace.items()
             if (site["type"] == "sample" and k not in samples) or (
                 site["type"] == "deterministic")
         }
     return {
         name: site["value"]
         for name, site in model_trace.items() if name in sites
     }
コード例 #28
0
 def get_model_trace(rng):
     fn = handler.seed(assets.fn, rng)
     model_tr = handler.trace(fn).get_trace(mode="prior_only", **data)
     # Unpack the bits of the trace we're interested in into a dict
     # in order to support vectorization. (dicts support
     # vectorization, OrderedDicts, as used by the trace, don't.)
     return {k: node['value'] for k, node in model_tr.items()}
コード例 #29
0
def predict(model, rng_key, samples, X):
    """Numpyro's helper function for prediction
    """
    model = handlers.substitute(handlers.seed(model, rng_key), samples)
    # note that Y will be sampled in the model because we pass Y=None here
    model_trace = handlers.trace(model).get_trace(X=X, Y=None)
    return model_trace["Y"]["value"]
コード例 #30
0
def sample_response(assets, seed, *args):
    assert type(assets) == Assets
    assert seed is None or type(seed) is int
    if seed is None:
        seed = sample_rng_seed()
    rng = random.PRNGKey(seed)
    return handler.seed(assets.sample_response_fn, rng)(*args)