示例#1
0
def _create_traces_check_error_string(model, guide, expected_string):
    model_trace = numpyro.handlers.trace(
        numpyro.handlers.seed(model, rng_seed=42)).get_trace()
    guide_trace = numpyro.handlers.trace(
        numpyro.handlers.seed(guide, rng_seed=42)).get_trace()
    with pytest.raises(ValueError, match=expected_string):
        check_model_guide_match(model_trace, guide_trace)
示例#2
0
        def single_particle_elbo(rng_key):
            params = param_map.copy()
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            subs_guide = substitute(seeded_guide, data=param_map)
            guide_trace = trace(subs_guide).get_trace(*args, **kwargs)
            mutable_params = {
                name: site["value"]
                for name, site in guide_trace.items()
                if site["type"] == "mutable"
            }
            params.update(mutable_params)
            subs_model = substitute(replay(seeded_model, guide_trace),
                                    data=params)
            model_trace = trace(subs_model).get_trace(*args, **kwargs)
            mutable_params.update({
                name: site["value"]
                for name, site in model_trace.items()
                if site["type"] == "mutable"
            })
            check_model_guide_match(model_trace, guide_trace)
            _validate_model(model_trace, plate_warning="loose")
            _check_mean_field_requirement(model_trace, guide_trace)

            elbo_particle = 0
            for name, model_site in model_trace.items():
                if model_site["type"] == "sample":
                    if model_site["is_observed"]:
                        elbo_particle = elbo_particle + _get_log_prob_sum(
                            model_site)
                    else:
                        guide_site = guide_trace[name]
                        try:
                            kl_qp = kl_divergence(guide_site["fn"],
                                                  model_site["fn"])
                            kl_qp = scale_and_mask(kl_qp,
                                                   scale=guide_site["scale"])
                            elbo_particle = elbo_particle - jnp.sum(kl_qp)
                        except NotImplementedError:
                            elbo_particle = (elbo_particle +
                                             _get_log_prob_sum(model_site) -
                                             _get_log_prob_sum(guide_site))

            # handle auxiliary sites in the guide
            for name, site in guide_trace.items():
                if site["type"] == "sample" and name not in model_trace:
                    assert site["infer"].get(
                        "is_auxiliary") or site["is_observed"]
                    elbo_particle = elbo_particle - _get_log_prob_sum(site)

            if mutable_params:
                if self.num_particles == 1:
                    return elbo_particle, mutable_params
                else:
                    raise ValueError(
                        "Currently, we only support mutable states with num_particles=1."
                    )
            else:
                return elbo_particle, None
示例#3
0
        def single_particle_elbo(rng_key):
            params = param_map.copy()
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(
                seeded_guide, args, kwargs, param_map)
            mutable_params = {
                name: site["value"]
                for name, site in guide_trace.items()
                if site["type"] == "mutable"
            }
            params.update(mutable_params)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, model_trace = log_density(
                seeded_model, args, kwargs, params)
            check_model_guide_match(model_trace, guide_trace)
            _validate_model(model_trace, plate_warning="loose")
            mutable_params.update({
                name: site["value"]
                for name, site in model_trace.items()
                if site["type"] == "mutable"
            })

            # log p(z) - log q(z)
            elbo_particle = model_log_density - guide_log_density
            if mutable_params:
                if self.num_particles == 1:
                    return elbo_particle, mutable_params
                else:
                    raise ValueError(
                        "Currently, we only support mutable states with num_particles=1."
                    )
            else:
                return elbo_particle, None
示例#4
0
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(
                seeded_guide, args, kwargs, param_map)
            # NB: we only want to substitute params not available in guide_trace
            model_param_map = {
                k: v
                for k, v in param_map.items() if k not in guide_trace
            }
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, model_trace = log_density(
                seeded_model, args, kwargs, model_param_map)
            check_model_guide_match(model_trace, guide_trace)
            _validate_model(model_trace, plate_warning="loose")

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo
示例#5
0
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            model_trace, guide_trace = get_importance_trace(
                seeded_model, seeded_guide, args, kwargs, param_map)
            check_model_guide_match(model_trace, guide_trace)
            _validate_model(model_trace, plate_warning="strict")

            # XXX: different from Pyro, we don't support baseline_loss here
            non_reparam_nodes = {
                name
                for name, site in guide_trace.items()
                if site["type"] == "sample" and (not site["is_observed"]) and (
                    not site["fn"].has_rsample)
            }
            if non_reparam_nodes:
                downstream_costs, _ = _compute_downstream_costs(
                    model_trace, guide_trace, non_reparam_nodes)

            elbo = 0.0
            for site in model_trace.values():
                if site["type"] == "sample":
                    elbo = elbo + jnp.sum(site["log_prob"])
            for name, site in guide_trace.items():
                if site["type"] == "sample":
                    log_prob_sum = jnp.sum(site["log_prob"])
                    if name in non_reparam_nodes:
                        surrogate = jnp.sum(
                            site["log_prob"] *
                            stop_gradient(downstream_costs[name]))
                        log_prob_sum = (
                            stop_gradient(log_prob_sum + surrogate) -
                            surrogate)
                    elbo = elbo - log_prob_sum

            return elbo