def make_all(self, input_storage=None, output_storage=None, storage_map=None): """ Returns Function to run all nodes, list of input containers, list of outputs Parameters ---------- input_storage list of storages corresponding to fgraph.inputs output_storage list of storages corresponding to fgraph.outputs Returns ------- object Function to run all nodes, list of input containers, list of output containers, list of thunks (for all programs), list of nodes (for all programs). """ fgraph = self.fgraph order = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, order, input_storage, output_storage, storage_map) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] thunks = [] for node in order: # Maker sure we don't use C version of the code, but rather only # the python version # Note : ops that implement their own make thunk don't usually # have this attribute defiend !! thunks += [ node.op.make_thunk(node, storage_map, compute_map, no_recycling, "py") ] thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs] computed, last_user = gc_helper(order) if self.allow_gc: post_thunk_old_storage = [] else: post_thunk_old_storage = None for node in order: if self.allow_gc: post_thunk_old_storage.append([ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and (node == last_user[input]) ]) if no_recycling is True: # True seems like some special code for *everything*?? -JB # FunctionMaker always passes a list I think -JB no_recycling = list(storage_map.values()) no_recycling = utils.difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] # The function that actually runs your program is one of the f's in streamline. f = streamline(fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling) f.allow_gc = ( self.allow_gc ) # HACK: this is a way of passing an arg to Function.__call__ f.storage_map = storage_map return ( f, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, readonly=True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, order, )
def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph nodes = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, nodes, input_storage, output_storage, storage_map) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] try: # We need to create thunk functions that will populate the output # storage arrays with the JAX-computed values. thunks, nodes = self.create_jax_thunks(compute_map, storage_map) except NotImplementedError as e: if not self.allow_non_jax: raise warn(f"JaxLinker could not JAXify graph: {e}") thunks = [] for node in nodes: thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling, "py") thunk_inputs = [storage_map[v] for v in node.inputs] thunk_outputs = [storage_map[v] for v in node.outputs] thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunks.append(thunk) computed, last_user = gc_helper(nodes) if self.allow_gc: post_thunk_old_storage = [] for node in nodes: post_thunk_old_storage.append([ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and (node == last_user[input]) ]) else: post_thunk_old_storage = None if no_recycling is True: no_recycling = list(storage_map.values()) no_recycling = utils.difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] fn = streamline(fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling) fn.allow_gc = self.allow_gc fn.storage_map = storage_map return ( fn, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, readonly=True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, nodes, )
def make_all( self, profiler=None, input_storage=None, output_storage=None, storage_map=None, ): fgraph = self.fgraph order = self.schedule(fgraph) input_storage, output_storage, storage_map = map_storage( fgraph, order, input_storage, output_storage, storage_map) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] thunks = [] # Collect Reallocation Info compute_map_re = defaultdict(lambda: [0]) for var in fgraph.inputs: compute_map_re[var][0] = 1 if getattr(fgraph.profile, "dependencies", None): dependencies = fgraph.profile.dependencies else: dependencies = self.compute_gc_dependencies(storage_map) reallocated_info = calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, dependencies) t0 = time.time() linker_make_thunk_time = {} impl = None if self.c_thunks is False: impl = "py" for node in order: try: thunk_start = time.time() # no-recycling is done at each VM.__call__ So there is # no need to cause duplicate c code by passing # no_recycling here. thunks.append( node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)) linker_make_thunk_time[node] = time.time() - thunk_start if not hasattr(thunks[-1], "lazy"): # We don't want all ops maker to think about lazy Ops. # So if they didn't specify that its lazy or not, it isn't. # If this member isn't present, it will crash later. thunks[-1].lazy = False except Exception as e: e.args = ( "The following error happened while" " compiling the node", node, "\n", ) + e.args raise t1 = time.time() if self.profile: self.profile.linker_node_make_thunks += t1 - t0 self.profile.linker_make_thunk_time = linker_make_thunk_time for node, thunk in zip(order, thunks): thunk.inputs = [storage_map[v] for v in node.inputs] thunk.outputs = [storage_map[v] for v in node.outputs] lazy = self.lazy if lazy is None: lazy = config.vm__lazy if lazy is None: lazy = not all([(not th.lazy) for th in thunks]) if not (lazy or ((config.profile or config.print_global_stats) and config.profile_memory) or self.use_cloop or self.callback or self.callback_input): for pair in reallocated_info.values(): storage_map[pair[1]] = storage_map[pair[0]] computed, last_user = gc_helper(order) if self.allow_gc: post_thunk_clear = [] for node in order: clear_after_this_thunk = [] for input in node.inputs: if (input in computed and input not in fgraph.outputs and node == last_user[input] and input not in reallocated_info): clear_after_this_thunk.append(storage_map[input]) post_thunk_clear.append(clear_after_this_thunk) else: post_thunk_clear = None vm = self.make_vm( order, thunks, input_storage, output_storage, storage_map, post_thunk_clear, computed, compute_map, self.updated_vars, ) vm.storage_map = storage_map vm.compute_map = compute_map return ( vm, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, readonly=True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, order, )