Beispiel #1
0
def make_dist(backend_dist_class,
              param_names=(),
              generate_eager=True,
              generate_to_funsor=True):
    if not param_names:
        param_names = tuple(name for name in inspect.getfullargspec(
            backend_dist_class.__init__)[0][1:]
                            if name in backend_dist_class.arg_constraints)

    @makefun.with_signature("__init__(self, {}, value='value')".format(
        ', '.join(param_names)))
    def dist_init(self, **kwargs):
        return Distribution.__init__(
            self, *tuple(kwargs[k] for k in self._ast_fields))

    dist_class = DistributionMeta(
        backend_dist_class.__name__.split("Wrapper_")[-1], (Distribution, ), {
            'dist_class': backend_dist_class,
            '__init__': dist_init,
        })

    if generate_eager:
        eager.register(dist_class, *((Tensor, ) * (len(param_names) + 1)))(
            dist_class.eager_log_prob)

    if generate_to_funsor:
        to_funsor.register(backend_dist_class)(functools.partial(
            backenddist_to_funsor, dist_class))

    return dist_class
Beispiel #2
0
# Convert Delta **distribution** to raw data
@to_data.register(Delta)  # noqa: F821
def deltadist_to_data(funsor_dist, name_to_dim=None):
    v = to_data(funsor_dist.v, name_to_dim=name_to_dim)
    log_density = to_data(funsor_dist.log_density, name_to_dim=name_to_dim)
    return dist.Delta(v,
                      log_density,
                      event_dim=len(funsor_dist.v.output.shape))


###############################################
# Converting PyTorch Distributions to funsors
###############################################

to_funsor.register(dist.Independent)(indepdist_to_funsor)
if hasattr(dist, "MaskedDistribution"):
    to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(dist.TransformedDistribution)(transformeddist_to_funsor)


@to_funsor.register(dist.BinomialProbs)
@to_funsor.register(dist.BinomialLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
    new_pyro_dist = _NumPyroWrapper_Binomial(
        total_count=numpyro_dist.total_count, probs=numpyro_dist.probs)
    return backenddist_to_funsor(Binomial, new_pyro_dist, output,
                                 dim_to_name)  # noqa: F821


@to_funsor.register(dist.CategoricalProbs)
Beispiel #3
0
from funsor.util import quote


@adjoint_ops.register(Tensor, ops.AssociativeOp, ops.AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object)
def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
    return {}


@recursion_reinterpret.register(DeviceArray)
@recursion_reinterpret.register(Tracer)
def _recursion_reinterpret_ground(x):
    return x


@children.register(DeviceArray)
@children.register(Tracer)
def _children_ground(x):
    return ()


to_funsor.register(DeviceArray)(tensor_to_funsor)
to_funsor.register(Tracer)(tensor_to_funsor)


@quote.register(DeviceArray)
def _quote(x, indent, out):
    """
    Work around JAX's DeviceArray not supporting reproducible repr.
    """
    out.append((indent, f"np.array({repr(x.copy().tolist())}, dtype=np.{x.dtype})"))
Beispiel #4
0
        for k, domain in kwargs.items()
    },
                              validate_args=False)
    return reals(*instance.event_shape)


Binomial._infer_value_domain = classmethod(
    _multinomial_infer_value_domain)  # noqa: F821
Multinomial._infer_value_domain = classmethod(
    _multinomial_infer_value_domain)  # noqa: F821

###############################################
# Converting PyTorch Distributions to funsors
###############################################

to_funsor.register(dist.Distribution)(backenddist_to_funsor)
to_funsor.register(dist.Independent)(indepdist_to_funsor)
to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(dist.TransformedDistribution)(transformeddist_to_funsor)
to_funsor.register(dist.MultivariateNormal)(mvndist_to_funsor)


@to_funsor.register(dist.BinomialProbs)
@to_funsor.register(dist.BinomialLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
    new_pyro_dist = _NumPyroWrapper_Binomial(probs=numpyro_dist.probs)
    return backenddist_to_funsor(new_pyro_dist, output, dim_to_name)


@to_funsor.register(dist.CategoricalProbs)
# XXX: in Pyro backend, we always convert pyro.distributions.Categorical
Beispiel #5
0
                               output=None,
                               dim_to_name=None,
                               real_inputs=None):
    name = next(real_inputs.keys()) if real_inputs else "value"
    expr = Variable(name, output)
    for part in tfm.parts:
        expr = to_funsor(part,
                         output=output,
                         dim_to_name=dim_to_name,
                         real_inputs=real_inputs)(**{
                             name: expr
                         })
    return expr


to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor)
to_funsor.register(MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(
    torch.distributions.TransformedDistribution)(transformeddist_to_funsor)


@to_funsor.register(torch.distributions.Bernoulli)
def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None):
    new_pyro_dist = _PyroWrapper_BernoulliLogits(logits=pyro_dist.logits)
    return backenddist_to_funsor(BernoulliLogits, new_pyro_dist, output,
                                 dim_to_name)  # noqa: F821


@to_funsor.register(dist.Delta)  # Delta **distribution**
def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None):
    v = to_funsor(pyro_dist.v,
Beispiel #6
0
@children.register(torch.Tensor)
@children.register(torch.nn.Module)
def _children_ground(x):
    return ()


@quote.register(torch.Tensor)
def _quote(x, indent, out):
    """
    Work around PyTorch not supporting reproducible repr.
    """
    out.append((indent, f"torch.tensor({repr(x.tolist())}, dtype={x.dtype})"))


to_funsor.register(torch.Tensor)(tensor_to_funsor)


@dispatch(torch.Tensor, torch.Tensor, [float])
def allclose(a, b, rtol=1e-05, atol=1e-08):
    return torch.allclose(a, b, rtol=rtol, atol=atol)


################################################################################
# Register Ops
################################################################################

ops.abs.register(torch.Tensor)(torch.abs)
ops.cholesky_solve.register(torch.Tensor, torch.Tensor)(torch.cholesky_solve)
ops.clamp.register(torch.Tensor, object, object)(torch.clamp)
ops.exp.register(torch.Tensor)(torch.exp)