Exemplo n.º 1
0
def make_dist(backend_dist_class,
              param_names=(),
              generate_eager=True,
              generate_to_funsor=True):
    if not param_names:
        param_names = tuple(name for name in inspect.getfullargspec(
            backend_dist_class.__init__)[0][1:]
                            if name in backend_dist_class.arg_constraints)

    @makefun.with_signature("__init__(self, {}, value='value')".format(
        ', '.join(param_names)))
    def dist_init(self, **kwargs):
        return Distribution.__init__(
            self, *tuple(kwargs[k] for k in self._ast_fields))

    dist_class = DistributionMeta(
        backend_dist_class.__name__.split("Wrapper_")[-1], (Distribution, ), {
            'dist_class': backend_dist_class,
            '__init__': dist_init,
        })

    if generate_eager:
        eager.register(dist_class, *((Tensor, ) * (len(param_names) + 1)))(
            dist_class.eager_log_prob)

    if generate_to_funsor:
        to_funsor.register(backend_dist_class)(functools.partial(
            backenddist_to_funsor, dist_class))

    return dist_class
Exemplo n.º 2
0
def make_dist(pyro_dist_class, param_names=()):

    if not param_names:
        param_names = tuple(name for name in inspect.getfullargspec(pyro_dist_class.__init__)[0][1:]
                            if name in pyro_dist_class.arg_constraints)

    @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, value='value')")
    def dist_init(self, **kwargs):
        return Distribution.__init__(self, *tuple(kwargs[k] for k in self._ast_fields))

    dist_class = DistributionMeta(pyro_dist_class.__name__.split("_PyroWrapper_")[-1], (Distribution,), {
        'dist_class': pyro_dist_class,
        '__init__': dist_init,
    })

    eager.register(dist_class, *((Tensor,) * (len(param_names) + 1)))(dist_class.eager_log_prob)

    return dist_class
Exemplo n.º 3
0
    v = to_funsor(pyro_dist.v,
                  output=Reals[pyro_dist.event_shape],
                  dim_to_name=dim_to_name)
    log_density = to_funsor(pyro_dist.log_density,
                            output=Real,
                            dim_to_name=dim_to_name)
    return Delta(v, log_density)  # noqa: F821


JointDirichletMultinomial = Contraction[Union[ops.LogAddExpOp, ops.NullOp],
                                        ops.AddOp, frozenset,
                                        Tuple[Dirichlet,
                                              Multinomial],  # noqa: F821
                                        ]

eager.register(Beta, Funsor, Funsor, Funsor)(eager_beta)  # noqa: F821)
eager.register(Binomial, Funsor, Funsor, Funsor)(eager_binomial)  # noqa: F821
eager.register(Multinomial, Tensor, Tensor,
               Tensor)(eager_multinomial)  # noqa: F821)
eager.register(Categorical, Funsor,
               Tensor)(eager_categorical_funsor)  # noqa: F821)
eager.register(Categorical, Tensor,
               Variable)(eager_categorical_tensor)  # noqa: F821)
eager.register(Delta, Tensor, Tensor, Tensor)(eager_delta_tensor)  # noqa: F821
eager.register(Delta, Funsor, Funsor,
               Variable)(eager_delta_funsor_variable)  # noqa: F821
eager.register(Delta, Variable, Funsor,
               Variable)(eager_delta_funsor_variable)  # noqa: F821
eager.register(Delta, Variable, Funsor,
               Funsor)(eager_delta_funsor_funsor)  # noqa: F821
eager.register(Delta, Variable, Variable,
Exemplo n.º 4
0
        rhs = rhs(**{lhs.name: lhs.point})
        return op(lhs, rhs)

    return None  # defer to default implementation


@eager.register(Binary, AddOp, (Funsor, Align), Delta)
def eager_add(op, lhs, rhs):
    if rhs.name in lhs.inputs:
        lhs = lhs(**{rhs.name: rhs.point})
        return op(lhs, rhs)

    return None  # defer to default implementation


eager.register(Binary, AddOp, Delta,
               Reduce)(funsor.terms.eager_distribute_other_reduce)
eager.register(Binary, AddOp, Reduce,
               Delta)(funsor.terms.eager_distribute_reduce_other)


@eager.register(Independent, Delta, str, str)
def eager_independent(delta, reals_var, bint_var):
    if delta.name == reals_var or delta.name.startswith(reals_var + "__BOUND"):
        i = Variable(bint_var, delta.inputs[bint_var])
        point = Lambda(i, delta.point)
        if bint_var in delta.log_density.inputs:
            log_density = delta.log_density.reduce(ops.add, bint_var)
        else:
            log_density = delta.log_density * delta.inputs[bint_var].dtype
        return Delta(reals_var, point, log_density)
Exemplo n.º 5
0
# XXX: in Pyro backend, we always convert pyro.distributions.Categorical
# to funsor.torch.distributions.Categorical
@to_funsor.register(dist.CategoricalLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
    new_pyro_dist = _NumPyroWrapper_Categorical(probs=numpyro_dist.probs)
    return backenddist_to_funsor(new_pyro_dist, output, dim_to_name)


@to_funsor.register(dist.MultinomialProbs)
@to_funsor.register(dist.MultinomialLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
    new_pyro_dist = _NumPyroWrapper_Multinomial(probs=numpyro_dist.probs)
    return backenddist_to_funsor(new_pyro_dist, output, dim_to_name)


eager.register(Beta, Funsor, Funsor, Funsor)(eager_beta)  # noqa: F821)
eager.register(Binomial, Funsor, Funsor, Funsor)(eager_binomial)  # noqa: F821
eager.register(Multinomial, Tensor, Tensor,
               Tensor)(eager_multinomial)  # noqa: F821)
eager.register(Categorical, Funsor,
               Tensor)(eager_categorical_funsor)  # noqa: F821)
eager.register(Categorical, Tensor,
               Variable)(eager_categorical_tensor)  # noqa: F821)
eager.register(Delta, Tensor, Tensor, Tensor)(eager_delta_tensor)  # noqa: F821
eager.register(Delta, Funsor, Funsor,
               Variable)(eager_delta_funsor_variable)  # noqa: F821
eager.register(Delta, Variable, Funsor,
               Variable)(eager_delta_funsor_variable)  # noqa: F821
eager.register(Delta, Variable, Funsor,
               Funsor)(eager_delta_funsor_funsor)  # noqa: F821
eager.register(Delta, Variable, Variable,
Exemplo n.º 6
0
@eager.register(Binary, AddOp, Joint, Gaussian)
def eager_add(op, joint, other):
    # Update with a delayed gaussian random variable.
    subs = tuple(
        (d.name, d.point) for d in joint.deltas if d.name in other.inputs)
    if subs:
        other = Subs(other, subs)
    if joint.gaussian is not Number(0):
        other = joint.gaussian + other
    if not isinstance(other, Gaussian):
        return Joint(joint.deltas, joint.discrete) + other
    return Joint(joint.deltas, joint.discrete, other)


eager.register(Binary, AddOp, Reduce,
               Joint)(funsor.terms.eager_distribute_reduce_other)


@eager.register(Binary, AddOp, (Funsor, Align, Delta), Joint)
def eager_add(op, other, joint):
    return joint + other


################################################################################
# Patterns to create a Joint from elementary funsors
################################################################################


@eager.register(Binary, AddOp, Delta, Delta)
def eager_add(op, lhs, rhs):
    if lhs.name == rhs.name: