Example #1
0
def general_toposort(r_out, deps, debug_print=False):
    """WRITEME

    :note:
        deps(i) should behave like a pure function (no funny business with internal state)

    :note:
        deps(i) will be cached by this function (to be fast)

    :note:
        The order of the return value list is determined by the order of nodes returned by the deps() function.
    """
    deps_cache = {}

    def _deps(io):
        if io not in deps_cache:
            d = deps(io)
            if d:
                deps_cache[io] = list(d)
            else:
                deps_cache[io] = d
            return d
        else:
            return deps_cache[io]

    assert isinstance(r_out, (tuple, list, deque))

    reachable, clients = stack_search(deque(r_out), _deps, 'dfs', True)
    sources = deque([r for r in reachable if not deps_cache.get(r, None)])

    rset = set()
    rlist = []
    while sources:
        node = sources.popleft()
        if node not in rset:
            rlist.append(node)
            rset.add(node)
            for client in clients.get(node, []):
                deps_cache[client] = [
                    a for a in deps_cache[client] if a is not node
                ]
                if not deps_cache[client]:
                    sources.append(client)

    if len(rlist) != len(reachable):
        if debug_print:
            print ''
            print reachable
            print rlist
        raise ValueError('graph contains cycles')

    return rlist
Example #2
0
def general_toposort(r_out, deps, debug_print = False):
    """WRITEME

    :note:
        deps(i) should behave like a pure function (no funny business with internal state)

    :note:
        deps(i) will be cached by this function (to be fast)

    :note:
        The order of the return value list is determined by the order of nodes returned by the deps() function.
    """
    deps_cache = {}
    def _deps(io):
        if io not in deps_cache:
            d = deps(io)
            if d:
                deps_cache[io] = list(d)
            else:
                deps_cache[io] = d
            return d
        else:
            return deps_cache[io]

    assert isinstance(r_out, (tuple, list, deque))

    reachable, clients = stack_search( deque(r_out), _deps, 'dfs', True)
    sources = deque([r for r in reachable if not deps_cache.get(r, None)])

    rset = set()
    rlist = []
    while sources:
        node = sources.popleft()
        if node not in rset:
            rlist.append(node)
            rset.add(node)
            for client in clients.get(node, []):
                deps_cache[client] = [a for a in deps_cache[client] if a is not node]
                if not deps_cache[client]:
                    sources.append(client)

    if len(rlist) != len(reachable):
        if debug_print:
            print ''
            print reachable
            print rlist
        raise ValueError('graph contains cycles')

    return rlist
Example #3
0
    def apply(self, env, start_from = None):
        if start_from is None:
            start_from = env.outputs
        changed = True
        max_use_abort = False
        opt_name = None
        process_count = {}

        while changed and not max_use_abort:
            changed = False

            #apply global optimizer
            env.change_tracker.reset()
            for gopt in self.global_optimizers:
                gopt.apply(env)
            if env.change_tracker.changed:
                changed = True

            #apply local optimizer
            for node in start_from:
                assert node in env.outputs

            q = deque(graph.io_toposort(env.inputs, start_from))

            max_use = len(q) * self.max_use_ratio
            def importer(node):
                if node is not current_node:
                    q.append(node)
            def pruner(node):
                if node is not current_node:
                    try: q.remove(node)
                    except ValueError: pass

            u = self.attach_updater(env, importer, pruner)
            try:
                while q:
                    node = q.pop()
                    current_node = node
                    for lopt in self.local_optimizers:
                        process_count.setdefault(lopt, 0)
                        if process_count[lopt] > max_use:
                            max_use_abort = True
                            opt_name = (getattr(lopt, "name", None)
                                        or getattr(lopt, "__name__", None) or "")
                        else:
                            lopt_change = self.process_node(env, node, lopt)
                            if lopt_change:
                                process_count[lopt] += 1
                                changed = True
                                if node not in env.nodes:
                                    break# go to next node
            finally:
                self.detach_updater(env, u)
            self.detach_updater(env, u) #TODO: erase this line, it's redundant at best
        if max_use_abort:
            _logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name
                          + ". You can safely raise the current threshold of "
                          + "%f with the theano flag 'optdb.max_use_ratio'." %
                          config.optdb.max_use_ratio)
Example #4
0
def variables_and_orphans(i, o):
    """WRITEME
    """
    def expand(r):
        if r.owner and r not in i:
            l = list(r.owner.inputs) + list(r.owner.outputs)
            l.reverse()
            return l
    variables = stack_search(deque(o), expand, 'dfs')
    orphans = [r for r in variables if r.owner is None and r not in i]
    return variables, orphans
Example #5
0
def variables_and_orphans(i, o):
    """WRITEME
    """
    def expand(r):
        if r.owner and r not in i:
            l = list(r.owner.inputs) + list(r.owner.outputs)
            l.reverse()
            return l

    variables = stack_search(deque(o), expand, 'dfs')
    orphans = [r for r in variables if r.owner is None and r not in i]
    return variables, orphans
