示例#1
0
    def make_all(self,
                 input_storage=None,
                 output_storage=None,
                 storage_map=None):
        fgraph = self.fgraph
        nodes = self.schedule(fgraph)
        no_recycling = self.no_recycling

        input_storage, output_storage, storage_map = map_storage(
            fgraph, nodes, input_storage, output_storage, storage_map)

        compute_map = {}
        for k in storage_map:
            compute_map[k] = [k.owner is None]

        thunks, nodes, jit_fn = self.create_jitable_thunk(
            compute_map, nodes, input_storage, output_storage, storage_map)

        computed, last_user = gc_helper(nodes)

        if self.allow_gc:
            post_thunk_old_storage = []

            for node in nodes:
                post_thunk_old_storage.append([
                    storage_map[input] for input in node.inputs
                    if (input in computed) and (input not in fgraph.outputs)
                    and (node == last_user[input])
                ])
        else:
            post_thunk_old_storage = None

        if no_recycling is True:
            no_recycling = list(storage_map.values())
            no_recycling = difference(no_recycling, input_storage)
        else:
            no_recycling = [
                storage_map[r] for r in no_recycling if r not in fgraph.inputs
            ]

        fn = streamline(fgraph,
                        thunks,
                        nodes,
                        post_thunk_old_storage,
                        no_recycling=no_recycling)

        fn.jit_fn = jit_fn
        fn.allow_gc = self.allow_gc
        fn.storage_map = storage_map

        return (
            fn,
            [
                Container(input, storage)
                for input, storage in zip(fgraph.inputs, input_storage)
            ],
            [
                Container(output, storage, readonly=True)
                for output, storage in zip(fgraph.outputs, output_storage)
            ],
            thunks,
            nodes,
        )
示例#2
0
文件: basic.py 项目: geofiber/aesara
    def make_all(self,
                 input_storage=None,
                 output_storage=None,
                 storage_map=None):
        """
        Returns Function to run all nodes, list of input containers, list of outputs

        Parameters
        ----------
        input_storage
            list of storages corresponding to fgraph.inputs
        output_storage
            list of storages corresponding to fgraph.outputs

        Returns
        -------
        object
            Function to run all nodes, list of input containers, list of output
            containers, list of thunks (for all programs), list of nodes
            (for all programs).

        """
        fgraph = self.fgraph
        order = self.schedule(fgraph)
        no_recycling = self.no_recycling

        input_storage, output_storage, storage_map = map_storage(
            fgraph, order, input_storage, output_storage, storage_map)

        compute_map = {}
        for k in storage_map:
            compute_map[k] = [k.owner is None]

        thunks = []
        for node in order:
            # Maker sure we don't use C version of the code, but rather only
            # the python version
            # Note : ops that implement their own make thunk don't usually
            # have this attribute defiend !!
            thunks += [
                node.op.make_thunk(node, storage_map, compute_map,
                                   no_recycling, "py")
            ]
            thunks[-1].inputs = [storage_map[v] for v in node.inputs]
            thunks[-1].outputs = [storage_map[v] for v in node.outputs]

        computed, last_user = gc_helper(order)
        if self.allow_gc:
            post_thunk_old_storage = []
        else:
            post_thunk_old_storage = None

        for node in order:
            if self.allow_gc:
                post_thunk_old_storage.append([
                    storage_map[input] for input in node.inputs
                    if (input in computed) and (input not in fgraph.outputs)
                    and (node == last_user[input])
                ])

        if no_recycling is True:
            # True seems like some special code for *everything*?? -JB
            # FunctionMaker always passes a list I think   -JB
            no_recycling = list(storage_map.values())
            no_recycling = difference(no_recycling, input_storage)
        else:
            no_recycling = [
                storage_map[r] for r in no_recycling if r not in fgraph.inputs
            ]

        # The function that actually runs your program is one of the f's in streamline.
        f = streamline(fgraph,
                       thunks,
                       order,
                       post_thunk_old_storage,
                       no_recycling=no_recycling)

        f.allow_gc = (
            self.allow_gc
        )  # HACK: this is a way of passing an arg to Function.__call__
        f.storage_map = storage_map

        return (
            f,
            [
                Container(input, storage)
                for input, storage in zip(fgraph.inputs, input_storage)
            ],
            [
                Container(output, storage, readonly=True)
                for output, storage in zip(fgraph.outputs, output_storage)
            ],
            thunks,
            order,
        )
示例#3
0
    def make_all(self, input_storage=None, output_storage=None, storage_map=None):
        fgraph = self.fgraph
        nodes = self.schedule(fgraph)
        no_recycling = self.no_recycling

        input_storage, output_storage, storage_map = map_storage(
            fgraph, nodes, input_storage, output_storage, storage_map
        )

        compute_map = {}
        for k in storage_map:
            compute_map[k] = [k.owner is None]

        try:
            # We need to create thunk functions that will populate the output
            # storage arrays with the JAX-computed values.
            thunks, nodes = self.create_jax_thunks(compute_map, storage_map)

        except NotImplementedError as e:
            if not self.allow_non_jax:
                raise

            warn(f"JaxLinker could not JAXify graph: {e}")

            thunks = []
            for node in nodes:
                thunk = node.op.make_thunk(
                    node, storage_map, compute_map, no_recycling, "py"
                )
                thunk_inputs = [storage_map[v] for v in node.inputs]
                thunk_outputs = [storage_map[v] for v in node.outputs]

                thunk.inputs = thunk_inputs
                thunk.outputs = thunk_outputs

                thunks.append(thunk)

        computed, last_user = gc_helper(nodes)

        if self.allow_gc:
            post_thunk_old_storage = []

            for node in nodes:
                post_thunk_old_storage.append(
                    [
                        storage_map[input]
                        for input in node.inputs
                        if (input in computed)
                        and (input not in fgraph.outputs)
                        and (node == last_user[input])
                    ]
                )
        else:
            post_thunk_old_storage = None

        if no_recycling is True:
            no_recycling = list(storage_map.values())
            no_recycling = difference(no_recycling, input_storage)
        else:
            no_recycling = [
                storage_map[r] for r in no_recycling if r not in fgraph.inputs
            ]

        fn = streamline(
            fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
        )

        fn.allow_gc = self.allow_gc
        fn.storage_map = storage_map

        return (
            fn,
            [
                Container(input, storage)
                for input, storage in zip(fgraph.inputs, input_storage)
            ],
            [
                Container(output, storage, readonly=True)
                for output, storage in zip(fgraph.outputs, output_storage)
            ],
            thunks,
            nodes,
        )