예제 #1
0
def test_to_funsor(shape, dtype):
    t = ops.astype(randn(shape), dtype)
    f = funsor.to_funsor(t)
    assert isinstance(f, Tensor)
    assert funsor.to_funsor(t, reals(*shape)) is f
    with pytest.raises(ValueError):
        funsor.to_funsor(t, reals(5, *shape))
예제 #2
0
파일: handlers.py 프로젝트: xidulu/numpyro
    def __exit__(self, *args, **kwargs):
        import funsor

        _coerce = COERCIONS.pop()
        assert _coerce is self._coerce
        super().__exit__(*args, **kwargs)

        # Convert delayed statements to pyro.factor()
        reduced_vars = []
        log_prob_terms = []
        plates = frozenset()
        for name, site in self.trace.items():
            if not site["is_observed"]:
                reduced_vars.append(name)
            dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]}
            fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name)
            value = site["value"]
            if not isinstance(value, str):
                value = funsor.to_funsor(site["value"], fn.inputs["value"],
                                         dim_to_name)
            log_prob_terms.append(fn(value=value))
            plates |= frozenset(f.name for f in site["cond_indep_stack"])
        assert log_prob_terms, "nothing to collapse"
        reduced_plates = plates - self.preserved_plates
        log_prob = funsor.sum_product.sum_product(
            funsor.ops.logaddexp,
            funsor.ops.add,
            log_prob_terms,
            eliminate=frozenset(reduced_vars) | reduced_plates,
            plates=plates,
        )
        name = reduced_vars[0]
        numpyro.factor(name, log_prob.data)
예제 #3
0
def test_to_funsor(shape, dtype):
    t = np.random.normal(size=shape).astype(dtype)
    f = funsor.to_funsor(t)
    assert isinstance(f, Array)
    assert funsor.to_funsor(t, reals(*shape)) is f
    with pytest.raises(ValueError):
        funsor.to_funsor(t, reals(5, *shape))
예제 #4
0
def test_tensor_to_funsor_ambiguous_output():
    x = randn((2, 1))
    f = funsor.to_funsor(x, output=None, dim_to_name=OrderedDict({-2: 'a'}))
    f2 = funsor.to_funsor(x,
                          output=reals(),
                          dim_to_name=OrderedDict({-2: 'a'}))
    assert f.inputs == f2.inputs == OrderedDict(a=bint(2))
    assert f.output.shape == () == f2.output.shape
예제 #5
0
    def _pyro_sample(self, msg):
        # Eagerly convert fn and value to Funsor.
        dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]}
        dim_to_name.update(self.preserved_plates)
        msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name)
        domain = msg["fn"].inputs["value"]
        if msg["value"] is None:
            msg["value"] = funsor.Variable(msg["name"], domain)
        else:
            msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name)

        msg["done"] = True
        msg["stop"] = True
예제 #6
0
def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL):
    import funsor
    if pyro.poutine.runtime.am_i_wrapped() and not dim_to_name:
        dim_to_name = _DIM_STACK.global_frame.dim_to_name.copy()
    assert not dim_to_name or not any(
        isinstance(name, DimRequest) for name in dim_to_name.values())
    return funsor.to_funsor(x, output=output, dim_to_name=dim_to_name)
예제 #7
0
    def model(data):
        log_prob = funsor.to_funsor(0.)
        xs_curr = [funsor.Tensor(torch.tensor(0.)) for var in var_names]

        for t, y in enumerate(data):
            xs_prev = xs_curr

            # A delayed sample statement.
            xs_curr = [
                funsor.Variable(name + '_{}'.format(t), funsor.reals())
                for name in var_names
            ]

            for i, x_curr in enumerate(xs_curr):
                log_prob += dist.Normal(trans_eqs[var_names[i]](xs_prev),
                                        torch.exp(trans_noises[i]),
                                        value=x_curr)

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([x_prev.name for x_prev in xs_prev]))

            # An observe statement.
            log_prob += dist.Normal(emit_eq(xs_curr),
                                    torch.exp(emit_noise),
                                    value=y)

        # Marginalize out all remaining delayed variables.
        return log_prob.reduce(ops.logaddexp), log_prob.gaussian