Example #6
0
def ancestors(variable_list, blockers=None):
    """Return the variables that contribute to those in variable_list (inclusive).

    :type variable_list: list of `Variable` instances
    :param variable_list:
        output `Variable` instances from which to search backward through owners
    :rtype: list of `Variable` instances
    :returns:
        all input nodes, in the order found by a left-recursive depth-first search
        started at the nodes in `variable_list`.

    """
    def expand(r):
        if r.owner and (not blockers or r not in blockers):
            l = list(r.owner.inputs)
            l.reverse()
            return l
    dfs_variables = stack_search(deque(variable_list), expand, 'dfs')
    return dfs_variables
Example #7
0
def ancestors(variable_list, blockers=None):
    """Return the variables that contribute to those in variable_list (inclusive).

    :type variable_list: list of `Variable` instances
    :param variable_list:
        output `Variable` instances from which to search backward through owners
    :rtype: list of `Variable` instances
    :returns:
        all input nodes, in the order found by a left-recursive depth-first search
        started at the nodes in `variable_list`.

    """
    def expand(r):
        if r.owner and (not blockers or r not in blockers):
            l = list(r.owner.inputs)
            l.reverse()
            return l

    dfs_variables = stack_search(deque(variable_list), expand, 'dfs')
    return dfs_variables
Example #8
0
def inputs(variable_list, blockers = None):
    """Return the inputs required to compute the given Variables.

    :type variable_list: list of `Variable` instances
    :param variable_list:
        output `Variable` instances from which to search backward through owners
    :rtype: list of `Variable` instances
    :returns:
        input nodes with no owner, in the order found by a left-recursive depth-first search
        started at the nodes in `variable_list`.

    """
    def expand(r):
        if r.owner and (not blockers or r not in blockers):
            l = list(r.owner.inputs)
            l.reverse()
            return l
    dfs_variables = stack_search(deque(variable_list), expand, 'dfs')
    rval = [r for r in dfs_variables if r.owner is None]
    #print rval, _orig_inputs(o)
    return rval
Example #9
0
    def apply(self, env, start_from = None):
        if start_from is None: start_from = env.outputs
        q = deque(graph.io_toposort(env.inputs, start_from))
        def importer(node):
            if node is not current_node:
                q.append(node)
        def pruner(node):
            if node is not current_node:
                try: q.remove(node)
                except ValueError: pass

        u = self.attach_updater(env, importer, pruner)
        try:
            while q:
                if self.order == 'out_to_in':
                    node = q.pop()
                else:
                    node = q.popleft()
                current_node = node
                self.process_node(env, node)
        except Exception:
            self.detach_updater(env, u)
            raise
        self.detach_updater(env, u)
Example #10
0
    def apply(self, env, start_from = None):
        if start_from is None: start_from = env.outputs
        q = deque(graph.io_toposort(env.inputs, start_from))
        def importer(node):
            if node is not current_node:
                q.append(node)
        def pruner(node):
            if node is not current_node:
                try: q.remove(node)
                except ValueError: pass

        u = self.attach_updater(env, importer, pruner)
        try:
            while q:
                if self.order == 'out_to_in':
                    node = q.pop()
                else:
                    node = q.popleft()
                current_node = node
                self.process_node(env, node)
        except Exception:
            self.detach_updater(env, u)
            raise
        self.detach_updater(env, u)
Example #11
0
def _dfs_toposort(i, r_out, orderings):
    """
    i - list of inputs
    o - list of outputs
    orderings - dict of additions to the normal inputs and outputs

    Returns nothing.  Raises exception for graph with cycles
    """
    #this is hard-coded reimplementation of functions  from graph.py
    # reason: go faster, prepare for port to C.

    assert isinstance(r_out, (tuple, list, deque))

    # TODO: For more speed - use a defaultdict for the orderings

    iset = set(i)

    if 0:

        def expand(obj):
            rval = []
            if obj not in iset:
                if isinstance(obj, graph.Variable):
                    if obj.owner:
                        rval = [obj.owner]
                if isinstance(obj, graph.Apply):
                    rval = list(obj.inputs)
                rval.extend(orderings.get(obj, []))
            else:
                assert not orderings.get(obj, [])
            return rval

    expand_cache = {}
    # reachable, clients = stack_search( deque(r_out), deps, 'dfs', True)
    start = deque(r_out)
    rval_set = set()
    rval_set.add(id(None))
    rval_list = list()
    expand_inv = {}
    sources = deque()
    while start:
        l = start.pop()  # this makes the search dfs
        if id(l) not in rval_set:
            rval_list.append(l)
            rval_set.add(id(l))
            if l in iset:
                assert not orderings.get(l, [])
                expand_l = []
            else:
                try:
                    if l.owner:
                        expand_l = [l.owner]
                    else:
                        expand_l = []
                except AttributeError:
                    expand_l = list(l.inputs)
                expand_l.extend(orderings.get(l, []))
            if expand_l:
                for r in expand_l:
                    expand_inv.setdefault(r, []).append(l)
                start.extend(expand_l)
            else:
                sources.append(l)
            expand_cache[l] = expand_l
    assert len(rval_list) == len(rval_set) - 1

    rset = set()
    rlist = []
    while sources:
        node = sources.popleft()
        if node not in rset:
            rlist.append(node)
            rset.add(node)
            for client in expand_inv.get(node, []):
                expand_cache[client] = [
                    a for a in expand_cache[client] if a is not node
                ]
                if not expand_cache[client]:
                    sources.append(client)

    if len(rlist) != len(rval_list):
        raise ValueError('graph contains cycles')