예제 #1
0
파일: op.py 프로젝트: mgorny/aesara
    def make_py_thunk(
        self,
        node: Apply,
        storage_map: StorageMapType,
        compute_map: ComputeMapType,
        no_recycling: List[Variable],
        debug: bool = False,
    ) -> ThunkType:
        """Make a Python thunk.

        Like :meth:`Op.make_thunk` but only makes Python thunks.

        """
        node_input_storage = [storage_map[r] for r in node.inputs]
        node_output_storage = [storage_map[r] for r in node.outputs]

        if debug and hasattr(self, "debug_perform"):
            p = node.op.debug_perform  # type: ignore
        else:
            p = node.op.perform

        params = node.run_params()

        if params is NoParams:
            # default arguments are stored in the closure of `rval`
            @is_thunk_type
            def rval(
                p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
            ):
                r = p(n, [x[0] for x in i], o)
                for o in node.outputs:
                    compute_map[o][0] = True
                return r

        else:
            params_val = node.params_type.filter(params)

            @is_thunk_type
            def rval(
                p=p,
                i=node_input_storage,
                o=node_output_storage,
                n=node,
                params=params_val,
            ):
                r = p(n, [x[0] for x in i], o, params)
                for o in node.outputs:
                    compute_map[o][0] = True
                return r

        rval.inputs = node_input_storage
        rval.outputs = node_output_storage
        setattr(rval, "perform", p)
        rval.lazy = False
        return rval
예제 #2
0
    def make_py_thunk(
        self,
        node: Apply,
        storage_map: StorageMapType,
        compute_map: ComputeMapType,
        no_recycling: bool,
        debug: bool = False,
    ) -> ThunkType:
        """Make a Python thunk.

        Like `Op.make_thunk` but only makes python thunks.

        """
        node_input_storage = [storage_map[r] for r in node.inputs]
        node_output_storage = [storage_map[r] for r in node.outputs]

        if debug:
            p = node.op.debug_perform
        else:
            p = node.op.perform

        params = node.run_params()

        if params is NoParams:
            # default arguments are stored in the closure of `rval`
            @rval_decorator
            def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
                r = p(n, [x[0] for x in i], o)
                for o in node.outputs:
                    compute_map[o][0] = True
                return r

        else:
            params_val = node.params_type.filter(params)

            @rval_decorator
            def rval(
                p=p,
                i=node_input_storage,
                o=node_output_storage,
                n=node,
                params=params_val,
            ):
                r = p(n, [x[0] for x in i], o, params)
                for o in node.outputs:
                    compute_map[o][0] = True
                return r

        rval.inputs = node_input_storage
        rval.outputs = node_output_storage
        rval.perform = p
        rval.lazy = False
        return rval