Exemple #1
0
def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4):
    # copied from pyro
    expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data(actual_loss)
    assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol)
    names = pyro.get_param_store().keys()
    params = []
    for name in names:
        params.append(funsor.to_data(pyro.param(name)).unconstrained())
    actual_grads = grad(actual_loss, params, allow_unused=True, retain_graph=True)
    expected_grads = grad(expected_loss, params, allow_unused=True, retain_graph=True)
    for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads):
        if actual_grad is None or expected_grad is None:
            continue
        assert ops.allclose(actual_grad, expected_grad, atol=atol, rtol=rtol)
Exemple #2
0
def test_elbo_plate_plate(backend, outer_dim, inner_dim):
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        num_particles = 1
        q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True))
        p = 0.2693204236205713  # for which kl(Categorical(q), Categorical(p)) = 0.5
        p = torch.tensor([p, 1 - p])

        def model():
            d = dist.Categorical(p)
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d)
            with context1:
                pyro.sample("x", d)
            with context2:
                pyro.sample("y", d)
            with context1, context2:
                pyro.sample("z", d)

        def guide():
            d = dist.Categorical(pyro.param("q"))
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d, infer={"enumerate": "parallel"})
            with context1:
                pyro.sample("x", d, infer={"enumerate": "parallel"})
            with context2:
                pyro.sample("y", d, infer={"enumerate": "parallel"})
            with context1, context2:
                pyro.sample("z", d, infer={"enumerate": "parallel"})

        kl_node = kl_divergence(
            torch.distributions.Categorical(funsor.to_data(q)),
            torch.distributions.Categorical(funsor.to_data(p)))
        kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
        expected_loss = kl
        expected_grad = grad(kl, [funsor.to_data(q)])[0]

        elbo = infer.TraceEnum_ELBO(num_particles=num_particles,
                                    vectorize_particles=True,
                                    strict_enumeration_warning=True)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        actual_loss = funsor.to_data(elbo(model, guide))
        actual_loss.backward()
        actual_grad = funsor.to_data(pyro.param('q')).grad

        assert_close(actual_loss, expected_loss, atol=1e-5)
        assert_close(actual_grad, expected_grad, atol=1e-5)
Exemple #3
0
def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL):
    """
    A primitive to extract a python object from a :class:`~funsor.terms.Funsor`.

    :param ~funsor.terms.Funsor x: A funsor object
    :param OrderedDict name_to_dim: An optional inputs hint which maps
        dimension names from `x` to dimension positions of the returned value.
    :param int dim_type: Either 0, 1, or 2. This optional argument indicates
        a dimension should be treated as 'local', 'global', or 'visible',
        which can be used to interact with the global :class:`DimStack`.
    :return: A non-funsor equivalent to `x`.
    """
    name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim

    initial_msg = {
        'type':
        'to_data',
        'fn':
        lambda x, name_to_dim, dim_type: funsor.to_data(
            x, name_to_dim=name_to_dim),
        'args': (x, ),
        'kwargs': {
            "name_to_dim": name_to_dim,
            "dim_type": dim_type
        },
        'value':
        None,
        'mask':
        None,
    }

    msg = apply_stack(initial_msg)
    return msg['value']
def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL):
    import funsor
    if pyro.poutine.runtime.am_i_wrapped() and not name_to_dim:
        name_to_dim = _DIM_STACK.global_frame.name_to_dim.copy()
    assert not name_to_dim or not any(
        isinstance(dim, DimRequest) for dim in name_to_dim.values())
    return funsor.to_data(x, name_to_dim=name_to_dim)
Exemple #5
0
 def step(self, *args, **kwargs):
     # This wraps both the call to `model` and `guide` in a `trace` so that
     # we can record all the parameters that are encountered. Note that
     # further tracing occurs inside of `loss`.
     with trace() as param_capture:
         # We use block here to allow tracing to record parameters only.
         with block(hide_fn=lambda msg: msg["type"] != "param"):
             loss = self.loss(self.model, self.guide, *args, **kwargs)
     # Differentiate the loss.
     funsor.to_data(loss).backward()
     # Grab all the parameters from the trace.
     params = [site["value"].data.unconstrained()
               for site in param_capture.values()]
     # Take a step w.r.t. each parameter in params.
     self.optim(params)
     # Zero out the gradients so that they don't accumulate.
     for p in params:
         p.grad = torch.zeros_like(p.grad)
     return loss.item()
Exemple #6
0
 def compiled(*params_and_args):
     unconstrained_params = params_and_args[:len(self._param_trace)]
     args = params_and_args[len(self._param_trace):]
     for name, unconstrained_param in zip(self._param_trace, unconstrained_params):
         constrained_param = param(name)  # assume param has been initialized
         assert constrained_param.data.unconstrained() is unconstrained_param
         self._param_trace[name]["value"] = constrained_param
     result = replay(self.fn, guide_trace=self._param_trace)(*args)
     assert not result.inputs
     assert result.output == funsor.reals()
     return funsor.to_data(result)
Exemple #7
0
def test_to_data_error():
    data = np.zeros((3, 3))
    x = Array(data, OrderedDict(i=bint(3)))
    with pytest.raises(ValueError):
        funsor.to_data(x)
Exemple #8
0
def test_to_data():
    data = np.zeros((3, 3))
    x = Array(data)
    assert funsor.to_data(x) is data
Exemple #9
0
def test_to_data():
    data = zeros((3, 3))
    x = Tensor(data)
    assert funsor.to_data(x) is data
Exemple #10
0
def test_to_data_error():
    data = zeros((3, 3))
    x = Tensor(data, OrderedDict(i=Bint[3]))
    with pytest.raises(ValueError):
        funsor.to_data(x)
Exemple #11
0
def test_to_data_error():
    data = torch.zeros(3, 3)
    x = Tensor(data, OrderedDict(i=bint(3)))
    with pytest.raises(ValueError):
        funsor.to_data(x)
Exemple #12
0
def test_to_data():
    data = torch.zeros(3, 3)
    x = Tensor(data)
    assert funsor.to_data(x) is data
Exemple #13
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with block(), enum(first_available_dim=first_available_dim):
        with plate_to_enum_plate():
            model_tr = packed_trace(model).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            list(terms["log_factors"].values()) +
            list(terms["log_measures"].values()),
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[terms["log_measures"][name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    with replay(guide_trace=sample_tr):
        return model(*args, **kwargs)
Exemple #14
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with funsor.adjoint.AdjointTape() as tape:
        with block(), enum(first_available_dim=first_available_dim):
            log_prob, model_tr, log_measures = _enum_log_density(
                model, args, kwargs, {}, sum_op, prod_op)

    with approx:
        approx_factors = tape.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[log_measures[name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    data = {
        name: site["value"]
        for name, site in sample_tr.items() if site["type"] == "sample"
    }

    # concatenate _PREV_foo to foo
    time_vars = defaultdict(list)
    for name in data:
        if name.startswith("_PREV_"):
            root_name = _shift_name(name, -_get_shift(name))
            time_vars[root_name].append(name)
    for name in time_vars:
        if name in data:
            time_vars[name].append(name)
        time_vars[name] = sorted(time_vars[name], key=len, reverse=True)

    for root_name, vars in time_vars.items():
        prototype_shape = model_trace[root_name]["value"].shape
        values = [data.pop(name) for name in vars]
        if len(values) == 1:
            data[root_name] = values[0].reshape(prototype_shape)
        else:
            assert len(prototype_shape) >= 1
            values = [v.reshape((-1, ) + prototype_shape[1:]) for v in values]
            data[root_name] = jnp.concatenate(values)

    return data