def make_dist(backend_dist_class, param_names=(), generate_eager=True, generate_to_funsor=True): if not param_names: param_names = tuple(name for name in inspect.getfullargspec( backend_dist_class.__init__)[0][1:] if name in backend_dist_class.arg_constraints) @makefun.with_signature("__init__(self, {}, value='value')".format( ', '.join(param_names))) def dist_init(self, **kwargs): return Distribution.__init__( self, *tuple(kwargs[k] for k in self._ast_fields)) dist_class = DistributionMeta( backend_dist_class.__name__.split("Wrapper_")[-1], (Distribution, ), { 'dist_class': backend_dist_class, '__init__': dist_init, }) if generate_eager: eager.register(dist_class, *((Tensor, ) * (len(param_names) + 1)))( dist_class.eager_log_prob) if generate_to_funsor: to_funsor.register(backend_dist_class)(functools.partial( backenddist_to_funsor, dist_class)) return dist_class
# Convert Delta **distribution** to raw data @to_data.register(Delta) # noqa: F821 def deltadist_to_data(funsor_dist, name_to_dim=None): v = to_data(funsor_dist.v, name_to_dim=name_to_dim) log_density = to_data(funsor_dist.log_density, name_to_dim=name_to_dim) return dist.Delta(v, log_density, event_dim=len(funsor_dist.v.output.shape)) ############################################### # Converting PyTorch Distributions to funsors ############################################### to_funsor.register(dist.Independent)(indepdist_to_funsor) if hasattr(dist, "MaskedDistribution"): to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor) to_funsor.register(dist.TransformedDistribution)(transformeddist_to_funsor) @to_funsor.register(dist.BinomialProbs) @to_funsor.register(dist.BinomialLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): new_pyro_dist = _NumPyroWrapper_Binomial( total_count=numpyro_dist.total_count, probs=numpyro_dist.probs) return backenddist_to_funsor(Binomial, new_pyro_dist, output, dim_to_name) # noqa: F821 @to_funsor.register(dist.CategoricalProbs)
from funsor.util import quote @adjoint_ops.register(Tensor, ops.AssociativeOp, ops.AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {} @recursion_reinterpret.register(DeviceArray) @recursion_reinterpret.register(Tracer) def _recursion_reinterpret_ground(x): return x @children.register(DeviceArray) @children.register(Tracer) def _children_ground(x): return () to_funsor.register(DeviceArray)(tensor_to_funsor) to_funsor.register(Tracer)(tensor_to_funsor) @quote.register(DeviceArray) def _quote(x, indent, out): """ Work around JAX's DeviceArray not supporting reproducible repr. """ out.append((indent, f"np.array({repr(x.copy().tolist())}, dtype=np.{x.dtype})"))
for k, domain in kwargs.items() }, validate_args=False) return reals(*instance.event_shape) Binomial._infer_value_domain = classmethod( _multinomial_infer_value_domain) # noqa: F821 Multinomial._infer_value_domain = classmethod( _multinomial_infer_value_domain) # noqa: F821 ############################################### # Converting PyTorch Distributions to funsors ############################################### to_funsor.register(dist.Distribution)(backenddist_to_funsor) to_funsor.register(dist.Independent)(indepdist_to_funsor) to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor) to_funsor.register(dist.TransformedDistribution)(transformeddist_to_funsor) to_funsor.register(dist.MultivariateNormal)(mvndist_to_funsor) @to_funsor.register(dist.BinomialProbs) @to_funsor.register(dist.BinomialLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): new_pyro_dist = _NumPyroWrapper_Binomial(probs=numpyro_dist.probs) return backenddist_to_funsor(new_pyro_dist, output, dim_to_name) @to_funsor.register(dist.CategoricalProbs) # XXX: in Pyro backend, we always convert pyro.distributions.Categorical
output=None, dim_to_name=None, real_inputs=None): name = next(real_inputs.keys()) if real_inputs else "value" expr = Variable(name, output) for part in tfm.parts: expr = to_funsor(part, output=output, dim_to_name=dim_to_name, real_inputs=real_inputs)(**{ name: expr }) return expr to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor) to_funsor.register(MaskedDistribution)(maskeddist_to_funsor) to_funsor.register( torch.distributions.TransformedDistribution)(transformeddist_to_funsor) @to_funsor.register(torch.distributions.Bernoulli) def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None): new_pyro_dist = _PyroWrapper_BernoulliLogits(logits=pyro_dist.logits) return backenddist_to_funsor(BernoulliLogits, new_pyro_dist, output, dim_to_name) # noqa: F821 @to_funsor.register(dist.Delta) # Delta **distribution** def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None): v = to_funsor(pyro_dist.v,
@children.register(torch.Tensor) @children.register(torch.nn.Module) def _children_ground(x): return () @quote.register(torch.Tensor) def _quote(x, indent, out): """ Work around PyTorch not supporting reproducible repr. """ out.append((indent, f"torch.tensor({repr(x.tolist())}, dtype={x.dtype})")) to_funsor.register(torch.Tensor)(tensor_to_funsor) @dispatch(torch.Tensor, torch.Tensor, [float]) def allclose(a, b, rtol=1e-05, atol=1e-08): return torch.allclose(a, b, rtol=rtol, atol=atol) ################################################################################ # Register Ops ################################################################################ ops.abs.register(torch.Tensor)(torch.abs) ops.cholesky_solve.register(torch.Tensor, torch.Tensor)(torch.cholesky_solve) ops.clamp.register(torch.Tensor, object, object)(torch.clamp) ops.exp.register(torch.Tensor)(torch.exp)