def general_toposort(r_out, deps, debug_print=False): """WRITEME :note: deps(i) should behave like a pure function (no funny business with internal state) :note: deps(i) will be cached by this function (to be fast) :note: The order of the return value list is determined by the order of nodes returned by the deps() function. """ deps_cache = {} def _deps(io): if io not in deps_cache: d = deps(io) if d: deps_cache[io] = list(d) else: deps_cache[io] = d return d else: return deps_cache[io] assert isinstance(r_out, (tuple, list, deque)) reachable, clients = stack_search(deque(r_out), _deps, 'dfs', True) sources = deque([r for r in reachable if not deps_cache.get(r, None)]) rset = set() rlist = [] while sources: node = sources.popleft() if node not in rset: rlist.append(node) rset.add(node) for client in clients.get(node, []): deps_cache[client] = [ a for a in deps_cache[client] if a is not node ] if not deps_cache[client]: sources.append(client) if len(rlist) != len(reachable): if debug_print: print '' print reachable print rlist raise ValueError('graph contains cycles') return rlist
def general_toposort(r_out, deps, debug_print = False): """WRITEME :note: deps(i) should behave like a pure function (no funny business with internal state) :note: deps(i) will be cached by this function (to be fast) :note: The order of the return value list is determined by the order of nodes returned by the deps() function. """ deps_cache = {} def _deps(io): if io not in deps_cache: d = deps(io) if d: deps_cache[io] = list(d) else: deps_cache[io] = d return d else: return deps_cache[io] assert isinstance(r_out, (tuple, list, deque)) reachable, clients = stack_search( deque(r_out), _deps, 'dfs', True) sources = deque([r for r in reachable if not deps_cache.get(r, None)]) rset = set() rlist = [] while sources: node = sources.popleft() if node not in rset: rlist.append(node) rset.add(node) for client in clients.get(node, []): deps_cache[client] = [a for a in deps_cache[client] if a is not node] if not deps_cache[client]: sources.append(client) if len(rlist) != len(reachable): if debug_print: print '' print reachable print rlist raise ValueError('graph contains cycles') return rlist
def apply(self, env, start_from = None): if start_from is None: start_from = env.outputs changed = True max_use_abort = False opt_name = None process_count = {} while changed and not max_use_abort: changed = False #apply global optimizer env.change_tracker.reset() for gopt in self.global_optimizers: gopt.apply(env) if env.change_tracker.changed: changed = True #apply local optimizer for node in start_from: assert node in env.outputs q = deque(graph.io_toposort(env.inputs, start_from)) max_use = len(q) * self.max_use_ratio def importer(node): if node is not current_node: q.append(node) def pruner(node): if node is not current_node: try: q.remove(node) except ValueError: pass u = self.attach_updater(env, importer, pruner) try: while q: node = q.pop() current_node = node for lopt in self.local_optimizers: process_count.setdefault(lopt, 0) if process_count[lopt] > max_use: max_use_abort = True opt_name = (getattr(lopt, "name", None) or getattr(lopt, "__name__", None) or "") else: lopt_change = self.process_node(env, node, lopt) if lopt_change: process_count[lopt] += 1 changed = True if node not in env.nodes: break# go to next node finally: self.detach_updater(env, u) self.detach_updater(env, u) #TODO: erase this line, it's redundant at best if max_use_abort: _logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name + ". You can safely raise the current threshold of " + "%f with the theano flag 'optdb.max_use_ratio'." % config.optdb.max_use_ratio)
def variables_and_orphans(i, o): """WRITEME """ def expand(r): if r.owner and r not in i: l = list(r.owner.inputs) + list(r.owner.outputs) l.reverse() return l variables = stack_search(deque(o), expand, 'dfs') orphans = [r for r in variables if r.owner is None and r not in i] return variables, orphans
def ancestors(variable_list, blockers=None): """Return the variables that contribute to those in variable_list (inclusive). :type variable_list: list of `Variable` instances :param variable_list: output `Variable` instances from which to search backward through owners :rtype: list of `Variable` instances :returns: all input nodes, in the order found by a left-recursive depth-first search started at the nodes in `variable_list`. """ def expand(r): if r.owner and (not blockers or r not in blockers): l = list(r.owner.inputs) l.reverse() return l dfs_variables = stack_search(deque(variable_list), expand, 'dfs') return dfs_variables
def inputs(variable_list, blockers = None): """Return the inputs required to compute the given Variables. :type variable_list: list of `Variable` instances :param variable_list: output `Variable` instances from which to search backward through owners :rtype: list of `Variable` instances :returns: input nodes with no owner, in the order found by a left-recursive depth-first search started at the nodes in `variable_list`. """ def expand(r): if r.owner and (not blockers or r not in blockers): l = list(r.owner.inputs) l.reverse() return l dfs_variables = stack_search(deque(variable_list), expand, 'dfs') rval = [r for r in dfs_variables if r.owner is None] #print rval, _orig_inputs(o) return rval
def apply(self, env, start_from = None): if start_from is None: start_from = env.outputs q = deque(graph.io_toposort(env.inputs, start_from)) def importer(node): if node is not current_node: q.append(node) def pruner(node): if node is not current_node: try: q.remove(node) except ValueError: pass u = self.attach_updater(env, importer, pruner) try: while q: if self.order == 'out_to_in': node = q.pop() else: node = q.popleft() current_node = node self.process_node(env, node) except Exception: self.detach_updater(env, u) raise self.detach_updater(env, u)
def _dfs_toposort(i, r_out, orderings): """ i - list of inputs o - list of outputs orderings - dict of additions to the normal inputs and outputs Returns nothing. Raises exception for graph with cycles """ #this is hard-coded reimplementation of functions from graph.py # reason: go faster, prepare for port to C. assert isinstance(r_out, (tuple, list, deque)) # TODO: For more speed - use a defaultdict for the orderings iset = set(i) if 0: def expand(obj): rval = [] if obj not in iset: if isinstance(obj, graph.Variable): if obj.owner: rval = [obj.owner] if isinstance(obj, graph.Apply): rval = list(obj.inputs) rval.extend(orderings.get(obj, [])) else: assert not orderings.get(obj, []) return rval expand_cache = {} # reachable, clients = stack_search( deque(r_out), deps, 'dfs', True) start = deque(r_out) rval_set = set() rval_set.add(id(None)) rval_list = list() expand_inv = {} sources = deque() while start: l = start.pop() # this makes the search dfs if id(l) not in rval_set: rval_list.append(l) rval_set.add(id(l)) if l in iset: assert not orderings.get(l, []) expand_l = [] else: try: if l.owner: expand_l = [l.owner] else: expand_l = [] except AttributeError: expand_l = list(l.inputs) expand_l.extend(orderings.get(l, [])) if expand_l: for r in expand_l: expand_inv.setdefault(r, []).append(l) start.extend(expand_l) else: sources.append(l) expand_cache[l] = expand_l assert len(rval_list) == len(rval_set) - 1 rset = set() rlist = [] while sources: node = sources.popleft() if node not in rset: rlist.append(node) rset.add(node) for client in expand_inv.get(node, []): expand_cache[client] = [ a for a in expand_cache[client] if a is not node ] if not expand_cache[client]: sources.append(client) if len(rlist) != len(rval_list): raise ValueError('graph contains cycles')