Пример #1
0
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(
                seeded_guide, args, kwargs, param_map)
            # first, we substitute `param_map` to `param` primitives of `model`
            seeded_model = substitute(seeded_model, param_map)
            # then creates a new `param_map` which holds base values of `sample` primitives
            base_param_map = {}
            # in autoguide, a site's value holds intermediate value
            for name, site in guide_trace.items():
                if site['type'] == 'sample':
                    base_param_map[name] = site['value']
            model_log_density, _ = log_density(seeded_model,
                                               args,
                                               kwargs,
                                               base_param_map,
                                               skip_dist_transforms=True)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            # Return (-elbo) since by convention we do gradient descent on a loss and
            # the ELBO is a lower bound that needs to be maximized.
            return -elbo
Пример #2
0
        def single_particle_elbo(rng_key):
            params = param_map.copy()
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(
                seeded_guide, args, kwargs, param_map)
            mutable_params = {
                name: site["value"]
                for name, site in guide_trace.items()
                if site["type"] == "mutable"
            }
            params.update(mutable_params)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, model_trace = log_density(
                seeded_model, args, kwargs, params)
            check_model_guide_match(model_trace, guide_trace)
            _validate_model(model_trace, plate_warning="loose")
            mutable_params.update({
                name: site["value"]
                for name, site in model_trace.items()
                if site["type"] == "mutable"
            })

            # log p(z) - log q(z)
            elbo_particle = model_log_density - guide_log_density
            if mutable_params:
                if self.num_particles == 1:
                    return elbo_particle, mutable_params
                else:
                    raise ValueError(
                        "Currently, we only support mutable states with num_particles=1."
                    )
            else:
                return elbo_particle, None
Пример #3
0
def test_mask(mask_last, use_jit):
    N = 10
    mask = np.ones(N, dtype=np.bool)
    mask[-mask_last] = 0

    def model(data, mask):
        with numpyro.plate('N', N):
            x = numpyro.sample('x', dist.Normal(0, 1))
            with handlers.mask(mask=mask):
                numpyro.sample('y', dist.Delta(x, log_density=1.))
                with handlers.scale(scale=2):
                    numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    data = random.normal(random.PRNGKey(0), (N, ))
    x = random.normal(random.PRNGKey(1), (N, ))
    if use_jit:
        log_joint = jit(lambda *args: log_density(*args)[0],
                        static_argnums=(0, ))(model, (data, mask), {}, {
                            'x': x,
                            'y': x
                        })
    else:
        log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
    log_prob_x = dist.Normal(0, 1).log_prob(x)
    log_prob_y = mask
    log_prob_z = dist.Normal(x, 1).log_prob(data)
    expected = (log_prob_x +
                jnp.where(mask, log_prob_y + 2 * log_prob_z, 0.)).sum()
    assert_allclose(log_joint, expected, atol=1e-4)
Пример #4
0
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo
Пример #5
0
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            # NB: we only want to substitute params not available in guide_trace
            model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace}
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo
Пример #6
0
def test_mask_inf():
    def model():
        with handlers.mask(mask=jnp.zeros(10, dtype=bool)):
            numpyro.factor('inf', -jnp.inf)

    log_joint = log_density(model, (), {}, {})[0]
    assert_allclose(log_joint, 0.)
Пример #7
0
def test_estimate_likelihood(kernel_cls):
    data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4)
    ref_params = jnp.array([0.1, 0.5, -0.2])
    sigma = 0.1
    data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        data_key, (10_000,)
    )
    n, _ = data.shape
    num_warmup = 200
    num_samples = 200
    num_blocks = 20

    def model(data):
        mean = numpyro.sample(
            "mean", dist.Normal(ref_params, jnp.ones_like(ref_params))
        )
        with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx:
            numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx])

    proxy_fn = HMCECS.taylor_proxy({"mean": ref_params})
    kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)

    mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"])

    pes = mcmc.get_extra_fields()["hmc_state.potential_energy"]
    samples = mcmc.get_samples()
    pes_full = vmap(
        lambda sample: log_density(
            model, (data,), {}, {**sample, **{"N": jnp.arange(n)}}
        )[0]
    )(samples)

    assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
