예제 #1
0
    def make_all(self,
                 input_storage=None,
                 output_storage=None,
                 storage_map=None):
        """
        Returns Function to run all nodes, list of input containers, list of outputs

        Parameters
        ----------
        input_storage
            list of storages corresponding to fgraph.inputs
        output_storage
            list of storages corresponding to fgraph.outputs

        Returns
        -------
        object
            Function to run all nodes, list of input containers, list of output
            containers, list of thunks (for all programs), list of nodes
            (for all programs).

        """
        fgraph = self.fgraph
        order = self.schedule(fgraph)
        no_recycling = self.no_recycling

        input_storage, output_storage, storage_map = map_storage(
            fgraph, order, input_storage, output_storage, storage_map)

        compute_map = {}
        for k in storage_map:
            compute_map[k] = [k.owner is None]

        thunks = []
        for node in order:
            # Maker sure we don't use C version of the code, but rather only
            # the python version
            # Note : ops that implement their own make thunk don't usually
            # have this attribute defiend !!
            thunks += [
                node.op.make_thunk(node, storage_map, compute_map,
                                   no_recycling, "py")
            ]
            thunks[-1].inputs = [storage_map[v] for v in node.inputs]
            thunks[-1].outputs = [storage_map[v] for v in node.outputs]

        computed, last_user = gc_helper(order)
        if self.allow_gc:
            post_thunk_old_storage = []
        else:
            post_thunk_old_storage = None

        for node in order:
            if self.allow_gc:
                post_thunk_old_storage.append([
                    storage_map[input] for input in node.inputs
                    if (input in computed) and (input not in fgraph.outputs)
                    and (node == last_user[input])
                ])

        if no_recycling is True:
            # True seems like some special code for *everything*?? -JB
            # FunctionMaker always passes a list I think   -JB
            no_recycling = list(storage_map.values())
            no_recycling = utils.difference(no_recycling, input_storage)
        else:
            no_recycling = [
                storage_map[r] for r in no_recycling if r not in fgraph.inputs
            ]

        # The function that actually runs your program is one of the f's in streamline.
        f = streamline(fgraph,
                       thunks,
                       order,
                       post_thunk_old_storage,
                       no_recycling=no_recycling)

        f.allow_gc = (
            self.allow_gc
        )  # HACK: this is a way of passing an arg to Function.__call__
        f.storage_map = storage_map

        return (
            f,
            [
                Container(input, storage)
                for input, storage in zip(fgraph.inputs, input_storage)
            ],
            [
                Container(output, storage, readonly=True)
                for output, storage in zip(fgraph.outputs, output_storage)
            ],
            thunks,
            order,
        )
예제 #2
0
    def make_all(self,
                 input_storage=None,
                 output_storage=None,
                 storage_map=None):
        fgraph = self.fgraph
        nodes = self.schedule(fgraph)
        no_recycling = self.no_recycling

        input_storage, output_storage, storage_map = map_storage(
            fgraph, nodes, input_storage, output_storage, storage_map)

        compute_map = {}
        for k in storage_map:
            compute_map[k] = [k.owner is None]

        try:
            # We need to create thunk functions that will populate the output
            # storage arrays with the JAX-computed values.
            thunks, nodes = self.create_jax_thunks(compute_map, storage_map)

        except NotImplementedError as e:
            if not self.allow_non_jax:
                raise

            warn(f"JaxLinker could not JAXify graph: {e}")

            thunks = []
            for node in nodes:
                thunk = node.op.make_thunk(node, storage_map, compute_map,
                                           no_recycling, "py")
                thunk_inputs = [storage_map[v] for v in node.inputs]
                thunk_outputs = [storage_map[v] for v in node.outputs]

                thunk.inputs = thunk_inputs
                thunk.outputs = thunk_outputs

                thunks.append(thunk)

        computed, last_user = gc_helper(nodes)

        if self.allow_gc:
            post_thunk_old_storage = []

            for node in nodes:
                post_thunk_old_storage.append([
                    storage_map[input] for input in node.inputs
                    if (input in computed) and (input not in fgraph.outputs)
                    and (node == last_user[input])
                ])
        else:
            post_thunk_old_storage = None

        if no_recycling is True:
            no_recycling = list(storage_map.values())
            no_recycling = utils.difference(no_recycling, input_storage)
        else:
            no_recycling = [
                storage_map[r] for r in no_recycling if r not in fgraph.inputs
            ]

        fn = streamline(fgraph,
                        thunks,
                        nodes,
                        post_thunk_old_storage,
                        no_recycling=no_recycling)

        fn.allow_gc = self.allow_gc
        fn.storage_map = storage_map

        return (
            fn,
            [
                Container(input, storage)
                for input, storage in zip(fgraph.inputs, input_storage)
            ],
            [
                Container(output, storage, readonly=True)
                for output, storage in zip(fgraph.outputs, output_storage)
            ],
            thunks,
            nodes,
        )
