示例#1
0
def predict(model, at_bats, hits, z, rng, player_names, train=True):
    header = model.__name__ + (' - TRAIN' if train else ' - TEST')
    model = substitute(seed(model, rng), z)
    model_trace = trace(model).get_trace(at_bats)
    predictions = model_trace['obs']['value']
    print_results('=' * 30 + header + '=' * 30, predictions, player_names,
                  at_bats, hits)
    if not train:
        model = substitute(model, z)
        model_trace = trace(model).get_trace(at_bats, hits)
        log_joint = 0.
        for site in model_trace.values():
            site_log_prob = site['fn'].log_prob(site['value'])
            log_joint = log_joint + np.sum(
                site_log_prob.reshape(site_log_prob.shape[:1] + (-1, )), -1)
        log_post_density = logsumexp(log_joint) - np.log(
            np.shape(log_joint)[0])
        print('\nPosterior log density: {:.2f}\n'.format(log_post_density))
示例#2
0
def test_substitute():
    def model():
        x = numpyro.param('x', None)
        y = handlers.substitute(
            lambda: numpyro.param('y', None) * numpyro.param('x', None),
            {'y': x})()
        return x + y

    assert handlers.substitute(model, {'x': 3.})() == 12.
示例#3
0
 def unpack_single_latent(latent):
     unpacked_samples = self._unpack_latent(latent)
     if self._has_transformed_dist:
         # first, substitute to `param` statements in model
         model = handlers.substitute(self.model, params)
         return constrain_fn(model, self._inv_transforms, model_args,
                             model_kwargs, unpacked_samples)
     else:
         return transform_fn(self._inv_transforms, unpacked_samples)
示例#4
0
 def single_loglik(samples):
     substituted_model = (substitute(model, samples) if isinstance(
         samples, dict) else model)
     model_trace = trace(substituted_model).get_trace(*args, **kwargs)
     return {
         name: site["fn"].log_prob(site["value"])
         for name, site in model_trace.items()
         if site["type"] == "sample" and site["is_observed"]
     }
示例#5
0
def test_substitute():
    def model():
        x = numpyro.param("x", None)
        y = handlers.substitute(
            lambda: numpyro.param("y", None) * numpyro.param("x", None), {"y": x}
        )()
        return x + y

    assert handlers.substitute(model, {"x": 3.0})() == 12.0
示例#6
0
 def get_log_probs(sample):
     # Note: We use seed 0 for parameter initialization.
     with handlers.trace() as tr, handlers.seed(
             rng_seed=0), handlers.substitute(data=sample):
         model(*model_args, **model_kwargs)
     return {
         name: site["fn"].log_prob(site["value"])
         for name, site in tr.items() if site["type"] == "sample"
     }
示例#7
0
文件: util.py 项目: ucals/numpyro
 def single_loglik(samples):
     substituted_model = substitute(model, samples) if isinstance(
         samples, dict) else model
     model_trace = trace(substituted_model).get_trace(*args, **kwargs)
     return {
         name: site['fn'].log_prob(site['value'])
         for name, site in model_trace.items()
         if site['type'] == 'sample' and site['is_observed']
     }
示例#8
0
def predict(model, rng_key, samples, p, t, D_H, data_type):
    model = handlers.substitute(handlers.seed(model, rng_key), samples)
    model_trace = handlers.trace(model).get_trace(p=p,
                                                  t=t,
                                                  Y=None,
                                                  F=None,
                                                  D_H=D_H,
                                                  data_type=data_type)
    return model_trace['Y']['value']
示例#9
0
 def log_prior(params):
     with warnings.catch_warnings():
         warnings.filterwarnings("ignore", category=UserWarning)
         dummy_subsample = {
             k: jnp.array([], dtype=jnp.int32)
             for k in subsample_plate_sizes
         }
         with block(), substitute(data=dummy_subsample):
             prior_prob, _ = log_density(model, model_kwargs, params)
     return prior_prob
