def trace(self, *args, **kwargs): """ Call the symbolic function on args and kwargs, returning the symbolic result and storing all shadowed variables. """ # clean args and kwargs c_args, c_kwargs = utils.clean_int_args(*args, **kwargs) # call the symfn results = self.symfn(*c_args, **c_kwargs) # get a tuple of the symbolic inputs # but avoid 'self' and 'cls' bound arguments all_args = utils.expandedcallargs(self.symfn, *c_args, **c_kwargs) if (inspect.ismethod(self.pyfn) or (len(all_args) > 0 and type(all_args[0]) is type)): all_args = all_args[1:] # store the inputs and outputs so they can be accessed later # self.s_inputs = tuple(self.get_symbolic(a) for a in all_args) # self.s_outputs = utils.as_seq(results, tuple) inputs = tuple(self.get_symbolic(a) for a in all_args) # outputs = utils.as_seq(results, tuple) return inputs, results
def recompile(self, f, nested=False): """ Accepts a function f that operates on numerical objects and returns a function that operates on Theano objects. nested : bool `recompile` resets the context and sets the 'top_node' of the function, which helps in tracing arguments. By passing nested=True, this reset can be bypassed. This is used, for example, when transforming nested functions. In this case, we want to use the same context but keep it when calling recompile. """ transformer = TheanoTransformer(context=self) f_ast = get_ast(f) if not nested: self._top_def = f_ast self.tags.clear() transformed_ast = fix_missing_locations(transformer.visit(f_ast)) f_globals = f.func_globals.copy() f_globals.update(dict(_ctx__=transformer, _functions__=autodiff.functions, _T__=theano.tensor, _utils__=autodiff.utils)) if f.func_closure: f_globals.update((v, transformer.shadow(c.cell_contents)) for v, c in zip(f.func_code.co_freevars, f.func_closure)) for name in f.func_code.co_names: if name in f_globals.iterkeys(): f_globals[name] = transformer.shadow(f_globals[name]) try: new_f = meta.decompiler.compile_func(ast_node=transformed_ast, filename='<Context-AST>', globals=f_globals) except SyntaxError as err: if "'return' with argument inside generator" in err.message: if isinstance(transformed_ast.body[-1], Return): transformed_ast.body.pop(-1) new_f = meta.decompiler.compile_func( ast_node=transformed_ast, filename='<Context-AST>', globals=f_globals) else: raise except: raise # add defaults, if necessary (meta erases them and won't recompile!) if f.func_defaults: new_f.func_defaults = utils.clean_int_args(*f.func_defaults)[0] # recreate method, if necessary if isinstance(f, types.MethodType): new_f = types.MethodType(new_f, f.im_self, f.im_class) return new_f