def _function(inputs, output, fn): if isinstance(fn, torch.nn.Module): names = getargspec(fn.forward)[0][1:] else: names = getargspec(fn)[0] args = tuple(Variable(name, domain) for (name, domain) in zip(names, inputs)) assert len(args) == len(inputs) if not isinstance(output, Domain): assert isinstance(output, tuple) # Memoize multiple-output functions so that invocations can be shared among # all outputs. This is not foolproof, but does work in simple situations. fn = _Memoized(fn) return _nested_function(fn, args, output)
def _of_shape(fn, shape): args, vargs, kwargs, defaults = getargspec(fn) assert not vargs assert not kwargs names = tuple(args) args = [Variable(name, size) for name, size in zip(names, shape)] return to_funsor(fn(*args)).align(names)
def _function(inputs, output, fn): if is_nn_module(fn): names = getargspec(fn.forward)[0][1:] else: names = getargspec(fn)[0] if isinstance(inputs, dict): args = tuple( Variable(name, inputs[name]) for name in names if name in inputs) else: args = tuple( Variable(name, domain) for (name, domain) in zip(names, inputs)) assert len(args) == len(inputs) if not isinstance(output, ArrayType): assert output.__origin__ in (tuple, typing.Tuple) # Memoize multiple-output functions so that invocations can be shared among # all outputs. This is not foolproof, but does work in simple situations. fn = _Memoized(fn) return _nested_function(fn, args, output)
def __init__(cls, name, bases, dct): super(FunsorMeta, cls).__init__(name, bases, dct) if not hasattr(cls, "__args__"): cls.__args__ = () if cls.__args__: base, = bases cls.__origin__ = base else: cls._ast_fields = getargspec(cls.__init__)[0][1:] cls._cons_cache = WeakValueDictionary() cls._type_cache = WeakValueDictionary()
def __call__(self, cls, args, kwargs): # Check whether distribution class takes any tensor inputs. arg_constraints = getattr(cls, "arg_constraints", None) if not arg_constraints: return # Check whether any tensor inputs are actually funsors. try: ast_fields = cls._funsor_ast_fields except AttributeError: ast_fields = cls._funsor_ast_fields = getargspec( cls.__init__)[0][1:] kwargs = { name: value for pairs in (zip(ast_fields, args), kwargs.items()) for name, value in pairs } if not any( isinstance(value, (str, Funsor)) for name, value in kwargs.items() if name in arg_constraints): return # Check for a corresponding funsor class. try: funsor_cls = cls._funsor_cls except AttributeError: funsor_cls = getattr(self.module, cls.__name__, None) # resolve the issues Binomial/Multinomial are functions in NumPyro, which # fallback to either BinomialProbs or BinomialLogits if funsor_cls is None and cls.__name__.endswith("Probs"): funsor_cls = getattr(self.module, cls.__name__[:-5], None) cls._funsor_cls = funsor_cls if funsor_cls is None: warnings.warn("missing funsor for {}".format(cls.__name__), RuntimeWarning) return # Coerce to funsor. return funsor_cls(**kwargs)