示例#10
0
文件: svi.py 项目: leej35/numpyro
    def init_fn(rng, model_args=(), guide_args=(), params=None):
        """

        :param jax.random.PRNGKey rng: random number generator seed.
        :param tuple model_args: arguments to the model (these can possibly vary during
            the course of fitting).
        :param tuple guide_args: arguments to the guide (these can possibly vary during
            the course of fitting).
        :param dict params: initial parameter values to condition on. This can be
            useful for initializing neural networks using more specialized methods
            rather than sampling from the prior.
        :return: tuple containing initial optimizer state, and `constrain_fn`, a callable
            that transforms unconstrained parameter values from the optimizer to the
            specified constrained domain
        """
        assert isinstance(model_args, tuple)
        assert isinstance(guide_args, tuple)
        model_init, guide_init = _seed(model, guide, rng)
        if params is None:
            params = {}
        else:
            model_init = substitute(model_init, params)
            guide_init = substitute(guide_init, params)
        guide_trace = trace(guide_init).get_trace(*guide_args, **kwargs)
        model_trace = trace(model_init).get_trace(*model_args, **kwargs)
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[site['name']] = base_transform
                    params[site['name']] = base_transform(
                        transform.inv(site['value']))
                else:
                    inv_transforms[site['name']] = transform
                    params[site['name']] = site['value']

        nonlocal constrain_fn
        constrain_fn = jax.partial(transform_fn, inv_transforms)
        return optim_init(params), constrain_fn
示例#11
0
    def median(self, params):
        """
        Returns the posterior median value of each latent variable.

        :param dict params: A dict containing parameter values.
        :return: A dict mapping sample site name to median tensor.
        :rtype: dict
        """
        loc, _ = handlers.substitute(self._loc_scale, params)()
        return self._unpack_and_constrain(loc, params)
示例#12
0
    def get_transform(self, params):
        """
        Returns the transformation learned by the guide to generate samples from the unconstrained
        (approximate) posterior.

        :param dict params: Current parameters of model and autoguide.
        :return: the transform of posterior distribution
        :rtype: :class:`~numpyro.distributions.constraints.Transform`
        """
        return handlers.substitute(self._get_transform, params)()
示例#13
0
def forecast(future, rng_key, sample, y, n_obs):
    Z_exp, obs_last = handlers.substitute(ar_k, sample)(n_obs, y)
    forecast_model = handlers.seed(_forecast, rng_key)
    forecast_trace = handlers.trace(forecast_model).get_trace(
        future, sample, Z_exp, n_obs)
    results = [
        np.clip(forecast_trace["yf[{}]".format(t)]["value"], a_min=1e-30)
        for t in range(future)
    ]
    return np.stack(results, axis=0)
示例#14
0
 def get_log_probs(sample, seed=0):
     with handlers.trace() as tr, handlers.seed(model, seed), handlers.substitute(
         data=sample
     ):
         model(*model_args, **model_kwargs)
     return {
         name: site["fn"].log_prob(site["value"])
         for name, site in tr.items()
         if site["type"] == "sample"
     }
示例#15
0
def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False):
    """
    Computes log of joint density for the model given latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs`: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :param bool skip_dist_transforms: whether to compute log probability of a site
        (if its prior is a transformed distribution) in its base distribution
        domain.
    :return: log of joint density and a corresponding model trace
    """
    # We skip transforms in
    #   + autoguide's model
    #   + hmc's model
    # We apply transforms in
    #   + autoguide's guide
    #   + svi's model + guide
    if skip_dist_transforms:
        model = substitute(model, base_param_map=params)
    else:
        model = substitute(model, param_map=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = 0.
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            if intermediates:
                if skip_dist_transforms:
                    log_prob = site['fn'].base_dist.log_prob(intermediates[0][0])
                else:
                    log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)
            log_prob = np.sum(log_prob)
            if 'scale' in site:
                log_prob = site['scale'] * log_prob
            log_joint = log_joint + log_prob
    return log_joint, model_trace
示例#16
0
def log_density(model, model_args, model_kwargs, params):
    model = substitute(model, params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = 0.
    for site in model_trace.values():
        if site['type'] == 'sample':
            log_prob = np.sum(site['fn'].log_prob(site['value']))
            if 'scale' in site:
                log_prob = site['scale'] * log_prob
            log_joint = log_joint + log_prob
    return log_joint, model_trace
 def _bnn_predict(self, rng_key, samples, X):
     '''
     This module takes the samples of a "trained" BNN and produces predictions
     based on the X values passed in
     '''
     
     model = handlers.substitute( handlers.seed(self._model,rng_key), samples ) # Pass post. sampled parameters to model
     # Gather a trace over possible Y values given the model parameters and input value X
     model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=self.bnn_dh,train=False) # Note: Y will be sampled in the model because we pass Y=None here
     
     return model_trace['Y']['value']