예제 #3
0
파일: vm.py 프로젝트: michaelosthege/aesara
    def make_all(
        self,
        profiler=None,
        input_storage=None,
        output_storage=None,
        storage_map=None,
    ):
        fgraph = self.fgraph
        order = self.schedule(fgraph)

        input_storage, output_storage, storage_map = map_storage(
            fgraph, order, input_storage, output_storage, storage_map)
        compute_map = {}
        for k in storage_map:
            compute_map[k] = [k.owner is None]

        thunks = []

        # Collect Reallocation Info
        compute_map_re = defaultdict(lambda: [0])
        for var in fgraph.inputs:
            compute_map_re[var][0] = 1

        if getattr(fgraph.profile, "dependencies", None):
            dependencies = fgraph.profile.dependencies
        else:
            dependencies = self.compute_gc_dependencies(storage_map)

        reallocated_info = calculate_reallocate_info(order, fgraph,
                                                     storage_map,
                                                     compute_map_re,
                                                     dependencies)
        t0 = time.time()
        linker_make_thunk_time = {}
        impl = None
        if self.c_thunks is False:
            impl = "py"
        for node in order:
            try:
                thunk_start = time.time()
                # no-recycling is done at each VM.__call__ So there is
                # no need to cause duplicate c code by passing
                # no_recycling here.
                thunks.append(
                    node.op.make_thunk(node,
                                       storage_map,
                                       compute_map, [],
                                       impl=impl))
                linker_make_thunk_time[node] = time.time() - thunk_start
                if not hasattr(thunks[-1], "lazy"):
                    # We don't want all ops maker to think about lazy Ops.
                    # So if they didn't specify that its lazy or not, it isn't.
                    # If this member isn't present, it will crash later.
                    thunks[-1].lazy = False
            except Exception as e:
                e.args = (
                    "The following error happened while"
                    " compiling the node",
                    node,
                    "\n",
                ) + e.args
                raise
        t1 = time.time()

        if self.profile:
            self.profile.linker_node_make_thunks += t1 - t0
            self.profile.linker_make_thunk_time = linker_make_thunk_time

        for node, thunk in zip(order, thunks):
            thunk.inputs = [storage_map[v] for v in node.inputs]
            thunk.outputs = [storage_map[v] for v in node.outputs]

        lazy = self.lazy
        if lazy is None:
            lazy = config.vm__lazy
        if lazy is None:
            lazy = not all([(not th.lazy) for th in thunks])
        if not (lazy or ((config.profile or config.print_global_stats)
                         and config.profile_memory) or self.use_cloop
                or self.callback or self.callback_input):
            for pair in reallocated_info.values():
                storage_map[pair[1]] = storage_map[pair[0]]

        computed, last_user = gc_helper(order)
        if self.allow_gc:
            post_thunk_clear = []
            for node in order:
                clear_after_this_thunk = []
                for input in node.inputs:
                    if (input in computed and input not in fgraph.outputs
                            and node == last_user[input]
                            and input not in reallocated_info):
                        clear_after_this_thunk.append(storage_map[input])
                post_thunk_clear.append(clear_after_this_thunk)
        else:
            post_thunk_clear = None

        vm = self.make_vm(
            order,
            thunks,
            input_storage,
            output_storage,
            storage_map,
            post_thunk_clear,
            computed,
            compute_map,
            self.updated_vars,
        )

        vm.storage_map = storage_map
        vm.compute_map = compute_map

        return (
            vm,
            [
                Container(input, storage)
                for input, storage in zip(fgraph.inputs, input_storage)
            ],
            [
                Container(output, storage, readonly=True)
                for output, storage in zip(fgraph.outputs, output_storage)
            ],
            thunks,
            order,
        )