def trace(self, *args, **kwargs): """ Call the symbolic function on args and kwargs, returning the symbolic result and storing all shadowed variables. """ # clean args and kwargs c_args, c_kwargs = utils.clean_int_args(*args, **kwargs) # call the symfn results = self.symfn(*c_args, **c_kwargs) # get a tuple of the symbolic inputs # but avoid 'self' and 'cls' bound arguments all_args = utils.expandedcallargs(self.symfn, *c_args, **c_kwargs) if (inspect.ismethod(self.pyfn) or (len(all_args) > 0 and type(all_args[0]) is type)): all_args = all_args[1:] # store the inputs and outputs so they can be accessed later # self.s_inputs = tuple(self.get_symbolic(a) for a in all_args) # self.s_outputs = utils.as_seq(results, tuple) inputs = tuple(self.get_symbolic(a) for a in all_args) # outputs = utils.as_seq(results, tuple) return inputs, results
def vector_from_args(self, args=None, kwargs=None): if args is None: args = () if kwargs is None: kwargs = dict() all_args = utils.expandedcallargs(self.pyfn, *args, **kwargs) return np.concatenate([np.asarray(a).flat for a in all_args])
def __init__(self, pyfn, init_args=None, init_kwargs=None, compile_fn=False, compile_grad=False, compile_hv=False, borrow=None, force_floatX=False, context=None): if not (compile_fn or compile_grad or compile_hv): raise ValueError('At least one of \'compile_fn\', ' '\'compile_grad\', or \'compile_hv\' ' 'must be True.') super(VectorArg, self).__init__(pyfn=pyfn, borrow=borrow, force_floatX=force_floatX, context=context) self.compile_fn = compile_fn self.compile_grad = compile_grad self.compile_hv = compile_hv if init_kwargs is None: init_kwargs = dict() self.init_args = init_args self.init_kwargs = init_kwargs self.all_init_args = utils.expandedcallargs(self.pyfn, *init_args, **init_kwargs) self.cache['fn'] = self.get_theano_fn(init_args, init_kwargs)
def call(self, *args, **kwargs): if '_vectors' in kwargs: vectors = kwargs.pop('_vectors') else: raise ValueError( 'Vectors must be passed the keyword \'_vectors\'.') vectors = utils.as_seq(vectors, tuple) all_args = utils.expandedcallargs(self.pyfn, *args, **kwargs) # avoid 'self' and 'cls' bound arguments if (inspect.ismethod(self.pyfn) or (len(all_args) > 0 and type(all_args[0]) is type)): all_args = all_args[1:] cache_key = tuple(np.asarray(a).ndim for a in all_args) if cache_key not in self.cache or not self.use_cache: self.cache[cache_key] = self.get_theano_fn(args, kwargs) fn = self.cache[cache_key] if len(self.wrt) > 0 and len(vectors) != len(self.wrt): raise ValueError('Expected {0} items in _vectors; received ' '{1}.'.format(len(self.wrt), len(vectors))) elif len(self.wrt) == 0 and len(vectors) != len(self.s_inputs): raise ValueError('Expected {0} items in _vectors; received ' '{1}.'.format(len(self.s_inputs), len(vectors))) return fn(*(all_args + vectors))
def __call__(self, *args, **kwargs): if 'vectors' in kwargs: vectors = kwargs.pop('vectors') else: raise ValueError( 'HessianVector must be called with the keyword \'vectors\'.') vectors = utils.as_seq(vectors, tuple) all_args = utils.expandedcallargs(self.symfn, *args, **kwargs) key = tuple(np.asarray(a).ndim for a in all_args) if key not in self.cache or not self.use_cache: self.context.reset() inputs, outputs = self.trace(*args, **kwargs) self.cache[key] = self.get_theano_function(inputs, outputs) fn = self.cache[key] if len(self.wrt) > 0 and len(vectors) != len(self.wrt): raise ValueError('Expected {0} items in `vectors`; received ' '{1}.'.format(len(self.wrt), len(vectors))) elif len(self.wrt) == 0 and len(vectors) != len(inputs): raise ValueError('Expected {0} items in `vectors`; received ' '{1}.'.format(len(inputs), len(vectors))) return fn(*(all_args + vectors))
def vector_from_args(self, args, kwargs): if len(args) + len(kwargs) > 1: all_args = utils.expandedcallargs(self.pyfn, *args, **kwargs) return np.concatenate([np.asarray(a).flatten() for a in all_args]) elif len(args) > 0: return np.asarray(args[0]).flatten() elif len(kwargs) > 0: return np.asarray(kwargs.values()[0]).flatten() else: return None
def vector_from_args(self, args, kwargs): if len(args) + len(kwargs) > 1: all_args = utils.expandedcallargs(self.pyfn, *args, **kwargs) return np.concatenate([np.asarray(a).flatten() for a in all_args]) elif len(args) > 0: return np.asarray(args[0]).flatten() elif len(kwargs) > 0: return np.asarray(list(kwargs.values())[0]).flatten() else: return None
def call(self, *args, **kwargs): all_args = utils.expandedcallargs(self.pyfn, *args, **kwargs) # avoid 'self' and 'cls' bound arguments if (inspect.ismethod(self.pyfn) or (len(all_args) > 0 and type(all_args[0]) is type)): all_args = all_args[1:] cache_key = tuple(np.asarray(a).ndim for a in all_args) if cache_key not in self.cache or not self.use_cache: self.cache[cache_key] = self.get_theano_fn(args, kwargs) fn = self.cache[cache_key] return fn(*all_args)
def __call__(self, *args, **kwargs): all_args = utils.expandedcallargs(self.symfn, *args, **kwargs) if (inspect.ismethod(self.pyfn) or (len(all_args) > 0 and type(all_args[0]) is type)): all_args = all_args[1:] key = tuple( (np.asarray(a).ndim, np.asarray(a).dtype) for a in all_args) if key not in self.cache or not self.use_cache: self.context.reset() inputs, outputs = self.trace(*args, **kwargs) self.cache[key] = self.get_theano_function(inputs, outputs) fn = self.cache[key] return fn(*all_args)
def __init__(self, pyfn, init_args=None, init_kwargs=None, context=None, force_floatX=False, borrowable=None, ignore=None, infer_updates=False, escape_on_error=False, function=False, gradient=False, hessian_vector=False): if isinstance(pyfn, Symbolic): pyfn = pyfn.pyfn self.pyfn = pyfn init_args = utils.as_seq(init_args, tuple) init_kwargs = utils.as_seq(init_kwargs, dict) self.init_args = utils.expandedcallargs( pyfn, *init_args, **init_kwargs) def wrapped_function(vector): return pyfn(*escaped_call(self.args_from_vector, vector)) def wrapper(*args, **kwargs): vector = self.vector_from_args(args, kwargs) v_args = self.args_from_vector(vector) return vector, pyfn(*v_args) symbolic = Symbolic(pyfn=wrapper, context=context, force_floatX=force_floatX, infer_updates=infer_updates, borrowable=borrowable, ignore=ignore, escape_on_error=escape_on_error) _, (sym_vector, result) = symbolic.trace(*init_args, **init_kwargs) fn = symbolic.compile(function=function, gradient=gradient, hessian_vector=hessian_vector, inputs=sym_vector, outputs=result) self.fn = fn