def test_to_funsor(shape, dtype): t = ops.astype(randn(shape), dtype) f = funsor.to_funsor(t) assert isinstance(f, Tensor) assert funsor.to_funsor(t, reals(*shape)) is f with pytest.raises(ValueError): funsor.to_funsor(t, reals(5, *shape))
def __exit__(self, *args, **kwargs): import funsor _coerce = COERCIONS.pop() assert _coerce is self._coerce super().__exit__(*args, **kwargs) # Convert delayed statements to pyro.factor() reduced_vars = [] log_prob_terms = [] plates = frozenset() for name, site in self.trace.items(): if not site["is_observed"]: reduced_vars.append(name) dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]} fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name) value = site["value"] if not isinstance(value, str): value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name) log_prob_terms.append(fn(value=value)) plates |= frozenset(f.name for f in site["cond_indep_stack"]) assert log_prob_terms, "nothing to collapse" reduced_plates = plates - self.preserved_plates log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, log_prob_terms, eliminate=frozenset(reduced_vars) | reduced_plates, plates=plates, ) name = reduced_vars[0] numpyro.factor(name, log_prob.data)
def test_to_funsor(shape, dtype): t = np.random.normal(size=shape).astype(dtype) f = funsor.to_funsor(t) assert isinstance(f, Array) assert funsor.to_funsor(t, reals(*shape)) is f with pytest.raises(ValueError): funsor.to_funsor(t, reals(5, *shape))
def test_tensor_to_funsor_ambiguous_output(): x = randn((2, 1)) f = funsor.to_funsor(x, output=None, dim_to_name=OrderedDict({-2: 'a'})) f2 = funsor.to_funsor(x, output=reals(), dim_to_name=OrderedDict({-2: 'a'})) assert f.inputs == f2.inputs == OrderedDict(a=bint(2)) assert f.output.shape == () == f2.output.shape
def _pyro_sample(self, msg): # Eagerly convert fn and value to Funsor. dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]} dim_to_name.update(self.preserved_plates) msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name) domain = msg["fn"].inputs["value"] if msg["value"] is None: msg["value"] = funsor.Variable(msg["name"], domain) else: msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name) msg["done"] = True msg["stop"] = True
def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL): import funsor if pyro.poutine.runtime.am_i_wrapped() and not dim_to_name: dim_to_name = _DIM_STACK.global_frame.dim_to_name.copy() assert not dim_to_name or not any( isinstance(name, DimRequest) for name in dim_to_name.values()) return funsor.to_funsor(x, output=output, dim_to_name=dim_to_name)
def model(data): log_prob = funsor.to_funsor(0.) xs_curr = [funsor.Tensor(torch.tensor(0.)) for var in var_names] for t, y in enumerate(data): xs_prev = xs_curr # A delayed sample statement. xs_curr = [ funsor.Variable(name + '_{}'.format(t), funsor.reals()) for name in var_names ] for i, x_curr in enumerate(xs_curr): log_prob += dist.Normal(trans_eqs[var_names[i]](xs_prev), torch.exp(trans_noises[i]), value=x_curr) if t > 0: log_prob = log_prob.reduce( ops.logaddexp, frozenset([x_prev.name for x_prev in xs_prev])) # An observe statement. log_prob += dist.Normal(emit_eq(xs_curr), torch.exp(emit_noise), value=y) # Marginalize out all remaining delayed variables. return log_prob.reduce(ops.logaddexp), log_prob.gaussian
def model(data): log_prob = funsor.to_funsor(0.) trans = dist.Categorical(probs=funsor.Tensor( trans_probs, inputs=OrderedDict([('prev', funsor.bint(args.hidden_dim))]), )) emit = dist.Categorical(probs=funsor.Tensor( emit_probs, inputs=OrderedDict([('latent', funsor.bint(args.hidden_dim))]), )) x_curr = funsor.Number(0, args.hidden_dim) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.bint(args.hidden_dim)) log_prob += trans(prev=x_prev, value=x_curr) if not args.lazy and isinstance(x_prev, funsor.Variable): log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2)) log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def _get_log_prob(self): # Convert delayed statements to pyro.factor() reduced_vars = [] log_prob_terms = [] plates = frozenset() for name, site in self.trace.nodes.items(): if not site["is_observed"]: reduced_vars.append(name) dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]} fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name) value = site["value"] if not isinstance(value, str): value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name) log_prob_terms.append(fn(value=value)) plates |= frozenset(f.name for f in site["cond_indep_stack"] if f.vectorized) name = reduced_vars[0] reduced_vars = frozenset(reduced_vars) assert log_prob_terms, "nothing to collapse" self.trace.nodes.clear() reduced_plates = plates - self.preserved_plates if reduced_plates: log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, log_prob_terms, eliminate=reduced_vars | reduced_plates, plates=plates, ) log_joint = NotImplemented else: log_joint = reduce(funsor.ops.add, log_prob_terms) log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars) return name, log_prob, log_joint, reduced_vars
def model(data): log_prob = funsor.to_funsor(0.) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr) # Optionally marginalize out the previous state. if t > 0 and not args.lazy: log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) # An observe statement. log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y) # Marginalize out all remaining delayed variables. log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def terms_from_trace(tr): """Helper function to extract elbo components from execution traces.""" log_factors = {} log_measures = {} sum_vars, prod_vars = frozenset(), frozenset() for site in tr.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] log_prob_factor = funsor.to_funsor(log_prob, output=funsor.Real, dim_to_name=dim_to_name) if site["is_observed"]: log_factors[site["name"]] = log_prob_factor else: log_measures[site["name"]] = log_prob_factor sum_vars |= frozenset({site["name"]}) prod_vars |= frozenset(f.name for f in site["cond_indep_stack"] if f.dim is not None) return { "log_factors": log_factors, "log_measures": log_measures, "measure_vars": sum_vars, "plate_vars": prod_vars, }
def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL): """ A primitive to convert a Python object to a :class:`~funsor.terms.Funsor`. :param x: An object. :param funsor.domains.Domain output: An optional output hint to uniquely convert a data to a Funsor (e.g. when `x` is a string). :param OrderedDict dim_to_name: An optional mapping from negative batch dimensions to name strings. :param int dim_type: Either 0, 1, or 2. This optional argument indicates a dimension should be treated as 'local', 'global', or 'visible', which can be used to interact with the global :class:`DimStack`. :return: A Funsor equivalent to `x`. :rtype: funsor.terms.Funsor """ dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name initial_msg = { 'type': 'to_funsor', 'fn': lambda x, output, dim_to_name, dim_type: funsor.to_funsor( x, output=output, dim_to_name=dim_to_name), 'args': (x, ), 'kwargs': { "output": output, "dim_to_name": dim_to_name, "dim_type": dim_type }, 'value': None, 'mask': None, } msg = apply_stack(initial_msg) return msg['value']
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with funsor.adjoint.AdjointTape() as tape: with block(), enum(first_available_dim=first_available_dim): log_prob, model_tr, log_measures = _enum_log_density( model, args, kwargs, {}, sum_op, prod_op) with approx: approx_factors = tape.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[log_measures[name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) data = { name: site["value"] for name, site in sample_tr.items() if site["type"] == "sample" } # concatenate _PREV_foo to foo time_vars = defaultdict(list) for name in data: if name.startswith("_PREV_"): root_name = _shift_name(name, -_get_shift(name)) time_vars[root_name].append(name) for name in time_vars: if name in data: time_vars[name].append(name) time_vars[name] = sorted(time_vars[name], key=len, reverse=True) for root_name, vars in time_vars.items(): prototype_shape = model_trace[root_name]["value"].shape values = [data.pop(name) for name in vars] if len(values) == 1: data[root_name] = values[0].reshape(prototype_shape) else: assert len(prototype_shape) >= 1 values = [v.reshape((-1, ) + prototype_shape[1:]) for v in values] data[root_name] = jnp.concatenate(values) return data
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): """Helper function to compute elbo and extract its components from execution traces.""" model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_factors = [] time_to_factors = defaultdict(list) # log prob factors time_to_init_vars = defaultdict(frozenset) # PP... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() history = 1 log_measures = {} for site in model_trace.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] log_prob_factor = funsor.to_funsor( log_prob, output=funsor.Real, dim_to_name=dim_to_name ) time_dim = None for dim, name in dim_to_name.items(): if name.startswith("_time"): time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]]) time_to_factors[time_dim].append(log_prob_factor) history = max( history, max(_get_shift(s) for s in dim_to_name.values()) ) time_to_init_vars[time_dim] |= frozenset( s for s in dim_to_name.values() if s.startswith("_PREV_") ) break if time_dim is None: log_factors.append(log_prob_factor) if not site["is_observed"]: log_measures[site["name"]] = log_prob_factor sum_vars |= frozenset({site["name"]}) prod_vars |= frozenset( f.name for f in site["cond_indep_stack"] if f.dim is not None ) for time_dim, init_vars in time_to_init_vars.items(): for var in init_vars: curr_var = _shift_name(var, -_get_shift(var)) dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"] if var in dim_to_name.values(): # i.e. _PREV_* (i.e. prev) in dim_to_name time_to_markov_dims[time_dim] |= frozenset( name for name in dim_to_name.values() ) if len(time_to_factors) > 0: markov_factors = compute_markov_factors( time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars, history, sum_op, prod_op, ) log_factors = log_factors + markov_factors with funsor.interpretations.lazy: lazy_result = funsor.sum_product.sum_product( sum_op, prod_op, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars, ) result = funsor.optimizer.apply_optimizer(lazy_result) if len(result.inputs) > 0: raise ValueError( "Expected the joint log density is a scalar, but got {}. " "There seems to be something wrong at the following sites: {}.".format( result.data.shape, {k.split("__BOUND")[0] for k in result.inputs} ) ) return result, model_trace, log_measures
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with block(), enum(first_available_dim=first_available_dim): with plate_to_enum_plate(): model_tr = packed_trace(model).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, list(terms["log_factors"].values()) + list(terms["log_measures"].values()), eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[terms["log_measures"][name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) with replay(guide_trace=sample_tr): return model(*args, **kwargs)
def test_empty_tensor_possible(): funsor.to_funsor(randn(3, 0), dim_to_name=OrderedDict([(-1, "a"), (-2, "b")]))
def log_density(model, model_args, model_kwargs, params): """ Similar to :func:`numpyro.infer.util.log_density` but works for models with discrete latent variables. Internally, this uses :mod:`funsor` to marginalize discrete latent sites and evaluate the joint log probability. :param model: Python callable containing NumPyro primitives. Typically, the model has been enumerated by using :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: def model(*args, **kwargs): ... log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_factors = [] time_to_factors = defaultdict(list) # log prob factors time_to_init_vars = defaultdict(frozenset) # _init/... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] scale = site['scale'] if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] log_prob = funsor.to_funsor(log_prob, output=funsor.reals(), dim_to_name=dim_to_name) time_dim = None for dim, name in dim_to_name.items(): if name.startswith("_time"): time_dim = funsor.Variable( name, funsor.domains.bint(site["value"].shape[dim])) time_to_factors[time_dim].append(log_prob) time_to_init_vars[time_dim] |= frozenset( s for s in dim_to_name.values() if s.startswith("_init")) break if time_dim is None: log_factors.append(log_prob) if not site['is_observed']: sum_vars |= frozenset({site['name']}) prod_vars |= frozenset(f.name for f in site['cond_indep_stack'] if f.dim is not None) for time_dim, init_vars in time_to_init_vars.items(): for var in init_vars: curr_var = "/".join(var.split("/")[1:]) dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"] if var in dim_to_name.values( ): # i.e. _init (i.e. prev) in dim_to_name time_to_markov_dims[time_dim] |= frozenset( name for name in dim_to_name.values()) if len(time_to_factors) > 0: markov_factors = compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars) log_factors = log_factors + markov_factors with funsor.interpreter.interpretation(funsor.terms.lazy): lazy_result = funsor.sum_product.sum_product(funsor.ops.logaddexp, funsor.ops.add, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars) result = funsor.optimizer.apply_optimizer(lazy_result) if len(result.inputs) > 0: raise ValueError( "Expected the joint log density is a scalar, but got {}. " "There seems to be something wrong at the following sites: {}.". format(result.data.shape, {k.split("__BOUND")[0] for k in result.inputs})) return result.data, model_trace