Esempio n. 1
0
    def __call__(self, name, fn, obs):
        assert obs is None, "LocScaleReparam does not support observe statements"
        centered = self.centered
        if is_identically_one(centered):
            return name, fn, obs
        event_shape = fn.event_shape
        fn, batch_shape, event_dim = self._unwrap(fn)

        # Apply a partial decentering transform.
        params = {key: getattr(fn, key) for key in self.shape_params}
        if self.centered is None:
            centered = numpyro.param("{}_centered".format(name),
                                     jnp.full(event_shape, 0.5),
                                     constraint=constraints.unit_interval)
        params["loc"] = fn.loc * centered
        params["scale"] = fn.scale**centered
        decentered_fn = self._wrap(type(fn)(**params), batch_shape, event_dim)

        # Draw decentered noise.
        decentered_value = numpyro.sample("{}_decentered".format(name),
                                          decentered_fn)

        # Differentiably transform.
        delta = decentered_value - centered * fn.loc
        value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta

        # Simulate a pyro.deterministic() site.
        return None, value
Esempio n. 2
0
def log_density(model, model_args, model_kwargs, params):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, data=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = jnp.zeros(())
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace
Esempio n. 3
0
def get_importance_trace(model, guide, args, kwargs, params):
    """
    (EXPERIMENTAL) Returns traces from the guide and the model that is run against it.
    The returned traces also store the log probability at each site.

    .. note:: Gradients are blocked at latent sites which do not have reparametrized samplers.
    """
    guide = substitute(guide, data=params)
    with _without_rsample_stop_gradient():
        guide_trace = trace(guide).get_trace(*args, **kwargs)
    model = substitute(replay(model, guide_trace), data=params)
    model_trace = trace(model).get_trace(*args, **kwargs)
    for tr in (guide_trace, model_trace):
        for site in tr.values():
            if site["type"] == "sample":
                if "log_prob" not in site:
                    value = site["value"]
                    intermediates = site["intermediates"]
                    scale = site["scale"]
                    if intermediates:
                        log_prob = site["fn"].log_prob(value, intermediates)
                    else:
                        log_prob = site["fn"].log_prob(value)

                    if (scale is not None) and (not is_identically_one(scale)):
                        log_prob = scale * log_prob
                    site["log_prob"] = log_prob
    return model_trace, guide_trace
Esempio n. 4
0
def log_density(model, model_args, model_kwargs, params):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, param_map=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = jnp.array(0.)
    for site in model_trace.values():
        if site['type'] == 'sample' and not isinstance(site['fn'],
                                                       dist.PRNGIdentity):
            value = site['value']
            intermediates = site['intermediates']
            scale = site['scale']
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace
Esempio n. 5
0
def log_prob_sum(trace):
    log_joint = jnp.zeros(())
    for site in trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint
Esempio n. 6
0
def log_density(model, model_args, model_kwargs, params):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, data=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = jnp.zeros(())
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                guide_shape = jnp.shape(value)
                model_shape = tuple(site["fn"].shape(
                ))  # TensorShape from tfp needs casting to tuple
                try:
                    broadcast_shapes(guide_shape, model_shape)
                except ValueError:
                    raise ValueError(
                        "Model and guide shapes disagree at site: '{}': {} vs {}"
                        .format(site["name"], model_shape, guide_shape))
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace
Esempio n. 7
0
def terms_from_trace(tr):
    """Helper function to extract elbo components from execution traces."""
    log_factors = {}
    log_measures = {}
    sum_vars, prod_vars = frozenset(), frozenset()
    for site in tr.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob_factor = funsor.to_funsor(log_prob,
                                               output=funsor.Real,
                                               dim_to_name=dim_to_name)

            if site["is_observed"]:
                log_factors[site["name"]] = log_prob_factor
            else:
                log_measures[site["name"]] = log_prob_factor
                sum_vars |= frozenset({site["name"]})
            prod_vars |= frozenset(f.name for f in site["cond_indep_stack"]
                                   if f.dim is not None)

    return {
        "log_factors": log_factors,
        "log_measures": log_measures,
        "measure_vars": sum_vars,
        "plate_vars": prod_vars,
    }
Esempio n. 8
0
def log_density(model, model_args, model_kwargs, params):
    """
    Similar to :func:`numpyro.infer.util.log_density` but works for models
    with discrete latent variables. Internally, this uses :mod:`funsor`
    to marginalize discrete latent sites and evaluate the joint log probability.

    :param model: Python callable containing NumPyro primitives. Typically,
        the model has been enumerated by using
        :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

            def model(*args, **kwargs):
                ...

            log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args,
                                                    **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # _init/... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            scale = site['scale']
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob = funsor.to_funsor(log_prob,
                                        output=funsor.reals(),
                                        dim_to_name=dim_to_name)

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(
                        name, funsor.domains.bint(site["value"].shape[dim]))
                    time_to_factors[time_dim].append(log_prob)
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values()
                        if s.startswith("_init"))
                    break
            if time_dim is None:
                log_factors.append(log_prob)

            if not site['is_observed']:
                sum_vars |= frozenset({site['name']})
            prod_vars |= frozenset(f.name for f in site['cond_indep_stack']
                                   if f.dim is not None)

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = "/".join(var.split("/")[1:])
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values(
            ):  # i.e. _init (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values())

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(time_to_factors,
                                                time_to_init_vars,
                                                time_to_markov_dims, sum_vars,
                                                prod_vars)
        log_factors = log_factors + markov_factors

    with funsor.interpreter.interpretation(funsor.terms.lazy):
        lazy_result = funsor.sum_product.sum_product(funsor.ops.logaddexp,
                                                     funsor.ops.add,
                                                     log_factors,
                                                     eliminate=sum_vars
                                                     | prod_vars,
                                                     plates=prod_vars)
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".
            format(result.data.shape,
                   {k.split("__BOUND")[0]
                    for k in result.inputs}))
    return result.data, model_trace
Esempio n. 9
0
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
    """Helper function to compute elbo and extract its components from execution traces."""
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # PP... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    history = 1
    log_measures = {}
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob_factor = funsor.to_funsor(
                log_prob, output=funsor.Real, dim_to_name=dim_to_name
            )

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]])
                    time_to_factors[time_dim].append(log_prob_factor)
                    history = max(
                        history, max(_get_shift(s) for s in dim_to_name.values())
                    )
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values() if s.startswith("_PREV_")
                    )
                    break
            if time_dim is None:
                log_factors.append(log_prob_factor)

            if not site["is_observed"]:
                log_measures[site["name"]] = log_prob_factor
                sum_vars |= frozenset({site["name"]})

            prod_vars |= frozenset(
                f.name for f in site["cond_indep_stack"] if f.dim is not None
            )

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = _shift_name(var, -_get_shift(var))
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values():  # i.e. _PREV_* (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values()
                )

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(
            time_to_factors,
            time_to_init_vars,
            time_to_markov_dims,
            sum_vars,
            prod_vars,
            history,
            sum_op,
            prod_op,
        )
        log_factors = log_factors + markov_factors

    with funsor.interpretations.lazy:
        lazy_result = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            log_factors,
            eliminate=sum_vars | prod_vars,
            plates=prod_vars,
        )
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".format(
                result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
            )
        )
    return result, model_trace, log_measures