示例#18
0
def test_subsample_substitute():
    data = jnp.arange(100.)
    subsample_size = 7
    subsample = jnp.array([13, 3, 30, 4, 1, 68, 5])
    with handlers.trace() as tr, handlers.seed(
            rng_seed=0), handlers.substitute(data={"a": subsample}):
        with numpyro.plate("a", len(data),
                           subsample_size=subsample_size) as idx:
            assert data[idx].shape == (subsample_size, )
            assert_allclose(idx, subsample)
    assert tr["a"]["kwargs"]["rng_key"] is None
    def _blr_predict(self, rng_key, samples, X, predict=False):
        ''' This module takes the samples of a "trained" BLR and produces predictions based on the provided X '''
        model = handlers.substitute(
            handlers.seed(self._model, rng_key),
            samples)  # Pass post. sampled parameters to the model
        # Gather a trace over possible Y values given the model parameters and input value X
        model_trace = handlers.trace(model).get_trace(X=X,
                                                      Y=None,
                                                      predict=predict)

        return model_trace['obs']['value']
示例#20
0
文件: __init__.py 项目: juvu/numpyro
 def __init__(self,
              rng,
              model,
              get_params_fn,
              prefix="auto",
              init_loc_fn=init_to_median):
     # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
     # Use `block` to not record sample primitives in `init_loc_fn`.
     model = substitute(model, substitute_fn=block(seed(init_loc_fn, rng)))
     super(AutoContinuous, self).__init__(model,
                                          get_params_fn,
                                          prefix=prefix)
示例#21
0
        def log_likelihood(params_flat, subsample_indices=None):
            if subsample_indices is None:
                subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()}
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                params = {name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items()}
                with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params):
                    model(*model_args, **model_kwargs)

            log_lik = {}
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in log_lik:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
                        else:
                            log_lik[frame.name] = _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik
示例#22
0
文件: util.py 项目: jatentaki/numpyro
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model, subkey),
                                      substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(
                *model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if (v["type"] == "sample" and not v["is_observed"]
                        and not v["fn"].is_discrete):
                    constrained_values[k] = v["value"]
                    inv_transforms[k] = biject_to(v["fn"].support)
            params = transform_fn(
                inv_transforms,
                {k: v
                 for k, v in constrained_values.items()},
                invert=True,
            )
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey,
                                               jnp.shape(v),
                                               minval=-radius,
                                               maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy,
                               model,
                               model_args,
                               model_kwargs,
                               enum=enum)
        if validate_grad:
            if forward_mode_differentiation:
                pe = potential_fn(params)
                z_grad = jacfwd(potential_fn)(params)
            else:
                pe, z_grad = value_and_grad(potential_fn)(params)
            z_grad_flat = ravel_pytree(z_grad)[0]
            is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        else:
            pe = potential_fn(params)
            is_valid = jnp.isfinite(pe)
            z_grad = None

        return i + 1, key, (params, pe, z_grad), is_valid
示例#23
0
 def single_prediction(val):
     rng_key, samples = val
     model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
         *model_args, **model_kwargs)
     if return_sites is not None:
         if return_sites == '':
             sites = {k for k, site in model_trace.items() if site['type'] != 'plate'}
         else:
             sites = return_sites
     else:
         sites = {k for k, site in model_trace.items()
                  if (site['type'] == 'sample' and k not in samples) or (site['type'] == 'deterministic')}
     return {name: site['value'] for name, site in model_trace.items() if name in sites}
