コード例 #1
0
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs,
                                                  **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs)
        params = {}
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                inv_transforms[site['name']] = transform
                params[site['name']] = transform.inv(site['value'])

        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)), params)
        return SVIState(self.optim.init(params), rng_key)
コード例 #2
0
ファイル: elbo.py プロジェクト: xidulu/numpyro
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            subs_guide = substitute(seeded_guide, data=param_map)
            guide_trace = trace(subs_guide).get_trace(*args, **kwargs)
            subs_model = substitute(replay(seeded_model, guide_trace), data=param_map)
            model_trace = trace(subs_model).get_trace(*args, **kwargs)
            _check_mean_field_requirement(model_trace, guide_trace)

            elbo_particle = 0
            for name, model_site in model_trace.items():
                if model_site["type"] == "sample":
                    if model_site["is_observed"]:
                        elbo_particle = elbo_particle + _get_log_prob_sum(model_site)
                    else:
                        guide_site = guide_trace[name]
                        try:
                            kl_qp = kl_divergence(guide_site["fn"], model_site["fn"])
                            kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"])
                            elbo_particle = elbo_particle - jnp.sum(kl_qp)
                        except NotImplementedError:
                            elbo_particle = elbo_particle + _get_log_prob_sum(model_site) \
                                - _get_log_prob_sum(guide_site)

            # handle auxiliary sites in the guide
            for name, site in guide_trace.items():
                if site["type"] == "sample" and name not in model_trace:
                    assert site["infer"].get("is_auxiliary")
                    elbo_particle = elbo_particle - _get_log_prob_sum(site)

            return elbo_particle
コード例 #3
0
def get_importance_trace(model, guide, args, kwargs, params):
    """
    (EXPERIMENTAL) Returns traces from the guide and the model that is run against it.
    The returned traces also store the log probability at each site.

    .. note:: Gradients are blocked at latent sites which do not have reparametrized samplers.
    """
    guide = substitute(guide, data=params)
    with _without_rsample_stop_gradient():
        guide_trace = trace(guide).get_trace(*args, **kwargs)
    model = substitute(replay(model, guide_trace), data=params)
    model_trace = trace(model).get_trace(*args, **kwargs)
    for tr in (guide_trace, model_trace):
        for site in tr.values():
            if site["type"] == "sample":
                if "log_prob" not in site:
                    value = site["value"]
                    intermediates = site["intermediates"]
                    scale = site["scale"]
                    if intermediates:
                        log_prob = site["fn"].log_prob(value, intermediates)
                    else:
                        log_prob = site["fn"].log_prob(value)

                    if (scale is not None) and (not is_identically_one(scale)):
                        log_prob = scale * log_prob
                    site["log_prob"] = log_prob
    return model_trace, guide_trace