예제 #8
0
    def model(data):
        log_prob = funsor.to_funsor(0.)

        trans = dist.Categorical(probs=funsor.Tensor(
            trans_probs,
            inputs=OrderedDict([('prev', funsor.bint(args.hidden_dim))]),
        ))

        emit = dist.Categorical(probs=funsor.Tensor(
            emit_probs,
            inputs=OrderedDict([('latent', funsor.bint(args.hidden_dim))]),
        ))

        x_curr = funsor.Number(0, args.hidden_dim)
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t),
                                     funsor.bint(args.hidden_dim))
            log_prob += trans(prev=x_prev, value=x_curr)

            if not args.lazy and isinstance(x_prev, funsor.Variable):
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2))

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
예제 #9
0
    def _get_log_prob(self):
        # Convert delayed statements to pyro.factor()
        reduced_vars = []
        log_prob_terms = []
        plates = frozenset()
        for name, site in self.trace.nodes.items():
            if not site["is_observed"]:
                reduced_vars.append(name)
            dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]}
            fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name)
            value = site["value"]
            if not isinstance(value, str):
                value = funsor.to_funsor(site["value"], fn.inputs["value"],
                                         dim_to_name)
            log_prob_terms.append(fn(value=value))
            plates |= frozenset(f.name for f in site["cond_indep_stack"]
                                if f.vectorized)
        name = reduced_vars[0]
        reduced_vars = frozenset(reduced_vars)
        assert log_prob_terms, "nothing to collapse"
        self.trace.nodes.clear()
        reduced_plates = plates - self.preserved_plates
        if reduced_plates:
            log_prob = funsor.sum_product.sum_product(
                funsor.ops.logaddexp,
                funsor.ops.add,
                log_prob_terms,
                eliminate=reduced_vars | reduced_plates,
                plates=plates,
            )
            log_joint = NotImplemented
        else:
            log_joint = reduce(funsor.ops.add, log_prob_terms)
            log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars)

        return name, log_prob, log_joint, reduced_vars
예제 #10
0
    def model(data):
        log_prob = funsor.to_funsor(0.)

        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr)

            # Optionally marginalize out the previous state.
            if t > 0 and not args.lazy:
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            # An observe statement.
            log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)

        # Marginalize out all remaining delayed variables.
        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
예제 #11
0
def terms_from_trace(tr):
    """Helper function to extract elbo components from execution traces."""
    log_factors = {}
    log_measures = {}
    sum_vars, prod_vars = frozenset(), frozenset()
    for site in tr.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob_factor = funsor.to_funsor(log_prob,
                                               output=funsor.Real,
                                               dim_to_name=dim_to_name)

            if site["is_observed"]:
                log_factors[site["name"]] = log_prob_factor
            else:
                log_measures[site["name"]] = log_prob_factor
                sum_vars |= frozenset({site["name"]})
            prod_vars |= frozenset(f.name for f in site["cond_indep_stack"]
                                   if f.dim is not None)

    return {
        "log_factors": log_factors,
        "log_measures": log_measures,
        "measure_vars": sum_vars,
        "plate_vars": prod_vars,
    }
예제 #12
0
def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL):
    """
    A primitive to convert a Python object to a :class:`~funsor.terms.Funsor`.

    :param x: An object.
    :param funsor.domains.Domain output: An optional output hint to uniquely
        convert a data to a Funsor (e.g. when `x` is a string).
    :param OrderedDict dim_to_name: An optional mapping from negative
        batch dimensions to name strings.
    :param int dim_type: Either 0, 1, or 2. This optional argument indicates
        a dimension should be treated as 'local', 'global', or 'visible',
        which can be used to interact with the global :class:`DimStack`.
    :return: A Funsor equivalent to `x`.
    :rtype: funsor.terms.Funsor
    """
    dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name

    initial_msg = {
        'type':
        'to_funsor',
        'fn':
        lambda x, output, dim_to_name, dim_type: funsor.to_funsor(
            x, output=output, dim_to_name=dim_to_name),
        'args': (x, ),
        'kwargs': {
            "output": output,
            "dim_to_name": dim_to_name,
            "dim_type": dim_type
        },
        'value':
        None,
        'mask':
        None,
    }

    msg = apply_stack(initial_msg)
    return msg['value']
