def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph nodes = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, nodes, input_storage, output_storage, storage_map) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] thunks, nodes, jit_fn = self.create_jitable_thunk( compute_map, nodes, input_storage, output_storage, storage_map) computed, last_user = gc_helper(nodes) if self.allow_gc: post_thunk_old_storage = [] for node in nodes: post_thunk_old_storage.append([ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and (node == last_user[input]) ]) else: post_thunk_old_storage = None if no_recycling is True: no_recycling = list(storage_map.values()) no_recycling = difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] fn = streamline(fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling) fn.jit_fn = jit_fn fn.allow_gc = self.allow_gc fn.storage_map = storage_map return ( fn, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, readonly=True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, nodes, )
def make_all(self, input_storage=None, output_storage=None, storage_map=None): """ Returns Function to run all nodes, list of input containers, list of outputs Parameters ---------- input_storage list of storages corresponding to fgraph.inputs output_storage list of storages corresponding to fgraph.outputs Returns ------- object Function to run all nodes, list of input containers, list of output containers, list of thunks (for all programs), list of nodes (for all programs). """ fgraph = self.fgraph order = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, order, input_storage, output_storage, storage_map) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] thunks = [] for node in order: # Maker sure we don't use C version of the code, but rather only # the python version # Note : ops that implement their own make thunk don't usually # have this attribute defiend !! thunks += [ node.op.make_thunk(node, storage_map, compute_map, no_recycling, "py") ] thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs] computed, last_user = gc_helper(order) if self.allow_gc: post_thunk_old_storage = [] else: post_thunk_old_storage = None for node in order: if self.allow_gc: post_thunk_old_storage.append([ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and (node == last_user[input]) ]) if no_recycling is True: # True seems like some special code for *everything*?? -JB # FunctionMaker always passes a list I think -JB no_recycling = list(storage_map.values()) no_recycling = difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] # The function that actually runs your program is one of the f's in streamline. f = streamline(fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling) f.allow_gc = ( self.allow_gc ) # HACK: this is a way of passing an arg to Function.__call__ f.storage_map = storage_map return ( f, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, readonly=True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, order, )
def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph nodes = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, nodes, input_storage, output_storage, storage_map ) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] try: # We need to create thunk functions that will populate the output # storage arrays with the JAX-computed values. thunks, nodes = self.create_jax_thunks(compute_map, storage_map) except NotImplementedError as e: if not self.allow_non_jax: raise warn(f"JaxLinker could not JAXify graph: {e}") thunks = [] for node in nodes: thunk = node.op.make_thunk( node, storage_map, compute_map, no_recycling, "py" ) thunk_inputs = [storage_map[v] for v in node.inputs] thunk_outputs = [storage_map[v] for v in node.outputs] thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunks.append(thunk) computed, last_user = gc_helper(nodes) if self.allow_gc: post_thunk_old_storage = [] for node in nodes: post_thunk_old_storage.append( [ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and (node == last_user[input]) ] ) else: post_thunk_old_storage = None if no_recycling is True: no_recycling = list(storage_map.values()) no_recycling = difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] fn = streamline( fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling ) fn.allow_gc = self.allow_gc fn.storage_map = storage_map return ( fn, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, readonly=True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, nodes, )