コード例 #4
0
ファイル: svi.py プロジェクト: vanAmsterdam/numpyro
    def init(self, rng_key, *args, **kwargs):
        """

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple containing initial :data:`SVIState`, and `get_params`, a callable
            that transforms unconstrained parameter values from the optimizer to the
            specified constrained domain
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(*args, **kwargs, **self.static_kwargs)
        params = {}
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                inv_transforms[site['name']] = transform
                params[site['name']] = transform.inv(site['value'])

        self.constrain_fn = partial(transform_fn, inv_transforms)
        return SVIState(self.optim.init(params), rng_key)
コード例 #5
0
ファイル: elbo.py プロジェクト: hessammehr/numpyro
        def single_particle_elbo(rng_key):
            params = param_map.copy()
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(
                seeded_guide, args, kwargs, param_map)
            mutable_params = {
                name: site["value"]
                for name, site in guide_trace.items()
                if site["type"] == "mutable"
            }
            params.update(mutable_params)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, model_trace = log_density(
                seeded_model, args, kwargs, params)
            check_model_guide_match(model_trace, guide_trace)
            _validate_model(model_trace, plate_warning="loose")
            mutable_params.update({
                name: site["value"]
                for name, site in model_trace.items()
                if site["type"] == "mutable"
            })

            # log p(z) - log q(z)
            elbo_particle = model_log_density - guide_log_density
            if mutable_params:
                if self.num_particles == 1:
                    return elbo_particle, mutable_params
                else:
                    raise ValueError(
                        "Currently, we only support mutable states with num_particles=1."
                    )
            else:
                return elbo_particle, None
コード例 #6
0
ファイル: elbo.py プロジェクト: hessammehr/numpyro
        def single_particle_elbo(rng_key):
            params = param_map.copy()
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            subs_guide = substitute(seeded_guide, data=param_map)
            guide_trace = trace(subs_guide).get_trace(*args, **kwargs)
            mutable_params = {
                name: site["value"]
                for name, site in guide_trace.items()
                if site["type"] == "mutable"
            }
            params.update(mutable_params)
            subs_model = substitute(replay(seeded_model, guide_trace),
                                    data=params)
            model_trace = trace(subs_model).get_trace(*args, **kwargs)
            mutable_params.update({
                name: site["value"]
                for name, site in model_trace.items()
                if site["type"] == "mutable"
            })
            check_model_guide_match(model_trace, guide_trace)
            _validate_model(model_trace, plate_warning="loose")
            _check_mean_field_requirement(model_trace, guide_trace)

            elbo_particle = 0
            for name, model_site in model_trace.items():
                if model_site["type"] == "sample":
                    if model_site["is_observed"]:
                        elbo_particle = elbo_particle + _get_log_prob_sum(
                            model_site)
                    else:
                        guide_site = guide_trace[name]
                        try:
                            kl_qp = kl_divergence(guide_site["fn"],
                                                  model_site["fn"])
                            kl_qp = scale_and_mask(kl_qp,
                                                   scale=guide_site["scale"])
                            elbo_particle = elbo_particle - jnp.sum(kl_qp)
                        except NotImplementedError:
                            elbo_particle = (elbo_particle +
                                             _get_log_prob_sum(model_site) -
                                             _get_log_prob_sum(guide_site))

            # handle auxiliary sites in the guide
            for name, site in guide_trace.items():
                if site["type"] == "sample" and name not in model_trace:
                    assert site["infer"].get(
                        "is_auxiliary") or site["is_observed"]
                    elbo_particle = elbo_particle - _get_log_prob_sum(site)

            if mutable_params:
                if self.num_particles == 1:
                    return elbo_particle, mutable_params
                else:
                    raise ValueError(
                        "Currently, we only support mutable states with num_particles=1."
                    )
            else:
                return elbo_particle, None
コード例 #7
0
def elbo(param_map, model, guide, model_args, guide_args, kwargs):
    """
    This is the most basic implementation of the Evidence Lower Bound, which is the
    fundamental objective in Variational Inference. This implementation has various
    limitations (for example it only supports random variablbes with reparameterized
    samplers) but can be used as a template to build more sophisticated loss
    objectives.

    For more details, refer to http://pyro.ai/examples/svi_part_i.html.

    :param dict param_map: dictionary of current parameter values keyed by site
        name.
    :param model: Python callable with Pyro primitives for the model.
    :param guide: Python callable with Pyro primitives for the guide
        (recognition network).
    :param tuple model_args: arguments to the model (these can possibly vary during
        the course of fitting).
    :param tuple guide_args: arguments to the guide (these can possibly vary during
        the course of fitting).
    :param dict kwargs: static keyword arguments to the model / guide.
    :return: negative of the Evidence Lower Bound (ELBo) to be minimized.
    """
    guide_log_density, guide_trace = log_density(guide, guide_args, kwargs,
                                                 param_map)
    model_log_density, _ = log_density(replay(model, guide_trace), model_args,
                                       kwargs, param_map)
    # log p(z) - log q(z)
    elbo = model_log_density - guide_log_density
    # Return (-elbo) since by convention we do gradient descent on a loss and
    # the ELBO is a lower bound that needs to be maximized.
    return -elbo
コード例 #8
0
ファイル: svi.py プロジェクト: leej35/numpyro
def get_param(opt_state,
              model,
              guide,
              get_params,
              constrain_fn,
              rng,
              model_args=None,
              guide_args=None,
              **kwargs):
    params = constrain_fn(get_params(opt_state))
    model, guide = _seed(model, guide, rng)
    if guide_args is not None:
        guide = substitute(guide, base_param_map=params)
        guide_trace = trace(guide).get_trace(*guide_args, **kwargs)
        model_params = {
            k: v
            for k, v in params.items() if k not in guide_trace
        }
        params = {
            k: guide_trace[k]['value'] if k in guide_trace else v
            for k, v in params.items()
        }

        if model_args is not None:
            model = substitute(replay(model, guide_trace),
                               base_param_map=model_params)
            model_trace = trace(model).get_trace(*model_args, **kwargs)
            params = {
                k: model_trace[k]['value'] if k in model_params else v
                for k, v in params.items()
            }

    return params
コード例 #9
0
ファイル: svi.py プロジェクト: ColCarroll/numpyro
def elbo(param_map, model, guide, model_args, guide_args, kwargs):
    guide_log_density, guide_trace = log_density(guide, guide_args, kwargs, param_map)
    model_log_density, _ = log_density(replay(model, guide_trace), model_args, kwargs, param_map)
    # log p(z) - log q(z)
    elbo = model_log_density - guide_log_density
    # Return (-elbo) since by convention we do gradient descent on a loss and
    # the ELBO is a lower bound that needs to be maximized.
    return -elbo
コード例 #10
0
ファイル: svi.py プロジェクト: leej35/numpyro
def elbo(param_map,
         model,
         guide,
         model_args,
         guide_args,
         kwargs,
         constrain_fn,
         is_autoguide=False):
    """
    This is the most basic implementation of the Evidence Lower Bound, which is the
    fundamental objective in Variational Inference. This implementation has various
    limitations (for example it only supports random variablbes with reparameterized
    samplers) but can be used as a template to build more sophisticated loss
    objectives.

    For more details, refer to http://pyro.ai/examples/svi_part_i.html.

    :param dict param_map: dictionary of current parameter values keyed by site
        name.
    :param model: Python callable with Pyro primitives for the model.
    :param guide: Python callable with Pyro primitives for the guide
        (recognition network).
    :param tuple model_args: arguments to the model (these can possibly vary during
        the course of fitting).
    :param tuple guide_args: arguments to the guide (these can possibly vary during
        the course of fitting).
    :param dict kwargs: static keyword arguments to the model / guide.
    :param constrain_fn: a callable that transforms unconstrained parameter values
        from the optimizer to the specified constrained domain.
    :return: negative of the Evidence Lower Bound (ELBo) to be minimized.
    """
    param_map = constrain_fn(param_map)
    guide_log_density, guide_trace = log_density(guide, guide_args, kwargs,
                                                 param_map)
    if is_autoguide:
        # in autoguide, a site's value holds intermediate value
        for name, site in guide_trace.items():
            if site['type'] == 'sample':
                param_map[name] = site['value']
    else:
        # NB: we only want to substitute params not available in guide_trace
        param_map = {
            k: v
            for k, v in param_map.items() if k not in guide_trace
        }
        model = replay(model, guide_trace)
    model_log_density, _ = log_density(model,
                                       model_args,
                                       kwargs,
                                       param_map,
                                       skip_dist_transforms=is_autoguide)
    # log p(z) - log q(z)
    elbo = model_log_density - guide_log_density
    # Return (-elbo) since by convention we do gradient descent on a loss and
    # the ELBO is a lower bound that needs to be maximized.
    return -elbo
コード例 #11
0
ファイル: test_infer_util.py プロジェクト: mjbajwa/numpyro
def test_get_mask_optimization():
    def model():
        with numpyro.handlers.seed(rng_seed=0):
            x = numpyro.sample("x", dist.Normal(0, 1))
            numpyro.sample("y", dist.Normal(x, 1), obs=0.)
            called.add("model-always")
            if numpyro.get_mask() is not False:
                called.add("model-sometimes")
                numpyro.factor("f", x + 1)

    def guide():
        with numpyro.handlers.seed(rng_seed=1):
            x = numpyro.sample("x", dist.Normal(0, 1))
            called.add("guide-always")
            if numpyro.get_mask() is not False:
                called.add("guide-sometimes")
                numpyro.factor("g", 2 - x)

    called = set()
    trace = handlers.trace(guide).get_trace()
    handlers.replay(model, trace)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" in called
    assert "guide-sometimes" in called

    called = set()
    with handlers.mask(mask=False):
        trace = handlers.trace(guide).get_trace()
        handlers.replay(model, trace)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" not in called
    assert "guide-sometimes" not in called

    called = set()
    Predictive(model, guide=guide, num_samples=2,
               parallel=True)(random.PRNGKey(2))
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" not in called
    assert "guide-sometimes" not in called
コード例 #12
0
ファイル: elbo.py プロジェクト: xidulu/numpyro
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo
コード例 #13
0
ファイル: test_handlers.py プロジェクト: mhashemi0873/numpyro
def test_subsample_replay():
    data = jnp.arange(100.)
    subsample_size = 7

    with handlers.trace() as guide_trace, handlers.seed(rng_seed=0):
        with numpyro.plate("a", len(data), subsample_size=subsample_size):
            pass

    with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace):
        with numpyro.plate("a", len(data)):
            subsample_data = numpyro.subsample(data, event_dim=0)
            assert subsample_data.shape == (subsample_size, )
コード例 #14
0
ファイル: svi.py プロジェクト: dirmeier/numpyro
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs
        )
        params = {}
        inv_transforms = {}
        mutable_state = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site["type"] == "param":
                constraint = site["kwargs"].pop("constraint", constraints.real)
                with helpful_support_errors(site):
                    transform = biject_to(constraint)
                inv_transforms[site["name"]] = transform
                params[site["name"]] = transform.inv(site["value"])
            elif site["type"] == "mutable":
                mutable_state[site["name"]] = site["value"]
            elif (
                site["type"] == "sample"
                and (not site["is_observed"])
                and site["fn"].support.is_discrete
                and not self.loss.can_infer_discrete
            ):
                s_name = type(self.loss).__name__
                warnings.warn(
                    f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
                )

        if not mutable_state:
            mutable_state = None
        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params, mutable_state = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)),
            (params, mutable_state),
        )
        return SVIState(self.optim.init(params), mutable_state, rng_key)
コード例 #15
0
ファイル: elbo.py プロジェクト: ziatdinovmax/numpyro
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            # NB: we only want to substitute params not available in guide_trace
            model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace}
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo
コード例 #16
0
    def retrace(self, name, tr, dist_proposal, val_proposal, model_args,
                model_kwargs):
        fn_current = tr[name]["fn"]
        val_current = tr[name]["value"]

        tr[name]["fn"] = dist_proposal
        tr[name]["value"] = val_proposal

        second_trace = trace(replay(self.model,
                                    tr)).get_trace(*model_args, **model_kwargs)

        tr[name]["fn"] = fn_current
        tr[name]["value"] = val_current

        return second_trace
コード例 #17
0
def test_compute_downstream_costs_plate_reuse(dim1, dim2):
    seeded_guide = handlers.seed(plate_reuse_model_guide, rng_seed=0)
    guide_trace = handlers.trace(seeded_guide).get_trace(include_obs=False,
                                                         dim1=dim1,
                                                         dim2=dim2)
    model_trace = handlers.trace(handlers.replay(
        seeded_guide, guide_trace)).get_trace(include_obs=True,
                                              dim1=dim1,
                                              dim2=dim2)

    for trace in (model_trace, guide_trace):
        for site in trace.values():
            if site["type"] == "sample":
                site["log_prob"] = site["fn"].log_prob(site["value"])
    non_reparam_nodes = set(
        name for name, site in guide_trace.items()
        if site["type"] == "sample" and (
            site["is_observed"] or not site["fn"].has_rsample))

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)

    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert guide_trace[k]["log_prob"].shape == dc[k].shape
        assert_allclose(dc[k], dc_brute[k], rtol=1e-6)

    expected_c1 = model_trace["c1"]["log_prob"] - guide_trace["c1"]["log_prob"]
    expected_c1 += (model_trace["b1"]["log_prob"] -
                    guide_trace["b1"]["log_prob"]).sum()
    expected_c1 += model_trace["c2"]["log_prob"] - guide_trace["c2"]["log_prob"]
    expected_c1 += model_trace["obs"]["log_prob"]
    assert_allclose(expected_c1, dc["c1"], rtol=1e-6)
コード例 #18
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with block(), enum(first_available_dim=first_available_dim):
        with plate_to_enum_plate():
            model_tr = packed_trace(model).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            list(terms["log_factors"].values()) +
            list(terms["log_measures"].values()),
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[terms["log_measures"][name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    with replay(guide_trace=sample_tr):
        return model(*args, **kwargs)
コード例 #19
0
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1,
                                                       include_single,
                                                       flip_c23,
                                                       include_triple,
                                                       include_z1):
    seeded_guide = handlers.seed(big_model_guide, rng_seed=0)
    guide_trace = handlers.trace(seeded_guide).get_trace(
        include_obs=False,
        include_inner_1=include_inner_1,
        include_single=include_single,
        flip_c23=flip_c23,
        include_triple=include_triple,
        include_z1=include_z1,
    )
    model_trace = handlers.trace(handlers.replay(
        seeded_guide, guide_trace)).get_trace(
            include_obs=True,
            include_inner_1=include_inner_1,
            include_single=include_single,
            flip_c23=flip_c23,
            include_triple=include_triple,
            include_z1=include_z1,
        )

    for trace in (model_trace, guide_trace):
        for site in trace.values():
            if site["type"] == "sample":
                site["log_prob"] = site["fn"].log_prob(site["value"])
    non_reparam_nodes = set(
        name for name, site in guide_trace.items()
        if site["type"] == "sample" and (
            site["is_observed"] or not site["fn"].has_rsample))

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)

    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_nodes_full_model = {
        "a1": {"c2", "a1", "d1", "c1", "obs", "b1", "d2", "c3", "b0"},
        "d2": {"obs", "d2"},
        "d1": {"obs", "d1", "d2"},
        "c3": {"d2", "obs", "d1", "c3"},
        "b0": {"b0", "d1", "c1", "obs", "b1", "d2", "c3", "c2"},
        "b1": {"obs", "b1", "d1", "d2", "c3", "c1", "c2"},
        "c1": {"d1", "c1", "obs", "d2", "c3", "c2"},
        "c2": {"obs", "d1", "c3", "d2", "c2"},
    }
    if not include_triple and include_inner_1 and include_single and not flip_c23:
        assert dc_nodes == expected_nodes_full_model

    expected_b1 = model_trace["b1"]["log_prob"] - guide_trace["b1"]["log_prob"]
    expected_b1 += (model_trace["d2"]["log_prob"] -
                    guide_trace["d2"]["log_prob"]).sum(0)
    expected_b1 += (model_trace["d1"]["log_prob"] -
                    guide_trace["d1"]["log_prob"]).sum(0)
    expected_b1 += model_trace["obs"]["log_prob"].sum(0, keepdims=False)
    if include_inner_1:
        expected_b1 += (model_trace["c1"]["log_prob"] -
                        guide_trace["c1"]["log_prob"]).sum(0)
        expected_b1 += (model_trace["c2"]["log_prob"] -
                        guide_trace["c2"]["log_prob"]).sum(0)
        expected_b1 += (model_trace["c3"]["log_prob"] -
                        guide_trace["c3"]["log_prob"]).sum(0)
    assert_allclose(expected_b1, dc["b1"], atol=1.0e-6)

    if include_single:
        expected_b0 = model_trace["b0"]["log_prob"] - guide_trace["b0"][
            "log_prob"]
        expected_b0 += (model_trace["b1"]["log_prob"] -
                        guide_trace["b1"]["log_prob"]).sum()
        expected_b0 += (model_trace["d2"]["log_prob"] -
                        guide_trace["d2"]["log_prob"]).sum()
        expected_b0 += (model_trace["d1"]["log_prob"] -
                        guide_trace["d1"]["log_prob"]).sum()
        expected_b0 += model_trace["obs"]["log_prob"].sum()
        if include_inner_1:
            expected_b0 += (model_trace["c1"]["log_prob"] -
                            guide_trace["c1"]["log_prob"]).sum()
            expected_b0 += (model_trace["c2"]["log_prob"] -
                            guide_trace["c2"]["log_prob"]).sum()
            expected_b0 += (model_trace["c3"]["log_prob"] -
                            guide_trace["c3"]["log_prob"]).sum()
        assert_allclose(expected_b0, dc["b0"], atol=1.0e-6)
        assert dc["b0"].shape == (5, )

    if include_inner_1:
        expected_c3 = model_trace["c3"]["log_prob"] - guide_trace["c3"][
            "log_prob"]
        expected_c3 += (model_trace["d1"]["log_prob"] -
                        guide_trace["d1"]["log_prob"]).sum(0)
        expected_c3 += (model_trace["d2"]["log_prob"] -
                        guide_trace["d2"]["log_prob"]).sum(0)
        expected_c3 += model_trace["obs"]["log_prob"].sum(0)

        expected_c2 = model_trace["c2"]["log_prob"] - guide_trace["c2"][
            "log_prob"]
        expected_c2 += (model_trace["d1"]["log_prob"] -
                        guide_trace["d1"]["log_prob"]).sum(0)
        expected_c2 += (model_trace["d2"]["log_prob"] -
                        guide_trace["d2"]["log_prob"]).sum(0)
        expected_c2 += model_trace["obs"]["log_prob"].sum(0)

        expected_c1 = model_trace["c1"]["log_prob"] - guide_trace["c1"][
            "log_prob"]

        if flip_c23:
            expected_c3 += model_trace["c2"]["log_prob"] - guide_trace["c2"][
                "log_prob"]
            expected_c2 += model_trace["c3"]["log_prob"]
        else:
            expected_c2 += model_trace["c3"]["log_prob"] - guide_trace["c3"][
                "log_prob"]
            expected_c2 += model_trace["c2"]["log_prob"] - guide_trace["c2"][
                "log_prob"]
        expected_c1 += expected_c3

        assert_allclose(expected_c1, dc["c1"], atol=1.0e-6)
        assert_allclose(expected_c2, dc["c2"], atol=1.0e-6)
        assert_allclose(expected_c3, dc["c3"], atol=1.0e-6)

    expected_d1 = model_trace["d1"]["log_prob"] - guide_trace["d1"]["log_prob"]
    expected_d1 += model_trace["d2"]["log_prob"] - guide_trace["d2"]["log_prob"]
    expected_d1 += model_trace["obs"]["log_prob"]

    expected_d2 = model_trace["d2"]["log_prob"] - guide_trace["d2"]["log_prob"]
    expected_d2 += model_trace["obs"]["log_prob"]

    if include_triple:
        expected_z0 = (dc["a1"] + model_trace["z0"]["log_prob"] -
                       guide_trace["z0"]["log_prob"])
        assert_allclose(expected_z0, dc["z0"], atol=1.0e-6)
    assert_allclose(expected_d2, dc["d2"], atol=1.0e-6)
    assert_allclose(expected_d1, dc["d1"], atol=1.0e-6)

    assert dc["b1"].shape == (2, )
    assert dc["d2"].shape == (4, 2)

    for k in dc:
        assert guide_trace[k]["log_prob"].shape == dc[k].shape
        assert_allclose(dc[k], dc_brute[k], rtol=2e-7)