예제 #13
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with funsor.adjoint.AdjointTape() as tape:
        with block(), enum(first_available_dim=first_available_dim):
            log_prob, model_tr, log_measures = _enum_log_density(
                model, args, kwargs, {}, sum_op, prod_op)

    with approx:
        approx_factors = tape.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[log_measures[name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    data = {
        name: site["value"]
        for name, site in sample_tr.items() if site["type"] == "sample"
    }

    # concatenate _PREV_foo to foo
    time_vars = defaultdict(list)
    for name in data:
        if name.startswith("_PREV_"):
            root_name = _shift_name(name, -_get_shift(name))
            time_vars[root_name].append(name)
    for name in time_vars:
        if name in data:
            time_vars[name].append(name)
        time_vars[name] = sorted(time_vars[name], key=len, reverse=True)

    for root_name, vars in time_vars.items():
        prototype_shape = model_trace[root_name]["value"].shape
        values = [data.pop(name) for name in vars]
        if len(values) == 1:
            data[root_name] = values[0].reshape(prototype_shape)
        else:
            assert len(prototype_shape) >= 1
            values = [v.reshape((-1, ) + prototype_shape[1:]) for v in values]
            data[root_name] = jnp.concatenate(values)

    return data
예제 #14
0
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
    """Helper function to compute elbo and extract its components from execution traces."""
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # PP... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    history = 1
    log_measures = {}
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob_factor = funsor.to_funsor(
                log_prob, output=funsor.Real, dim_to_name=dim_to_name
            )

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]])
                    time_to_factors[time_dim].append(log_prob_factor)
                    history = max(
                        history, max(_get_shift(s) for s in dim_to_name.values())
                    )
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values() if s.startswith("_PREV_")
                    )
                    break
            if time_dim is None:
                log_factors.append(log_prob_factor)

            if not site["is_observed"]:
                log_measures[site["name"]] = log_prob_factor
                sum_vars |= frozenset({site["name"]})

            prod_vars |= frozenset(
                f.name for f in site["cond_indep_stack"] if f.dim is not None
            )

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = _shift_name(var, -_get_shift(var))
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values():  # i.e. _PREV_* (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values()
                )

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(
            time_to_factors,
            time_to_init_vars,
            time_to_markov_dims,
            sum_vars,
            prod_vars,
            history,
            sum_op,
            prod_op,
        )
        log_factors = log_factors + markov_factors

    with funsor.interpretations.lazy:
        lazy_result = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            log_factors,
            eliminate=sum_vars | prod_vars,
            plates=prod_vars,
        )
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".format(
                result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
            )
        )
    return result, model_trace, log_measures
예제 #15
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with block(), enum(first_available_dim=first_available_dim):
        with plate_to_enum_plate():
            model_tr = packed_trace(model).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            list(terms["log_factors"].values()) +
            list(terms["log_measures"].values()),
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[terms["log_measures"][name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    with replay(guide_trace=sample_tr):
        return model(*args, **kwargs)
예제 #16
0
def test_empty_tensor_possible():
    funsor.to_funsor(randn(3, 0),
                     dim_to_name=OrderedDict([(-1, "a"), (-2, "b")]))
예제 #17
0
def log_density(model, model_args, model_kwargs, params):
    """
    Similar to :func:`numpyro.infer.util.log_density` but works for models
    with discrete latent variables. Internally, this uses :mod:`funsor`
    to marginalize discrete latent sites and evaluate the joint log probability.

    :param model: Python callable containing NumPyro primitives. Typically,
        the model has been enumerated by using
        :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

            def model(*args, **kwargs):
                ...

            log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args,
                                                    **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # _init/... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            scale = site['scale']
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob = funsor.to_funsor(log_prob,
                                        output=funsor.reals(),
                                        dim_to_name=dim_to_name)

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(
                        name, funsor.domains.bint(site["value"].shape[dim]))
                    time_to_factors[time_dim].append(log_prob)
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values()
                        if s.startswith("_init"))
                    break
            if time_dim is None:
                log_factors.append(log_prob)

            if not site['is_observed']:
                sum_vars |= frozenset({site['name']})
            prod_vars |= frozenset(f.name for f in site['cond_indep_stack']
                                   if f.dim is not None)

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = "/".join(var.split("/")[1:])
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values(
            ):  # i.e. _init (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values())

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(time_to_factors,
                                                time_to_init_vars,
                                                time_to_markov_dims, sum_vars,
                                                prod_vars)
        log_factors = log_factors + markov_factors

    with funsor.interpreter.interpretation(funsor.terms.lazy):
        lazy_result = funsor.sum_product.sum_product(funsor.ops.logaddexp,
                                                     funsor.ops.add,
                                                     log_factors,
                                                     eliminate=sum_vars
                                                     | prod_vars,
                                                     plates=prod_vars)
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".
            format(result.data.shape,
                   {k.split("__BOUND")[0]
                    for k in result.inputs}))
    return result.data, model_trace