示例#24
0
def log_density(model, model_args, model_kwargs, params):
    def logp(d, val):
        with validation_disabled():
            return d.logpdf(val) if isinstance(
                d.dist, jax_continuous) else d.logpmf(val)

    model = substitute(model, params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = 0.
    for site in model_trace.values():
        if site['type'] == 'sample':
            log_joint = log_joint + np.sum(logp(site['fn'], site['value']))
    return log_joint, model_trace
示例#25
0
    def __call__(self, rng_key, *args, **kwargs):
        """
        Returns dict of samples from the predictive distribution. By default, only sample sites not
        contained in `posterior_samples` are returned. This can be modified by changing the
        `return_sites` keyword argument of this :class:`Predictive` instance.

        :param jax.random.PRNGKey rng_key: random key to draw samples.
        :param args: model arguments.
        :param kwargs: model kwargs.
        """
        posterior_samples = self.posterior_samples
        if self.guide is not None:
            rng_key, guide_rng_key = random.split(rng_key)
            # use return_sites='' as a special signal to return all sites
            guide = substitute(self.guide, self.params)
            posterior_samples = _predictive(guide_rng_key, guide, posterior_samples,
                                            self._batch_shape, return_sites='', parallel=self.parallel,
                                            model_args=args, model_kwargs=kwargs)
        model = substitute(self.model, self.params)
        return _predictive(rng_key, model, posterior_samples, self._batch_shape,
                           return_sites=self.return_sites, parallel=self.parallel,
                           model_args=args, model_kwargs=kwargs)
示例#26
0
    def sample_posterior(self, rng_key, params, sample_shape=()):
        """
        Get samples from the learned posterior.

        :param jax.random.PRNGKey rng_key: random key to be used draw samples.
        :param dict params: Current parameters of model and autoguide.
        :param tuple sample_shape: batch shape of each latent sample, defaults to ().
        :return: a dict containing samples drawn the this guide.
        :rtype: dict
        """
        latent_sample = handlers.substitute(handlers.seed(self._sample_latent, rng_key), params)(
            self.base_dist, sample_shape=sample_shape)
        return self._unpack_and_constrain(latent_sample, params)
示例#27
0
    def sample(self):
        rng_key, rng_key_sample, rng_key_accept = split(
            self.nmc_status.rng_key, 3)
        params = self.nmc_status.params

        for site in params.keys():
            # Collect accepted trace
            for i in range(len(params[site])):
                self.acc_trace[site + str(i)].append(params[site][i])

            tr_current = trace(substitute(self.model, params)).get_trace(
                *self.model_args, **self.model_kwargs)
            ll_current = self.nmc_status.log_likelihood

            val_current = tr_current[site]["value"]
            dist_curr = tr_current[site]["fn"]

            def log_den_fun(var):
                return partial(log_density, self.model, self.model_args,
                               self.model_kwargs)(var)[0]

            val_proposal, dist_proposal = self.proposal(
                site, log_den_fun, self.get_params(tr_current), dist_curr,
                rng_key_sample)

            tr_proposal = self.retrace(site, tr_current, dist_proposal,
                                       val_proposal, self.model_args,
                                       self.model_kwargs)
            ll_proposal = log_density(self.model, self.model_args,
                                      self.model_kwargs,
                                      self.get_params(tr_proposal))[0]

            ll_proposal_val = dist_proposal.log_prob(val_current).sum()
            ll_current_val = dist_curr.log_prob(val_proposal).sum()

            hastings_ratio = (ll_proposal + ll_proposal_val) - \
                (ll_current + ll_current_val)

            accept_prob = np.minimum(1, np.exp(hastings_ratio))
            u = sample("u", dist.Uniform(0, 1), rng_key=rng_key_accept)

            if u <= accept_prob:
                params, ll_current = self.get_params(tr_proposal), ll_proposal
            else:
                params, ll_current = self.get_params(tr_current), ll_current

        iter = self.nmc_status.i + 1
        mean_accept_prob = self.nmc_status.accept_prob + \
            (accept_prob - self.nmc_status.accept_prob) / iter

        return NMC_STATUS(iter, params, ll_current, mean_accept_prob, rng_key)
示例#28
0
def run_model_on_samples_and_data(modelfn, samples, data):
    assert type(samples) == dict
    assert len(samples) > 0
    num_chains, num_samples = next(iter(samples.values())).shape[0:2]
    assert all(arr.shape[0:2] == (num_chains, num_samples)
               for arr in samples.values())
    flat_samples = {k: flatten(arr) for k, arr in samples.items()}
    out = vmap(lambda sample: handler.substitute(modelfn, sample)
               (**data, mode='prior_and_mu'))(flat_samples)
    # Restore chain dim.
    return {
        k: unflatten(arr, num_chains, num_samples)
        for k, arr in out.items()
    }
示例#29
0
    def model(batch, subsample, full_size):
        drift = numpyro.sample("drift", dist.LogNormal(-1, 0.5))
        with handlers.substitute(data={"data": subsample}):
            plate = numpyro.plate("data", full_size, subsample_size=len(subsample))
        assert plate.size == 50

        def transition_fn(z_prev, y_curr):
            with plate:
                z_curr = numpyro.sample("state", dist.Normal(z_prev, drift))
                y_curr = numpyro.sample("obs", dist.Bernoulli(logits=z_curr), obs=y_curr)
            return z_curr, y_curr

        _, result = scan(transition_fn, jnp.zeros(len(subsample)), batch, length=num_time_steps)
        return result
示例#30
0
        def log_likelihood(params, subsample_indices=None):
            params_flat, unravel_fn = ravel_pytree(params)
            if subsample_indices is None:
                subsample_indices = {
                    k: jnp.arange(v[0])
                    for k, v in subsample_plate_sizes.items()
                }
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                with block(), trace(
                ) as tr, substitute(data=subsample_indices), substitute(
                        substitute_fn=partial(_unconstrain_reparam, params)):
                    model(*model_args, **model_kwargs)

            log_lik = defaultdict(float)
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in subsample_plate_sizes:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik