예제 #1
0
    def trace(self, *args, **kwargs):
        """
        Call the symbolic function on args and kwargs, returning the symbolic
        result and storing all shadowed variables.
        """
        # clean args and kwargs
        c_args, c_kwargs = utils.clean_int_args(*args, **kwargs)

        # call the symfn
        results = self.symfn(*c_args, **c_kwargs)

        # get a tuple of the symbolic inputs
        # but avoid 'self' and 'cls' bound arguments
        all_args = utils.expandedcallargs(self.symfn, *c_args, **c_kwargs)
        if (inspect.ismethod(self.pyfn) or
           (len(all_args) > 0 and type(all_args[0]) is type)):
            all_args = all_args[1:]

        # store the inputs and outputs so they can be accessed later
        # self.s_inputs = tuple(self.get_symbolic(a) for a in all_args)
        # self.s_outputs = utils.as_seq(results, tuple)
        inputs = tuple(self.get_symbolic(a) for a in all_args)
        # outputs = utils.as_seq(results, tuple)

        return inputs, results
예제 #2
0
    def trace(self, *args, **kwargs):
        """
        Call the symbolic function on args and kwargs, returning the symbolic
        result and storing all shadowed variables.
        """
        # clean args and kwargs
        c_args, c_kwargs = utils.clean_int_args(*args, **kwargs)

        # call the symfn
        results = self.symfn(*c_args, **c_kwargs)

        # get a tuple of the symbolic inputs
        # but avoid 'self' and 'cls' bound arguments
        all_args = utils.expandedcallargs(self.symfn, *c_args, **c_kwargs)
        if (inspect.ismethod(self.pyfn) or
           (len(all_args) > 0 and type(all_args[0]) is type)):
            all_args = all_args[1:]

        # store the inputs and outputs so they can be accessed later
        # self.s_inputs = tuple(self.get_symbolic(a) for a in all_args)
        # self.s_outputs = utils.as_seq(results, tuple)
        inputs = tuple(self.get_symbolic(a) for a in all_args)
        # outputs = utils.as_seq(results, tuple)

        return inputs, results
예제 #3
0
    def recompile(self, f, nested=False):
        """
        Accepts a function f that operates on numerical objects and
        returns a function that operates on Theano objects.

        nested : bool
            `recompile` resets the context and sets the 'top_node' of the
            function, which helps in tracing arguments. By passing nested=True,
            this reset can be bypassed. This is used, for example, when
            transforming nested functions. In this case, we want to use the
            same context but keep it when calling recompile.
        """
        transformer = TheanoTransformer(context=self)

        f_ast = get_ast(f)

        if not nested:
            self._top_def = f_ast
            self.tags.clear()

        transformed_ast = fix_missing_locations(transformer.visit(f_ast))

        f_globals = f.func_globals.copy()
        f_globals.update(dict(_ctx__=transformer,
                              _functions__=autodiff.functions,
                              _T__=theano.tensor,
                              _utils__=autodiff.utils))
        if f.func_closure:
            f_globals.update((v, transformer.shadow(c.cell_contents))
                             for v, c in
                             zip(f.func_code.co_freevars, f.func_closure))

        for name in f.func_code.co_names:
            if name in f_globals.iterkeys():
                f_globals[name] = transformer.shadow(f_globals[name])

        try:
            new_f = meta.decompiler.compile_func(ast_node=transformed_ast,
                                                 filename='<Context-AST>',
                                                 globals=f_globals)
        except SyntaxError as err:
            if "'return' with argument inside generator" in err.message:
                if isinstance(transformed_ast.body[-1], Return):
                    transformed_ast.body.pop(-1)
                    new_f = meta.decompiler.compile_func(
                        ast_node=transformed_ast,
                        filename='<Context-AST>',
                        globals=f_globals)
            else:
                raise
        except:
            raise

        # add defaults, if necessary (meta erases them and won't recompile!)
        if f.func_defaults:
            new_f.func_defaults = utils.clean_int_args(*f.func_defaults)[0]

        # recreate method, if necessary
        if isinstance(f, types.MethodType):
            new_f = types.MethodType(new_f, f.im_self, f.im_class)

        return new_f