Beispiel #1
0
def _function(inputs, output, fn):
    if isinstance(fn, torch.nn.Module):
        names = getargspec(fn.forward)[0][1:]
    else:
        names = getargspec(fn)[0]
    args = tuple(Variable(name, domain) for (name, domain) in zip(names, inputs))
    assert len(args) == len(inputs)
    if not isinstance(output, Domain):
        assert isinstance(output, tuple)
        # Memoize multiple-output functions so that invocations can be shared among
        # all outputs. This is not foolproof, but does work in simple situations.
        fn = _Memoized(fn)
    return _nested_function(fn, args, output)
Beispiel #2
0
def _of_shape(fn, shape):
    args, vargs, kwargs, defaults = getargspec(fn)
    assert not vargs
    assert not kwargs
    names = tuple(args)
    args = [Variable(name, size) for name, size in zip(names, shape)]
    return to_funsor(fn(*args)).align(names)
Beispiel #3
0
def _function(inputs, output, fn):
    if is_nn_module(fn):
        names = getargspec(fn.forward)[0][1:]
    else:
        names = getargspec(fn)[0]
    if isinstance(inputs, dict):
        args = tuple(
            Variable(name, inputs[name]) for name in names if name in inputs)
    else:
        args = tuple(
            Variable(name, domain) for (name, domain) in zip(names, inputs))
    assert len(args) == len(inputs)
    if not isinstance(output, ArrayType):
        assert output.__origin__ in (tuple, typing.Tuple)
        # Memoize multiple-output functions so that invocations can be shared among
        # all outputs. This is not foolproof, but does work in simple situations.
        fn = _Memoized(fn)
    return _nested_function(fn, args, output)
Beispiel #4
0
 def __init__(cls, name, bases, dct):
     super(FunsorMeta, cls).__init__(name, bases, dct)
     if not hasattr(cls, "__args__"):
         cls.__args__ = ()
     if cls.__args__:
         base, = bases
         cls.__origin__ = base
     else:
         cls._ast_fields = getargspec(cls.__init__)[0][1:]
         cls._cons_cache = WeakValueDictionary()
         cls._type_cache = WeakValueDictionary()
Beispiel #5
0
    def __call__(self, cls, args, kwargs):
        # Check whether distribution class takes any tensor inputs.
        arg_constraints = getattr(cls, "arg_constraints", None)
        if not arg_constraints:
            return

        # Check whether any tensor inputs are actually funsors.
        try:
            ast_fields = cls._funsor_ast_fields
        except AttributeError:
            ast_fields = cls._funsor_ast_fields = getargspec(
                cls.__init__)[0][1:]
        kwargs = {
            name: value
            for pairs in (zip(ast_fields, args), kwargs.items())
            for name, value in pairs
        }
        if not any(
                isinstance(value, (str, Funsor))
                for name, value in kwargs.items() if name in arg_constraints):
            return

        # Check for a corresponding funsor class.
        try:
            funsor_cls = cls._funsor_cls
        except AttributeError:
            funsor_cls = getattr(self.module, cls.__name__, None)
            # resolve the issues Binomial/Multinomial are functions in NumPyro, which
            # fallback to either BinomialProbs or BinomialLogits
            if funsor_cls is None and cls.__name__.endswith("Probs"):
                funsor_cls = getattr(self.module, cls.__name__[:-5], None)
            cls._funsor_cls = funsor_cls
        if funsor_cls is None:
            warnings.warn("missing funsor for {}".format(cls.__name__),
                          RuntimeWarning)
            return

        # Coerce to funsor.
        return funsor_cls(**kwargs)