Пример #8
0
 def log_prior(params):
     with warnings.catch_warnings():
         warnings.filterwarnings("ignore", category=UserWarning)
         dummy_subsample = {
             k: jnp.array([], dtype=jnp.int32)
             for k in subsample_plate_sizes
         }
         with block(), substitute(data=dummy_subsample):
             prior_prob, _ = log_density(model, model_kwargs, params)
     return prior_prob
Пример #9
0
    def run_minibatch_test_for_batch_size(self, batch_size):
        batch = self.X[:batch_size]
        self.assertEqual(batch_size, jnp.shape(batch)[0])

        prior_log_prob = dist.Normal(1.).log_prob(self.mu)
        data_log_prob = jnp.sum(dist.Normal(self.mu).log_prob(batch))
        expected_log_joint = prior_log_prob + (self.num_samples /
                                               batch_size) * data_log_prob

        log_joint, _ = log_density(self.model, (batch, ),
                                   {'num_obs_total': self.num_samples},
                                   {"theta": self.mu})
        self.assertTrue(jnp.allclose(expected_log_joint, log_joint))
Пример #10
0
    def sample(self):
        rng_key, rng_key_sample, rng_key_accept = split(
            self.nmc_status.rng_key, 3)
        params = self.nmc_status.params

        for site in params.keys():
            # Collect accepted trace
            for i in range(len(params[site])):
                self.acc_trace[site + str(i)].append(params[site][i])

            tr_current = trace(substitute(self.model, params)).get_trace(
                *self.model_args, **self.model_kwargs)
            ll_current = self.nmc_status.log_likelihood

            val_current = tr_current[site]["value"]
            dist_curr = tr_current[site]["fn"]

            def log_den_fun(var):
                return partial(log_density, self.model, self.model_args,
                               self.model_kwargs)(var)[0]

            val_proposal, dist_proposal = self.proposal(
                site, log_den_fun, self.get_params(tr_current), dist_curr,
                rng_key_sample)

            tr_proposal = self.retrace(site, tr_current, dist_proposal,
                                       val_proposal, self.model_args,
                                       self.model_kwargs)
            ll_proposal = log_density(self.model, self.model_args,
                                      self.model_kwargs,
                                      self.get_params(tr_proposal))[0]

            ll_proposal_val = dist_proposal.log_prob(val_current).sum()
            ll_current_val = dist_curr.log_prob(val_proposal).sum()

            hastings_ratio = (ll_proposal + ll_proposal_val) - \
                (ll_current + ll_current_val)

            accept_prob = np.minimum(1, np.exp(hastings_ratio))
            u = sample("u", dist.Uniform(0, 1), rng_key=rng_key_accept)

            if u <= accept_prob:
                params, ll_current = self.get_params(tr_proposal), ll_proposal
            else:
                params, ll_current = self.get_params(tr_current), ll_current

        iter = self.nmc_status.i + 1
        mean_accept_prob = self.nmc_status.accept_prob + \
            (accept_prob - self.nmc_status.accept_prob) / iter

        return NMC_STATUS(iter, params, ll_current, mean_accept_prob, rng_key)
Пример #11
0
def test_scale(use_context_manager):
    def model(data):
        x = numpyro.sample('x', dist.Normal(0, 1))
        with optional(use_context_manager, handlers.scale(scale=10)):
            numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    model = model if use_context_manager else handlers.scale(model, 10.)
    data = random.normal(random.PRNGKey(0), (3, ))
    x = random.normal(random.PRNGKey(1))
    log_joint = log_density(model, (data, ), {}, {'x': x})[0]
    log_prob1, log_prob2 = dist.Normal(0, 1).log_prob(x), dist.Normal(
        x, 1).log_prob(data).sum()
    expected = log_prob1 + 10 * log_prob2 if use_context_manager else 10 * (
        log_prob1 + log_prob2)
    assert_allclose(log_joint, expected)
Пример #12
0
    def __init__(self, model, *model_args, rng_key=PRNGKey(0), **model_kwargs):
        self.model = model
        self.model_args = model_args
        self.rng_key = rng_key
        self.model_kwargs = model_kwargs

        tr = trace(model).get_trace(model_args)
        log_likelihood = log_density(self.model, self.model_args,
                                     self.model_kwargs, self.get_params(tr))[0]

        self.nmc_status = NMC_STATUS(i=0,
                                     params=self.get_params(tr),
                                     log_likelihood=log_likelihood,
                                     accept_prob=0.,
                                     rng_key=rng_key)

        self.props = {}
        self.acc_trace = {}
        self.init_trace()