Пример #1
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with block(), enum(first_available_dim=first_available_dim):
        with plate_to_enum_plate():
            model_tr = packed_trace(model).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            list(terms["log_factors"].values()) +
            list(terms["log_measures"].values()),
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[terms["log_measures"][name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    with replay(guide_trace=sample_tr):
        return model(*args, **kwargs)
Пример #2
0
def log_density(model, model_args, model_kwargs, params):
    """
    Similar to :func:`numpyro.infer.util.log_density` but works for models
    with discrete latent variables. Internally, this uses :mod:`funsor`
    to marginalize discrete latent sites and evaluate the joint log probability.

    :param model: Python callable containing NumPyro primitives. Typically,
        the model has been enumerated by using
        :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

            def model(*args, **kwargs):
                ...

            log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args,
                                                    **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # _init/... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            scale = site['scale']
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob = funsor.to_funsor(log_prob,
                                        output=funsor.reals(),
                                        dim_to_name=dim_to_name)

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(
                        name, funsor.domains.bint(site["value"].shape[dim]))
                    time_to_factors[time_dim].append(log_prob)
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values()
                        if s.startswith("_init"))
                    break
            if time_dim is None:
                log_factors.append(log_prob)

            if not site['is_observed']:
                sum_vars |= frozenset({site['name']})
            prod_vars |= frozenset(f.name for f in site['cond_indep_stack']
                                   if f.dim is not None)

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = "/".join(var.split("/")[1:])
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values(
            ):  # i.e. _init (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values())

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(time_to_factors,
                                                time_to_init_vars,
                                                time_to_markov_dims, sum_vars,
                                                prod_vars)
        log_factors = log_factors + markov_factors

    with funsor.interpreter.interpretation(funsor.terms.lazy):
        lazy_result = funsor.sum_product.sum_product(funsor.ops.logaddexp,
                                                     funsor.ops.add,
                                                     log_factors,
                                                     eliminate=sum_vars
                                                     | prod_vars,
                                                     plates=prod_vars)
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".
            format(result.data.shape,
                   {k.split("__BOUND")[0]
                    for k in result.inputs}))
    return result.data, model_trace
Пример #3
0
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
    """Helper function to compute elbo and extract its components from execution traces."""
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # PP... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    history = 1
    log_measures = {}
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob_factor = funsor.to_funsor(
                log_prob, output=funsor.Real, dim_to_name=dim_to_name
            )

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]])
                    time_to_factors[time_dim].append(log_prob_factor)
                    history = max(
                        history, max(_get_shift(s) for s in dim_to_name.values())
                    )
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values() if s.startswith("_PREV_")
                    )
                    break
            if time_dim is None:
                log_factors.append(log_prob_factor)

            if not site["is_observed"]:
                log_measures[site["name"]] = log_prob_factor
                sum_vars |= frozenset({site["name"]})

            prod_vars |= frozenset(
                f.name for f in site["cond_indep_stack"] if f.dim is not None
            )

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = _shift_name(var, -_get_shift(var))
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values():  # i.e. _PREV_* (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values()
                )

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(
            time_to_factors,
            time_to_init_vars,
            time_to_markov_dims,
            sum_vars,
            prod_vars,
            history,
            sum_op,
            prod_op,
        )
        log_factors = log_factors + markov_factors

    with funsor.interpretations.lazy:
        lazy_result = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            log_factors,
            eliminate=sum_vars | prod_vars,
            plates=prod_vars,
        )
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".format(
                result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
            )
        )
    return result, model_trace, log_measures