def test_sort_apply_nodes(): x = tensor.matrix('x') y = tensor.dot(x * 2, x + 1) str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort nodes = sort_apply_nodes([x], [y], cmps=[str_cmp]) for a, b in zip(nodes[:-1], nodes[1:]): assert str(a) <= str(b)
def posort(l, *cmps): """ Partially ordered sort with multiple comparators Given a list of comparators order the elements in l so that the comparators are satisfied as much as possible giving precedence to earlier comparators. inputs: l - an iterable of nodes in a graph cmps - a sequence of comparator functions that describe which nodes should come before which others outputs: a list of nodes which satisfy the comparators as much as possible. >>> lower_tens = lambda a, b: a/10 - b/10 # prefer lower numbers div 10 >>> prefer evens = lambda a, b: a%2 - b%2 # prefer even numbers >>> posort(range(20), lower_tens, prefer_evens) [0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15] implemented with _toposort """ comes_before = dict((a, set()) for a in l) comes_after = dict((a, set()) for a in l) def add_links(a, b): # b depends on a comes_after[a].add(b) comes_after[a].update(comes_after[b]) for c in comes_before[a]: comes_after[c].update(comes_after[a]) comes_before[b].add(a) comes_before[b].update(comes_before[a]) for c in comes_after[b]: comes_before[c].update(comes_before[b]) def check(): """ Tests for cycles in manufactured edges """ for a in l: for b in l: assert not (b in comes_after[a] and a in comes_after[b]) for cmp in cmps: for a in l: for b in l: if cmp(a, b) < 0: # a wants to come before b # if this wouldn't cause a cycle and isn't already known if not b in comes_before[a] and not b in comes_after[a]: add_links(a, b) # check() # debug code return _toposort(comes_after)
def posort(l, *cmps): """ Partially ordered sort with multiple comparators Given a list of comparators order the elements in l so that the comparators are satisfied as much as possible giving precedence to earlier comparators. inputs: l - an iterable of nodes in a graph cmps - a sequence of comparator functions that describe which nodes should come before which others outputs: a list of nodes which satisfy the comparators as much as possible. >>> lower_tens = lambda a, b: a/10 - b/10 # prefer lower numbers div 10 >>> prefer evens = lambda a, b: a%2 - b%2 # prefer even numbers >>> posort(range(20), lower_tens, prefer_evens) [0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15] implemented with _toposort """ comes_before = dict((a, set()) for a in l) comes_after = dict((a, set()) for a in l) def add_links(a, b): # b depends on a comes_after[a].add(b) comes_after[a].update(comes_after[b]) for c in comes_before[a]: comes_after[c].update(comes_after[a]) comes_before[b].add(a) comes_before[b].update(comes_before[a]) for c in comes_after[b]: comes_before[c].update(comes_before[b]) def check(): """ Tests for cycles in manufactured edges """ for a in l: for b in l: assert not(b in comes_after[a] and a in comes_after[b]) for cmp in cmps: for a in l: for b in l: if cmp(a, b) < 0: # a wants to come before b # if this wouldn't cause a cycle and isn't already known if not b in comes_before[a] and not b in comes_after[a]: add_links(a, b) # check() # debug code return _toposort(comes_after)
def test_sort_schedule_fn(): import theano from theano.gof.sched import sort_schedule_fn, make_depends x = theano.tensor.matrix('x') y = theano.tensor.dot(x[:5]*2, x.T+1).T str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort linker = theano.OpWiseCLinker(schedule=sort_schedule_fn(str_cmp)) mode = theano.Mode(linker=linker) f = theano.function((x,), (y,), mode=mode) nodes = f.maker.linker.make_all()[-1] depends = make_depends() for a, b in zip(nodes[:-1], nodes[1:]): if not depends((b, a)): assert str(a) < str(b)
def test_sort_schedule_fn(): import theano from theano.gof.sched import sort_schedule_fn, make_depends x = theano.tensor.matrix('x') y = theano.tensor.dot(x[:5] * 2, x.T + 1).T str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort linker = theano.OpWiseCLinker(schedule=sort_schedule_fn(str_cmp)) mode = theano.Mode(linker=linker) f = theano.function((x, ), (y, ), mode=mode) nodes = f.maker.linker.make_all()[-1] depends = make_depends() for a, b in zip(nodes[:-1], nodes[1:]): if not depends((b, a)): assert str(a) < str(b)
def key_cmp(a, b): return cmp(key(a), key(b))
def str_cmp(a, b): return cmp(str(a), str(b)) # lexicographical sort