def make_all(self, profiler=None, input_storage=None, output_storage=None): """ :param profiler: WRITEME :param input_storage: WRITEME :param output_storage: WRITEME :returns: function to run all nodes, list of input containers, list of output containers, list of thunks (for all of program), list of nodes (for all of program) """ fgraph = self.fgraph order = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, order, input_storage, output_storage) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] thunks = [] for node in order: # Maker sure we don't use C version of the code, but rather only # the python version # Note : ops that implement their own make thunk don't usually # have this attribute defiend !! old_value = getattr(node.op, '_op_use_c_code', False) try: node.op._op_use_c_code = False thunks += [ node.op.make_thunk(node, storage_map, compute_map, no_recycling) ] thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs] finally: node.op._op_use_c_code = old_value computed, last_user = gc_helper(order) if self.allow_gc: post_thunk_old_storage = [] else: post_thunk_old_storage = None for node in order: if self.allow_gc: post_thunk_old_storage.append([ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and node == last_user[input] ]) if no_recycling is True: # True seems like some special code for *everything*?? -JB # FunctionMaker always passes a list I think -JB no_recycling = storage_map.values() no_recycling = utils.difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] # The function that actually runs your program is one of the f's in streamline. f = streamline(fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling, profiler=profiler) f.allow_gc = self.allow_gc #HACK: this is a way of passing an arg to Function.__call__ add_clear_storage(f, computed, storage_map) return f, [Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage)], \ [Container(output, storage, True) for output, storage in zip(fgraph.outputs, output_storage)], \ thunks, order
def make_all(self, input_storage=None, output_storage=None, storage_map=None): """ Returns Function to run all nodes, list of input containers, list of outputs Parameters ---------- input_storage list of storages corresponding to fgraph.inputs output_storage list of storages corresponding to fgraph.outputs Returns ------- object Function to run all nodes, list of input containers, list of output containers, list of thunks (for all programs), list of nodes (for all programs). """ fgraph = self.fgraph order = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, order, input_storage, output_storage, storage_map ) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] thunks = [] for node in order: # Maker sure we don't use C version of the code, but rather only # the python version # Note : ops that implement their own make thunk don't usually # have this attribute defiend !! thunks += [ node.op.make_thunk(node, storage_map, compute_map, no_recycling, "py") ] thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs] computed, last_user = gc_helper(order) if self.allow_gc: post_thunk_old_storage = [] else: post_thunk_old_storage = None for node in order: if self.allow_gc: post_thunk_old_storage.append( [ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and (node == last_user[input]) ] ) if no_recycling is True: # True seems like some special code for *everything*?? -JB # FunctionMaker always passes a list I think -JB no_recycling = list(storage_map.values()) no_recycling = utils.difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] # The function that actually runs your program is one of the f's in streamline. f = streamline( fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling ) f.allow_gc = ( self.allow_gc ) # HACK: this is a way of passing an arg to Function.__call__ add_clear_storage(f, computed, storage_map) f.storage_map = storage_map return ( f, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, order, )
def make_all(self, input_storage=None, output_storage=None): """ :param input_storage: WRITEME :param output_storage: WRITEME :returns: function to run all nodes, list of input containers, list of output containers, list of thunks (for all of program), list of nodes (for all of program) """ fgraph = self.fgraph order = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] thunks = [] for node in order: # Maker sure we don't use C version of the code, but rather only # the python version # Note : ops that implement their own make thunk don't usually # have this attribute defiend !! old_value = getattr(node.op, '_op_use_c_code', False) try: node.op._op_use_c_code = False thunks += [node.op.make_thunk(node, storage_map, compute_map, no_recycling)] thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs] finally: node.op._op_use_c_code = old_value computed, last_user = gc_helper(order) if self.allow_gc: post_thunk_old_storage = [] else: post_thunk_old_storage = None for node in order: if self.allow_gc: post_thunk_old_storage.append( [storage_map[input] for input in node.inputs if (input in computed) and ( input not in fgraph.outputs) and ( node == last_user[input])]) if no_recycling is True: # True seems like some special code for *everything*?? -JB # FunctionMaker always passes a list I think -JB no_recycling = list(storage_map.values()) no_recycling = utils.difference(no_recycling, input_storage) else: no_recycling = [storage_map[r] for r in no_recycling if r not in fgraph.inputs] # The function that actually runs your program is one of the f's in streamline. f = streamline(fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling) f.allow_gc = self.allow_gc # HACK: this is a way of passing an arg to Function.__call__ add_clear_storage(f, computed, storage_map) f.storage_map = storage_map return (f, [Container(input, storage) for input, storage in izip(fgraph.inputs, input_storage)], [Container(output, storage, True) for output, storage in izip(fgraph.outputs, output_storage)], thunks, order)
def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph nodes = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, nodes, input_storage, output_storage, storage_map) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] try: # We need to create thunk functions that will populate the output # storage arrays with the JAX-computed values. thunks, nodes = self.create_jax_thunks(compute_map, storage_map) except NotImplementedError as e: if not self.allow_non_jax: raise warn(f"JaxLinker could not JAXify graph: {e}") thunks = [] for node in nodes: thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling, "py") thunk_inputs = [storage_map[v] for v in node.inputs] thunk_outputs = [storage_map[v] for v in node.outputs] thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunks.append(thunk) computed, last_user = gc_helper(nodes) if self.allow_gc: post_thunk_old_storage = [] for node in nodes: post_thunk_old_storage.append([ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and (node == last_user[input]) ]) else: post_thunk_old_storage = None if no_recycling is True: no_recycling = list(storage_map.values()) no_recycling = utils.difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] fn = streamline(fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling) fn.allow_gc = self.allow_gc fn.storage_map = storage_map return ( fn, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, readonly=True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, nodes, )