def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4): # copied from pyro expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data(actual_loss) assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol) names = pyro.get_param_store().keys() params = [] for name in names: params.append(funsor.to_data(pyro.param(name)).unconstrained()) actual_grads = grad(actual_loss, params, allow_unused=True, retain_graph=True) expected_grads = grad(expected_loss, params, allow_unused=True, retain_graph=True) for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads): if actual_grad is None or expected_grad is None: continue assert ops.allclose(actual_grad, expected_grad, atol=atol, rtol=rtol)
def test_elbo_plate_plate(backend, outer_dim, inner_dim): with pyro_backend(backend): pyro.get_param_store().clear() num_particles = 1 q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True)) p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 p = torch.tensor([p, 1 - p]) def model(): d = dist.Categorical(p) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d) with context1: pyro.sample("x", d) with context2: pyro.sample("y", d) with context1, context2: pyro.sample("z", d) def guide(): d = dist.Categorical(pyro.param("q")) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d, infer={"enumerate": "parallel"}) with context1: pyro.sample("x", d, infer={"enumerate": "parallel"}) with context2: pyro.sample("y", d, infer={"enumerate": "parallel"}) with context1, context2: pyro.sample("z", d, infer={"enumerate": "parallel"}) kl_node = kl_divergence( torch.distributions.Categorical(funsor.to_data(q)), torch.distributions.Categorical(funsor.to_data(p))) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl expected_grad = grad(kl, [funsor.to_data(q)])[0] elbo = infer.TraceEnum_ELBO(num_particles=num_particles, vectorize_particles=True, strict_enumeration_warning=True) elbo = elbo.differentiable_loss if backend == "pyro" else elbo actual_loss = funsor.to_data(elbo(model, guide)) actual_loss.backward() actual_grad = funsor.to_data(pyro.param('q')).grad assert_close(actual_loss, expected_loss, atol=1e-5) assert_close(actual_grad, expected_grad, atol=1e-5)
def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL): """ A primitive to extract a python object from a :class:`~funsor.terms.Funsor`. :param ~funsor.terms.Funsor x: A funsor object :param OrderedDict name_to_dim: An optional inputs hint which maps dimension names from `x` to dimension positions of the returned value. :param int dim_type: Either 0, 1, or 2. This optional argument indicates a dimension should be treated as 'local', 'global', or 'visible', which can be used to interact with the global :class:`DimStack`. :return: A non-funsor equivalent to `x`. """ name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim initial_msg = { 'type': 'to_data', 'fn': lambda x, name_to_dim, dim_type: funsor.to_data( x, name_to_dim=name_to_dim), 'args': (x, ), 'kwargs': { "name_to_dim": name_to_dim, "dim_type": dim_type }, 'value': None, 'mask': None, } msg = apply_stack(initial_msg) return msg['value']
def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL): import funsor if pyro.poutine.runtime.am_i_wrapped() and not name_to_dim: name_to_dim = _DIM_STACK.global_frame.name_to_dim.copy() assert not name_to_dim or not any( isinstance(dim, DimRequest) for dim in name_to_dim.values()) return funsor.to_data(x, name_to_dim=name_to_dim)
def step(self, *args, **kwargs): # This wraps both the call to `model` and `guide` in a `trace` so that # we can record all the parameters that are encountered. Note that # further tracing occurs inside of `loss`. with trace() as param_capture: # We use block here to allow tracing to record parameters only. with block(hide_fn=lambda msg: msg["type"] != "param"): loss = self.loss(self.model, self.guide, *args, **kwargs) # Differentiate the loss. funsor.to_data(loss).backward() # Grab all the parameters from the trace. params = [site["value"].data.unconstrained() for site in param_capture.values()] # Take a step w.r.t. each parameter in params. self.optim(params) # Zero out the gradients so that they don't accumulate. for p in params: p.grad = torch.zeros_like(p.grad) return loss.item()
def compiled(*params_and_args): unconstrained_params = params_and_args[:len(self._param_trace)] args = params_and_args[len(self._param_trace):] for name, unconstrained_param in zip(self._param_trace, unconstrained_params): constrained_param = param(name) # assume param has been initialized assert constrained_param.data.unconstrained() is unconstrained_param self._param_trace[name]["value"] = constrained_param result = replay(self.fn, guide_trace=self._param_trace)(*args) assert not result.inputs assert result.output == funsor.reals() return funsor.to_data(result)
def test_to_data_error(): data = np.zeros((3, 3)) x = Array(data, OrderedDict(i=bint(3))) with pytest.raises(ValueError): funsor.to_data(x)
def test_to_data(): data = np.zeros((3, 3)) x = Array(data) assert funsor.to_data(x) is data
def test_to_data(): data = zeros((3, 3)) x = Tensor(data) assert funsor.to_data(x) is data
def test_to_data_error(): data = zeros((3, 3)) x = Tensor(data, OrderedDict(i=Bint[3])) with pytest.raises(ValueError): funsor.to_data(x)
def test_to_data_error(): data = torch.zeros(3, 3) x = Tensor(data, OrderedDict(i=bint(3))) with pytest.raises(ValueError): funsor.to_data(x)
def test_to_data(): data = torch.zeros(3, 3) x = Tensor(data) assert funsor.to_data(x) is data
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with block(), enum(first_available_dim=first_available_dim): with plate_to_enum_plate(): model_tr = packed_trace(model).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, list(terms["log_factors"].values()) + list(terms["log_measures"].values()), eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[terms["log_measures"][name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) with replay(guide_trace=sample_tr): return model(*args, **kwargs)
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with funsor.adjoint.AdjointTape() as tape: with block(), enum(first_available_dim=first_available_dim): log_prob, model_tr, log_measures = _enum_log_density( model, args, kwargs, {}, sum_op, prod_op) with approx: approx_factors = tape.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[log_measures[name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) data = { name: site["value"] for name, site in sample_tr.items() if site["type"] == "sample" } # concatenate _PREV_foo to foo time_vars = defaultdict(list) for name in data: if name.startswith("_PREV_"): root_name = _shift_name(name, -_get_shift(name)) time_vars[root_name].append(name) for name in time_vars: if name in data: time_vars[name].append(name) time_vars[name] = sorted(time_vars[name], key=len, reverse=True) for root_name, vars in time_vars.items(): prototype_shape = model_trace[root_name]["value"].shape values = [data.pop(name) for name in vars] if len(values) == 1: data[root_name] = values[0].reshape(prototype_shape) else: assert len(prototype_shape) >= 1 values = [v.reshape((-1, ) + prototype_shape[1:]) for v in values] data[root_name] = jnp.concatenate(values) return data