Пример #1
0
def test_postorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)]
    expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)]
    expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)]
    expected4 = [w, x, y, x + y, w*(x + y), z, z + w*(x + y)]
    expected5 = [x, y, x + y, w, w*(x + y), x, x + w*(x + y)]
    expected6 = [y, x, x + y, w, w*(x + y), x, x + w*(x + y)]
    assert list(postorder_traversal(expr)) in [expected1, expected2,
                                               expected3, expected4,
                                               expected5, expected6]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(postorder_traversal(expr)) == [
        x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2,
        ExprCondPair.true_sentinel,
        ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
    ]

    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1, Tuple(x, 0, 1),
        Integral(x**2, Tuple(x, 0, 1))
    ]
    assert list(postorder_traversal(('abc', ('d', 'ef')))) == [
        'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
Пример #2
0
def test_preorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z + w*(x + y), z, w*(x + y), w, x + y, y, x]
    expected2 = [z + w*(x + y), z, w*(x + y), w, x + y, x, y]
    expected3 = [z + w*(x + y), w*(x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(preorder_traversal(expr)) == [
        Piecewise((x, x < 1), (x**2, True)), ExprCondPair(x, x < 1), x, x < 1,
        x, 1, ExprCondPair(x**2, True), x**2, x, 2, True
    ]
    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1, Tuple(x, 0, 1),
        Integral(x**2, Tuple(x, 0, 1))
    ]
    assert list(postorder_traversal(('abc', ('d', 'ef')))) == [
        'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]

    expr = (x**(y**z)) ** (x**(y**z))
    expected = [(x**(y**z))**(x**(y**z)), x**(y**z), x**(y**z)]
    result = []
    pt = preorder_traversal(expr)
    for i in pt:
        result.append(i)
        if i == x**(y**z):
            pt.skip()
    assert result == expected
Пример #3
0
def test_postorder_traversal():
    expr = z + w * (x + y)
    expected = [z, w, x, y, x + y, w * (x + y), w * (x + y) + z]
    assert list(postorder_traversal(expr, key=default_sort_key)) == expected

    expr = Piecewise((x, x < 1), (x**2, True))
    expected = [
        x, 1, x, x < 1,
        ExprCondPair(x, x < 1), ExprCondPair.true_sentinel, 2, x, x**2,
        ExprCondPair(x**2, True),
        Piecewise((x, x < 1), (x**2, True))
    ]
    assert list(postorder_traversal(expr, key=default_sort_key)) == expected
    assert list(postorder_traversal(
        [expr], key=default_sort_key)) == expected + [[expr]]

    assert list(
        postorder_traversal(Integral(x**2, (x, 0, 1)),
                            key=default_sort_key)) == [
                                2, x, x**2, 0, 1, x,
                                Tuple(x, 0, 1),
                                Integral(x**2, Tuple(x, 0, 1))
                            ]
    assert list(postorder_traversal(
        ('abc',
         ('d',
          'ef')))) == ['abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
Пример #4
0
def test_preorder_traversal():
    expr = z + w * (x + y)
    expected1 = [z + w * (x + y), z, w * (x + y), w, x + y, y, x]
    expected2 = [z + w * (x + y), z, w * (x + y), w, x + y, x, y]
    expected3 = [z + w * (x + y), w * (x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x, x < 1), (x**2, True))
    assert list(preorder_traversal(expr)) == [
        Piecewise((x, x < 1), (x**2, True)),
        ExprCondPair(x, x < 1), x, x < 1, x, 1,
        ExprCondPair(x**2, True), x**2, x, 2, True
    ]
    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1,
        Tuple(x, 0, 1),
        Integral(x**2, Tuple(x, 0, 1))
    ]
    assert list(postorder_traversal(
        ('abc',
         ('d',
          'ef')))) == ['abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]

    expr = (x**(y**z))**(x**(y**z))
    expected = [(x**(y**z))**(x**(y**z)), x**(y**z), x**(y**z)]
    result = []
    pt = preorder_traversal(expr)
    for i in pt:
        result.append(i)
        if i == x**(y**z):
            pt.skip()
    assert result == expected
Пример #5
0
    def time_substitutions(self, sympy_expr):
        """This method checks through the sympy_expr to replace the time index with
        a cyclic index but only for variables which are not being saved in the time domain
        :param sympy_expr: The Sympy expression to process
        :returns: The expression after the substitutions
        """
        subs_dict = {}

        # For Iteration objects we apply time subs to the stencil list
        if isinstance(sympy_expr, Iteration):
            sympy_expr.expressions = [
                Expression(self.time_substitutions(s.stencil))
                for s in sympy_expr.expressions
            ]
            return sympy_expr

        for arg in postorder_traversal(sympy_expr):
            if isinstance(arg, Indexed):
                array_term = arg

                if not str(array_term.base.label) in self.save_vars:
                    raise ValueError(
                        "Invalid variable '%s' in sympy expression."
                        " Did you add it to the operator's params?" %
                        str(array_term.base.label))

                if not self.save_vars[str(array_term.base.label)]:
                    subs_dict[arg] = array_term.xreplace(self.t_replace)

        return sympy_expr.xreplace(subs_dict)
Пример #6
0
def test_postorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)]
    expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)]
    expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)]
    assert list(postorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(postorder_traversal(expr)) == [
        x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2, True,
        ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
    ]
    assert list(preorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        Integral(x**2, (x, 0, 1)), x**2, x, 2, Tuple(x, 0, 1), x, 0, 1
    ]
    assert list(preorder_traversal(('abc', ('d', 'ef')))) == [
        ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']
Пример #7
0
def test_postorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)]
    expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)]
    expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)]
    assert list(postorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(postorder_traversal(expr)) == [
        x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2, True,
        ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
    ]
    assert list(preorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        Integral(x**2, (x, 0, 1)), x**2, x, 2, Tuple(x, 0, 1), x, 0, 1
    ]
    assert list(preorder_traversal(('abc', ('d', 'ef')))) == [
        ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']
Пример #8
0
def test_preorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z + w*(x + y), z, w*(x + y), w, x + y, y, x]
    expected2 = [z + w*(x + y), z, w*(x + y), w, x + y, x, y]
    expected3 = [z + w*(x + y), w*(x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(preorder_traversal(expr)) == [
        Piecewise((x, x < 1), (x**2, True)), ExprCondPair(x, x < 1), x, x < 1,
        x, 1, ExprCondPair(x**2, True), x**2, x, 2, True
    ]
    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1, (0, 1), (x, (0, 1)), ((x, (0, 1)),),
        Integral(x**2, (x, 0, 1))
    ]
    assert list(postorder_traversal(('abc', ('d', 'ef')))) == [
        'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
Пример #9
0
def fcode(expr, assign_to=None, precision=15, user_functions={}, human=True):
    """Converts an expr to a string of Fortran 77 code

       Arguments:
         expr  --  a sympy expression to be converted

       Optional arguments:
         assign_to  --  When given, the argument is used as the name of the
                        variable to which the Fortran expression is assigned.
                        (This is helpful in case of line-wrapping.)
         precision  --  the precision for numbers such as pi [default=15]
         user_functions  --  A dictionary where keys are FunctionClass instances
                             and values are there string representations.
         human  --  If True, the result is a single string that may contain
                    some parameter statements for the number symbols. If
                    False, the same information is returned in a more
                    programmer-friendly data structure.

       >>> from sympy import fcode, symbols, Rational, pi, sin
       >>> x, tau = symbols(["x", "tau"])
       >>> fcode((2*tau)**Rational(7,2))
       '      8*sqrt(2)*tau**(7.0/2.0)'
       >>> fcode(sin(x), assign_to="s")
       '      s = sin(x)'
       >>> print fcode(pi)
             parameter (pi = 3.14159265358979)
             pi

    """
    # find all number symbols
    number_symbols = set([])
    for sub in postorder_traversal(expr):
        if isinstance(sub, NumberSymbol):
            number_symbols.add(sub)
    number_symbols = [(str(ns), ns.evalf(precision)) for ns in sorted(number_symbols)]
    # run the printer
    profile = {
        "full_prec": False, # programmers don't care about trailing zeros.
        "assign_to": assign_to,
        "user_functions": user_functions,
    }
    printer = FCodePrinter(profile)
    result = printer.doprint(expr)
    # format the output
    if human:
        lines = []
        if len(printer.not_fortran) > 0:
            lines.append("C     Not Fortran 77:")
            for expr in sorted(printer.not_fortran):
                lines.append("C     %s" % expr)
        for name, value in number_symbols:
            lines.append("      parameter (%s = %s)" % (name, value))
        lines.extend(result.split("\n"))
        lines = wrap_fortran(lines)
        return "\n".join(lines)
    else:
        return number_symbols, printer.not_fortran, result
Пример #10
0
    def _fix_integer_power(self, expr):
        subs = dict()

        for subexpr in list(postorder_traversal(expr)):
            if Is(subexpr).Pow:
                if Is(subexpr.args[1]).Integer and subexpr.args[1] > 0:
                    expr = expr.subs(
                        subexpr, Mul(*([subexpr.args[0]] * subexpr.args[1])))

        return expr
Пример #11
0
    def doprint(self, expr):
        """Returns Fortran code for expr (as a string)"""
        # find all number symbols
        number_symbols = set([])
        for sub in postorder_traversal(expr):
            if isinstance(sub, NumberSymbol):
                number_symbols.add(sub)
        number_symbols = [(str(ns), ns.evalf(self._settings["precision"]))
                          for ns in sorted(number_symbols)]

        # keep a set of expressions that are not strictly translatable to
        # Fortran.
        self._not_fortran = set([])

        lines = []
        if isinstance(expr, Piecewise):
            # support for top-level Piecewise function
            for i, (e, c) in enumerate(expr.args):
                if i == 0:
                    lines.append("      if (%s) then" % self._print(c))
                elif i == len(expr.args) - 1 and c == True:
                    lines.append("      else")
                else:
                    lines.append("      else if (%s) then" % self._print(c))
                if self._settings["assign_to"] is None:
                    lines.append("        %s" % self._print(e))
                else:
                    lines.append("        %s = %s" %
                                 (self._settings["assign_to"], self._print(e)))
            lines.append("      end if")
            text = "\n".join(lines)
        else:
            line = StrPrinter.doprint(self, expr)
            if self._settings["assign_to"] is None:
                text = "      %s" % line
            else:
                text = "      %s = %s" % (self._settings["assign_to"], line)

        # format the output
        if self._settings["human"]:
            lines = []
            if len(self._not_fortran) > 0:
                lines.append("C     Not Fortran 77:")
                for expr in sorted(self._not_fortran):
                    lines.append("C     %s" % expr)
            for name, value in number_symbols:
                lines.append("      parameter (%s = %s)" % (name, value))
            lines.extend(text.split("\n"))
            lines = wrap_fortran(lines)
            result = "\n".join(lines)
        else:
            result = number_symbols, self._not_fortran, text

        del self._not_fortran
        return result
Пример #12
0
    def doprint(self, expr):
        """Returns Fortran code for expr (as a string)"""
        # find all number symbols
        number_symbols = set([])
        for sub in postorder_traversal(expr):
            if isinstance(sub, NumberSymbol):
                number_symbols.add(sub)
        number_symbols = [(str(ns), ns.evalf(self._settings["precision"]))
                          for ns in sorted(number_symbols)]

        # keep a set of expressions that are not strictly translatable to
        # Fortran.
        self._not_fortran = set([])

        lines = []
        if isinstance(expr, Piecewise):
            # support for top-level Piecewise function
            for i, (e, c) in enumerate(expr.args):
                if i == 0:
                    lines.append("      if (%s) then" % self._print(c))
                elif i == len(expr.args)-1 and c == True:
                    lines.append("      else")
                else:
                    lines.append("      else if (%s) then" % self._print(c))
                if self._settings["assign_to"] is None:
                    lines.append("        %s" % self._print(e))
                else:
                    lines.append("        %s = %s" % (self._settings["assign_to"], self._print(e)))
            lines.append("      end if")
            text = "\n".join(lines)
        else:
            line = StrPrinter.doprint(self, expr)
            if self._settings["assign_to"] is None:
                text = "      %s" % line
            else:
                text = "      %s = %s" % (self._settings["assign_to"], line)

        # format the output
        if self._settings["human"]:
            lines = []
            if len(self._not_fortran) > 0:
                lines.append("C     Not Fortran 77:")
                for expr in sorted(self._not_fortran):
                    lines.append("C     %s" % expr)
            for name, value in number_symbols:
                lines.append("      parameter (%s = %s)" % (name, value))
            lines.extend(text.split("\n"))
            lines = wrap_fortran(lines)
            result = "\n".join(lines)
        else:
            result = number_symbols, self._not_fortran, text

        del self._not_fortran
        return result
Пример #13
0
def test_postorder_traversal():
    expr = z + w*(x+y)
    expected = [z, w, x, y, x + y, w*(x + y), w*(x + y) + z]
    assert list(postorder_traversal(expr, key=default_sort_key)) == expected

    expr = Piecewise((x, x < 1), (x**2, True))
    expected = [
        x, 1, x, x < 1, ExprCondPair(x, x < 1),
        ExprCondPair.true_sentinel, 2, x, x**2,
        ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
     ]
    assert list(postorder_traversal(expr, key=default_sort_key)) == expected
    assert list(postorder_traversal([expr], key=default_sort_key)) == expected + [[expr]]

    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)),
        key=default_sort_key)) == [
        2, x, x**2, 0, 1, x, Tuple(x, 0, 1),
        Integral(x**2, Tuple(x, 0, 1))
    ]
    assert list(postorder_traversal(('abc', ('d', 'ef')))) == [
        'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
Пример #14
0
def test_preorder_traversal():
    expr = z + w * (x + y)
    expected1 = [z + w * (x + y), z, w * (x + y), w, x + y, y, x]
    expected2 = [z + w * (x + y), z, w * (x + y), w, x + y, x, y]
    expected3 = [z + w * (x + y), w * (x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x, x < 1), (x**2, True))
    assert list(preorder_traversal(expr)) == [
        Piecewise((x, x < 1), (x**2, True)),
        ExprCondPair(x, x < 1), x, x < 1, x, 1,
        ExprCondPair(x**2, True), x**2, x, 2, True
    ]
    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1, (0, 1), (x, (0, 1)), ((x, (0, 1)), ),
        Integral(x**2, (x, 0, 1))
    ]
    assert list(postorder_traversal(
        ('abc',
         ('d',
          'ef')))) == ['abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
Пример #15
0
def test_postorder_traversal():
    expr = z + w * (x + y)
    expected = [z, w, x, y, x + y, w * (x + y), w * (x + y) + z]
    assert list(postorder_traversal(expr, keys=default_sort_key)) == expected
    assert list(postorder_traversal(expr, keys=True)) == expected

    expr = Piecewise((x, x < 1), (x ** 2, True))
    expected = [
        x,
        1,
        x,
        x < 1,
        ExprCondPair(x, x < 1),
        2,
        x,
        x ** 2,
        true,
        ExprCondPair(x ** 2, True),
        Piecewise((x, x < 1), (x ** 2, True)),
    ]
    assert list(postorder_traversal(expr, keys=default_sort_key)) == expected
    assert list(postorder_traversal([expr], keys=default_sort_key)) == expected + [[expr]]

    assert list(postorder_traversal(Integral(x ** 2, (x, 0, 1)), keys=default_sort_key)) == [
        2,
        x,
        x ** 2,
        0,
        1,
        x,
        Tuple(x, 0, 1),
        Integral(x ** 2, Tuple(x, 0, 1)),
    ]
    assert list(postorder_traversal(("abc", ("d", "ef")))) == ["abc", "d", "ef", ("d", "ef"), ("abc", ("d", "ef"))]
Пример #16
0
def test_postorder_traversal():
    expr = z + w * (x + y)
    expected = [z, w, x, y, x + y, w * (x + y), w * (x + y) + z]
    assert list(postorder_traversal(expr, keys=default_sort_key)) == expected
    assert list(postorder_traversal(expr, keys=True)) == expected

    expr = Piecewise((x, x < 1), (x ** 2, True))
    expected = [
        x,
        1,
        x,
        x < 1,
        ExprCondPair(x, x < 1),
        2,
        x,
        x ** 2,
        true,
        ExprCondPair(x ** 2, True),
        Piecewise((x, x < 1), (x ** 2, True)),
    ]
    assert list(postorder_traversal(expr, keys=default_sort_key)) == expected
    assert list(postorder_traversal([expr], keys=default_sort_key)) == expected + [
        [expr]
    ]

    assert list(
        postorder_traversal(Integral(x ** 2, (x, 0, 1)), keys=default_sort_key)
    ) == [2, x, x ** 2, 0, 1, x, Tuple(x, 0, 1), Integral(x ** 2, Tuple(x, 0, 1))]
    assert list(postorder_traversal(("abc", ("d", "ef")))) == [
        "abc",
        "d",
        "ef",
        ("d", "ef"),
        ("abc", ("d", "ef")),
    ]
Пример #17
0
    def _needed_symbols(self, expr):

        l = set()
        symbols = set()

        for subexpr in postorder_traversal(expr):
            if hasattr(subexpr, 'free_symbols'):
                symbols = symbols.union(subexpr.free_symbols)

        for symb in symbols:
            if symb in self._decls:
                l.add(symb)
                l = l.union(self._needed_symbols(self._decls[symb]))
        return l
Пример #18
0
def all_back_sub(eqns, knowns, levels=-1, multiple_sols=False, sub_all=True):
    unks = get_eqns_unk(eqns, knowns)
    print "Knowns:", knowns
    print "Unknowns:", unks
    ord_unk_iter = UpdatingPermutationIterator(
        unks, levels if levels != -1 else len(unks))
    sols = []
    tot_to_test = len(list(ord_unk_iter))
    print "Searching a possible %d orders" % tot_to_test
    print "Hit control-C to stop searching and return solutions already found."
    ord_unk_iter.reset()
    num_tested = 0
    for ord_unks in ord_unk_iter:
        try:
            #        print "Testing order:", ord_unks
            num_tested += 1
            if num_tested % (tot_to_test / 10 if tot_to_test > 10 else 2) == 0:
                print "Tested: ", num_tested, ", Solutions:", len(sols)
            sol_dict, failed_var = backward_sub(eqns, knowns, ord_unks,
                                                multiple_sols, sub_all)
            #        print "  result:", sol_dict, failed_var
            if sol_dict is None:
                if failed_var in ord_unks:
                    ord_unk_iter.bad_pos(ord_unks.index(failed_var))
            else:
                #                for var in sol_dict:
                #                    sol_dict[var] = sol_dict[var].expand()
                if len(filter(lambda x: x[0] == sol_dict, sols)) == 0:
                    sols.append((sol_dict, ord_unks))
                    print "Found new solution:\n%s" % pprint.pformat(
                        sol_dict, 4, 80)
        except KeyboardInterrupt:
            break
    print "Tested %d orders" % num_tested
    print "Found %d unique solutions" % len(sols)
    sols.sort(key=lambda s: sum([len(list(postorder_traversal(v))) \
                                      for _, v in s[0].iteritems()]))
    return sols
Пример #19
0
def all_back_sub(eqns, knowns, levels= -1, multiple_sols=False, sub_all=True):
    unks = get_eqns_unk(eqns, knowns)
    print "Knowns:", knowns
    print "Unknowns:", unks
    ord_unk_iter = UpdatingPermutationIterator(unks,
                                       levels if levels != -1 else len(unks))
    sols = []
    tot_to_test = len(list(ord_unk_iter))
    print "Searching a possible %d orders" % tot_to_test
    print "Hit control-C to stop searching and return solutions already found."
    ord_unk_iter.reset()
    num_tested = 0
    for ord_unks in ord_unk_iter:
        try:
    #        print "Testing order:", ord_unks
            num_tested += 1
            if num_tested % (tot_to_test / 10 if tot_to_test > 10 else 2) == 0:
                print "Tested: ", num_tested, ", Solutions:", len(sols)
            sol_dict, failed_var = backward_sub(eqns, knowns, ord_unks,
                                                multiple_sols, sub_all)
    #        print "  result:", sol_dict, failed_var
            if sol_dict is None:
                if failed_var in ord_unks:
                    ord_unk_iter.bad_pos(ord_unks.index(failed_var))
            else:
#                for var in sol_dict:
#                    sol_dict[var] = sol_dict[var].expand()
                if len(filter(lambda x: x[0] == sol_dict, sols)) == 0:
                    sols.append((sol_dict, ord_unks))
                    print "Found new solution:\n%s" % pprint.pformat(sol_dict, 4, 80)
        except KeyboardInterrupt:
            break
    print "Tested %d orders" % num_tested
    print "Found %d unique solutions" % len(sols)
    sols.sort(key=lambda s: sum([len(list(postorder_traversal(v))) \
                                      for _, v in s[0].iteritems()]))
    return sols
Пример #20
0
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.

    Parameters:

    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The `numbered_symbols` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
        iterator.
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        `sympy.simplify.cse.cse_optimizations` is used.

    Returns:

    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    """
    if symbols is None:
        symbols = numbered_symbols()
    else:
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    to_eliminate = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    for expr in exprs:
        for subtree in postorder_traversal(expr):
            if subtree.args == ():
                # Exclude atoms, since there is no point in renaming them.
                continue
            if (subtree.args != () and subtree in seen_subexp
                    and subtree not in to_eliminate):
                to_eliminate.append(subtree)
            seen_subexp.add(subtree)

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        # WARNING: modifying iterated list in-place! I think it's fine,
        # but there might be clearer alternatives.
        for j in range(i + 1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs
    ]

    return replacements, reduced_exprs
Пример #21
0
    def doprint(self, expr):
        """Returns Fortran code for expr (as a string)"""
        # find all number symbols
        number_symbols = set([])
        for sub in postorder_traversal(expr):
            if isinstance(sub, NumberSymbol):
                number_symbols.add(sub)
        number_symbols = [(str(ns), ns.evalf(self._settings["precision"]))
                          for ns in sorted(number_symbols)]

        # keep a set of expressions that are not strictly translatable to
        # Fortran.
        self._not_fortran = set([])


        # Setup loops if expression contain Indexed objects
        openloop, closeloop, local_ints = self._get_loop_opening_ending_ints(expr)

        self._not_fortran |= set(local_ints)

        # the lhs may contain loops that are not in the rhs
        lhs = self._settings['assign_to']
        if lhs:
            open_lhs, close_lhs, lhs_ints = self._get_loop_opening_ending_ints(lhs)
            for n,ind in enumerate(lhs_ints):
                if ind not in self._not_fortran:
                    self._not_fortran.add(ind)
                    openloop.insert(0,open_lhs[n])
                    closeloop.append(close_lhs[n])
            lhs_printed = self._print(lhs)

        lines = []
        if isinstance(expr, Piecewise):
            # support for top-level Piecewise function
            for i, (e, c) in enumerate(expr.args):
                if i == 0:
                    lines.append("if (%s) then" % self._print(c))
                elif i == len(expr.args)-1 and c == True:
                    lines.append("else")
                else:
                    lines.append("else if (%s) then" % self._print(c))
                if self._settings["assign_to"] is None:
                    lines.extend(openloop)
                    lines.append("  %s" % self._print(e))
                    lines.extend(closeloop)
                else:
                    lines.extend(openloop)
                    lines.append("  %s = %s" % (lhs_printed, self._print(e)))
                    lines.extend(closeloop)
            lines.append("end if")
        else:
            lines.extend(openloop)
            line = StrPrinter.doprint(self, expr)
            if self._settings["assign_to"] is None:
                text = "%s" % line
            else:
                text = "%s = %s" % (lhs_printed, line)
            lines.append(text)
            lines.extend(closeloop)

        # format the output
        if self._settings["human"]:
            frontlines = []
            if len(self._not_fortran) > 0:
                frontlines.append("! Not Fortran:")
                for expr in sorted(self._not_fortran, key=self._print):
                    frontlines.append("! %s" % expr)
            for name, value in number_symbols:
                frontlines.append("parameter (%s = %s)" % (name, value))
            frontlines.extend(lines)
            lines = frontlines
            lines = self._pad_leading_columns(lines)
            lines = self._wrap_fortran(lines)
            lines = self.indent_code(lines)
            result = "\n".join(lines)
        else:
            lines = self._pad_leading_columns(lines)
            lines = self._wrap_fortran(lines)
            lines = self.indent_code(lines)
            result = number_symbols, self._not_fortran, "\n".join(lines)

        del self._not_fortran
        return result
Пример #22
0
def codegen(name_expr,
            language,
            prefix,
            project="project",
            to_files=False,
            header=True,
            empty=True):
    """Write source code for the given expressions in the given language.

       Mandatory Arguments:
         name_expr  --  A single (name, expression) tuple or a list of
                        (name, expression) tuples. Each tuple corresponds to a
                        routine
         language  --  A string that indicates the source code language. This
                       is case insensitive. For the moment, only 'C' is
                       supported.
         prefix  --  A prefix for the names of the files that contain the source
                     code. Proper (language dependent) suffixes will be
                     appended.

       Optional Arguments:
         project  --  A project name, used for making unique preprocessor
                      instructions. [DEFAULT="project"]
         to_files  --  When True, the code will be written to one or more files
                       with the given prefix, otherwise strings with the names
                       and contents of these files are returned. [DEFAULT=False]
         header  --  When True, a header is written on top of each source file.
                     [DEFAULT=True]
         empty  --  When True, empty lines are used to structure the code.
                    [DEFAULT=True]

       >>> from sympy import symbols
       >>> from sympy.utilities.codegen import codegen
       >>> from sympy.abc import x, y, z
       >>> [(c_name, c_code), (h_name, c_header)] = \\
       ...     codegen(("f", x+y*z), "C", "test", header=False, empty=False)
       >>> print c_name
       test.c
       >>> print c_code,
       #include "test.h"
       #include <math.h>
       double f(double x, double y, double z) {
         return x + y*z;
       }
       >>> print h_name
       test.h
       >>> print c_header,
       #ifndef PROJECT__TEST__H
       #define PROJECT__TEST__H
       double f(double x, double y, double z);
       #endif

    """

    # Initialize the code generator.
    CodeGenClass = {"C": CCodeGen}.get(language.upper())
    if CodeGenClass is None:
        raise ValueError("Language '%s' is not supported." % language)
    code_gen = CodeGenClass(project)

    # Construct the routines based on the name_expression pairs.
    #  mainly the input arguments require some work
    routines = []
    if isinstance(name_expr[0], basestring):
        # single tuple is given, turn it into a singleton list with a tuple.
        name_expr = [name_expr]
    for name, expr in name_expr:
        symbols = set([])
        for sub in postorder_traversal(expr):
            if isinstance(sub, Symbol):
                symbols.add(sub)
        routines.append(
            Routine(name,
                    [InputArgument(symbol) for symbol in sorted(symbols)],
                    [Result(expr)]))

    # Write the code.
    return code_gen.write(routines, prefix, to_files, header, empty)
Пример #23
0
    def _cse(self, expr):

        l = set()

        exprs = set()

        expr_n = expr
        # y = Wild('y')

        # expr_n = expr.\
        #     replace(sqrt(2),
        #             Float(sqrt(2).evalf(n=128), 128)). \
        #     replace(Pow(y, -Rational(3, 2)),
        #             lambda y: 1. / Mul(y, fsqrt(y),
        #                                evaluate=False)).\
        #     replace(Pow(y, Rational(3, 2)),
        #             lambda y: Mul(y, fsqrt(y),
        #                           evaluate=False)).\
        #     replace(Pow(y, Rational(5, 2)),
        #             lambda y: Mul(y, y, fsqrt(y),
        #                           evaluate=False)).\
        #     replace(Pow(y, -Rational(5, 2)),
        #             lambda y: 1. / Mul(y, y, fsqrt(y),
        #                                evaluate=False)).\
        #     replace(Pow(y, Rational(1, 2)),
        #             lambda y: fsqrt(y)).\
        #     replace(Pow(y, -Rational(1, 2)),
        #             lambda y: 1. / fsqrt(y))
        #                 replace(lambda expr: expr.is_Pow and expr.args[1].is_Integer and expr.args[1]>2,
        #                         lambda z: self.recurs_mul(z.args[0], z.args[1])).\
        #                 replace(lambda expr: expr.is_Pow and expr.args[1].is_Integer and expr.args[1]<-2,
        # lambda z: 1./(self.recurs_mul(z.args[0], -z.args[1])))

        user_defs = zip(self._some_vars, self._user_exprs)
        user_subs = [(v, e[0]) for v, e in user_defs]
        self._var_asserts = dict([
            (v, ['{0} {1}'.format(v, spec)
                 for spec in e[1]]) for v, e in filter(
                     lambda c: c[1] is not None or c[1] != [], user_defs)
        ])

        for (v, e) in user_subs:
            expr_n = expr_n.replace(e, v)

        for subexpr in postorder_traversal(expr_n):

            # AC & JM: slower
            #            if Is(subexpr).Piecewise:
            #                for (e, c) in subexpr.args:
            #                    exprs.add(e)
            #                    exprs.add(c)
            if Is(subexpr).Function and not subexpr.is_Piecewise:
                exprs.add(subexpr)

            elif Is(subexpr).Pow:
                exprs.add(subexpr)
                exprs.add(subexpr.args[0])

            elif Is(subexpr).Relational:
                exprs.add(subexpr)

#            elif Is(subexpr).Pow and not (Is(subexpr.args[0]).Symbol or Is(subexpr.args[0]).Number):
#                exprs.add(subexpr.args[0])

# AC & JM slower
#            elif Is(subexpr).Mul or Is(subexpr).Add:
#                if len(subexpr.args) > 1:
#                    for e in subexpr.args:
#                        exprs.add(e)
#            elif Is(subexpr).ExprCondPair:
#                exprs.add(subexpr.args[0])
#                exprs.add(subexpr.args[1])
#            else:
#                print 'unknown subexpr:', type(subexpr), subexpr
        (vs, exprs) = cse([expr_n] + list(exprs), self._some_vars)

        return ((user_subs + vs), (exprs + self._user_exprs))
Пример #24
0
    def simplify(self, deep=False):
        from sympy import S
        if self.function.func == self.func:
            exists = self.function
            return self.func(exists.function, *exists.limits + self.limits).simplify()

        this = self.delete_independent_variables()
        if this is not None:
            return this
        
        function = self.function
        if function.is_And or function.is_Or:
            for t in range(len(self.limits)):
                x, *domain = self.limits[t]            
                index = []
                for i, eq in enumerate(function.args):
                    if eq._has(x):
                        index.append(i)
                if len(index) == 1:
                    
                    if any(limit._has(x) for limit in self.limits[:t]):
                        continue
                    
                    if len(domain) == 1 and domain[0].is_boolean:
                        continue
                    
                    index = index[0]
                    eqs = [*function.args]

                    eqs[index] = self.func(eqs[index], (x, *domain)).simplify()
                    limits = self.limits_delete(x)
                    
                    if limits:
                        function = function.func(*eqs)                        
                        return self.func(function, *limits).simplify()
                    else:
                        return function.func(*eqs)
            limits_cond = self.limits_cond
            for i, eq in enumerate(self.function.args):
                eq &= limits_cond
                copy = False
                shrink = False
                if eq:
                    if self.function.is_Or:
                        copy = True
                    else:
                        shrink = True
                elif eq.is_BooleanFalse:
                    if self.function.is_And:
                        copy = True
                    else:
                        shrink = True

                if copy:
                    return eq
                if shrink:
                    args = [*self.function.args]
                    del args[i]
                    function = self.function.func(*args)
                    return self.func(function, *self.limits).simplify()

        if deep:
            function = self.function
            reps = {}
            for x, domain in limits_dict.items():
                if domain.is_set and domain.is_integer:
                    _x = x.copy(domain=domain)
                    function = function._subs(x, _x).simplify(deep=deep)
                    reps[_x] = x
            if reps:
                for _x, x in reps.items():
                    function = function._subs(_x, x)
                if function != self.function:
                    return self.func(function, *self.limits).simplify()

        for i, (x, *domain) in enumerate(self.limits):
            if len(domain) == 1:
                domain = domain[0]
                if domain.is_FiniteSet and len(domain) == 1:
                    if len(self.limits) == 1: 
                        return self.func(self.finite_aggregate(x, domain), *self.limits_delete(x)).simplify()
                if domain.is_Contains:
                    if domain.lhs == x:
                        domain = domain.rhs
                        limits = self.limits_update({x:domain})
                        return self.func(self.function, *limits).simplify()
                elif domain.is_ConditionSet:
                    if x == domain.variable: 
                        condition = domain.condition
                    else:
                        condition = domain.condition._subs(domain.variable, x)                        
                        
                    limits = [*self.limits]
                    if domain.base_set.is_UniversalSet:
                        limits[i] = (x, condition)
                    else:
                        limits[i] = (x, condition, domain.base_set)
                    return self.func(self.function, *limits).simplify()
                elif domain.is_UniversalSet:
                    limits = [*self.limits]
                    limits[i] = (x,)
                    return self.func(self.function, *limits).simplify()

        for i, limit in enumerate(self.limits):
            if len(limit) == 1:
                continue            
            if len(limit) == 3:
                e, cond, baseset = limit
                if baseset.is_set:
                    if cond == self.function: 
                        return S.BooleanTrue
            else:
                e, s = limit
                if s.is_set:
                    if s.is_Symbol or s.is_Indexed:
                        continue
                    
                    if s.is_Piecewise:
                        if s.args[-1][0].is_EmptySet:
                            s = s.func(*s.args[:-2], (s.args[-2][0], True))                            
                        
                            limits = [*self.limits]
                            limits[i] = (e, s)
                            return self.func(self.function, *limits).simplify()
                        continue
                    
                    image_set = s.image_set()
                    if image_set is not None:
                        sym, expr, base_set = image_set
                        if self.function.is_ExprWithLimits:
                            if sym in self.function.bound_symbols:
                                _sym = base_set.element_symbol(self.function.variables_set)
                                assert sym.shape == _sym.shape
                                _expr = expr.subs(sym, _sym)
                                if _expr == expr:
                                    for var in postorder_traversal(expr):
                                        if var._has(sym):
                                            break
                                    expr = expr._subs(var, var.definition)
                                    _expr = expr._subs(sym, _sym)
    
                                expr = _expr
                                sym = _sym
                            assert sym not in self.function.bound_symbols
                        
                        function = self.function
                        if e != expr:
                            if sym.type == e.type:
                                _expr = expr._subs(sym, e)                        
                                if _expr != e:
                                    _function = function._subs(e, expr)
                                    if _function == function:
                                        return self
                                    limits = self.limits_update({e: (sym, base_set)})
                                    function = _function          
                                else:
                                    base_set = base_set._subs(sym, e)
                                    limits = self.limits_update(e, base_set)
                            else:
                                _function = function._subs(e, expr)
                                if _function == function:
                                    return self
                                limits = self.limits_update({e: (sym, base_set)})
                                function = _function          
                        else:
                            limits = self.limits_update({e: (sym, base_set)})                            
                        return self.func(function, *limits).simplify()
                else:  # s.type.is_condition: 
                    if s.is_Equal:
                        if e == s.lhs:
                            y = s.rhs
                        elif e == s.rhs:
                            y = s.lhs
                        else:
                            y = None
                        if y is not None and not y.has(e):
                            function = function._subs(e, y)
                            if function.is_BooleanAtom:
                                return function 
                            limits = self.limits_delete(e)
                            if limits:
                                return self.func(function, *limits)
                            return function
                    if s == self.function or s.dummy_eq(self.function):  # s.invert() | self.function
                        return S.BooleanTrue

        return ExprWithLimits.simplify(self, deep=deep)
Пример #25
0
def test_postorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)]
    expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)]
    expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)]
    assert list(postorder_traversal(expr)) in [expected1, expected2, expected3]
Пример #26
0
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.

    Parameters:

    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The `numbered_symbols` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
        iterator.
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        `sympy.simplify.cse.cse_optimizations` is used.

    Returns:

    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    """
    if symbols is None:
        symbols = numbered_symbols()
    else:
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    to_eliminate = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    for expr in exprs:
        for subtree in postorder_traversal(expr):
            if subtree.args == ():
                # Exclude atoms, since there is no point in renaming them.
                continue
            if (subtree.args != () and
                subtree in seen_subexp and
                subtree not in to_eliminate):
                to_eliminate.append(subtree)
            seen_subexp.add(subtree)

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        # WARNING: modifying iterated list in-place! I think it's fine,
        # but there might be clearer alternatives.
        for j in range(i+1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs]

    return replacements, reduced_exprs
Пример #27
0
def test_postorder_traversal():
    expr = z + w * (x + y)
    expected1 = [z, w, y, x, x + y, w * (x + y), z + w * (x + y)]
    expected2 = [z, w, x, y, x + y, w * (x + y), z + w * (x + y)]
    expected3 = [w, y, x, x + y, w * (x + y), z, z + w * (x + y)]
    assert list(postorder_traversal(expr)) in [expected1, expected2, expected3]
Пример #28
0
            if num_tested % (tot_to_test / 10 if tot_to_test > 10 else 2) == 0:
                print "Tested: ", num_tested, ", Solutions:", len(sols)
            try:
                sol_dict, failed_var = backward_sub(eqns, knowns, ord_unks,
                                                    multiple_sols, sub_all)
            except FlameTensorError, e:
                print "Error for:", ord_unks
                traceback.print_exc()
                continue

    #        print "  result:", sol_dict, failed_var
            if sol_dict is None:
                if failed_var in ord_unks:
                    ord_unk_iter.bad_pos(ord_unks.index(failed_var))
            else:
#                for var in sol_dict:
#                    sol_dict[var] = sol_dict[var].expand()
                if len(filter(lambda x: x[0] == sol_dict, sols)) == 0:
                    sols.append((sol_dict, ord_unks))
                    print "Found new solution:\n%s" % pprint.pformat(sol_dict, 4, 80)
        except KeyboardInterrupt:
            break
    print "Tested %d orders" % num_tested
    print "Found %d unique solutions" % len(sols)
    sols.sort(key=lambda s: sum([len(list(postorder_traversal(v))) \
                                      for _, v in s[0].iteritems()]))
    if not allow_recompute:
        sols = filter(lambda x:x, map(sol_without_recomputes, sols))
        print "Found %d unique solutions without recomputation" % len(sols)
    return sols
Пример #29
0
def codegen(name_expr, language, prefix, project="project", to_files=False, header=True, empty=True):
    """Write source code for the given expressions in the given language.

       Mandatory Arguments:
         name_expr  --  A single (name, expression) tuple or a list of
                        (name, expression) tuples. Each tuple corresponds to a
                        routine
         language  --  A string that indicates the source code language. This
                       is case insensitive. For the moment, only 'C' is
                       supported.
         prefix  --  A prefix for the names of the files that contain the source
                     code. Proper (language dependent) suffixes will be
                     appended.

       Optional Arguments:
         project  --  A project name, used for making unique preprocessor
                      instructions. [DEFAULT="project"]
         to_files  --  When True, the code will be written to one or more files
                       with the given prefix, otherwise strings with the names
                       and contents of these files are returned. [DEFAULT=False]
         header  --  When True, a header is written on top of each source file.
                     [DEFAULT=True]
         empty  --  When True, empty lines are used to structure the code.
                    [DEFAULT=True]

       >>> from sympy import symbols
       >>> from sympy.utilities.codegen import codegen
       >>> x, y, z = symbols('xyz')
       >>> [(c_name, c_code), (h_name, c_header)] = \\
       ...     codegen(("f", x+y*z), "C", "test", header=False, empty=False)
       >>> print c_name
       test.c
       >>> print c_code,
       #include "test.h"
       #include <math.h>
       double f(double x, double y, double z) {
         return x + y*z;
       }
       >>> print h_name
       test.h
       >>> print c_header,
       #ifndef PROJECT__TEST__H
       #define PROJECT__TEST__H
       double f(double x, double y, double z);
       #endif

    """

    # Initialize the code generator.
    CodeGenClass = {"C": CCodeGen}.get(language.upper())
    if CodeGenClass is None:
        raise ValueError("Language '%s' is not supported." % language)
    code_gen = CodeGenClass(project)

    # Construct the routines based on the name_expression pairs.
    #  mainly the input arguments require some work
    routines = []
    if isinstance(name_expr[0], basestring):
        # single tuple is given, turn it into a singleton list with a tuple.
        name_expr = [name_expr]
    for name, expr in name_expr:
        symbols = set([])
        for sub in postorder_traversal(expr):
            if isinstance(sub, Symbol):
                symbols.add(sub)
        routines.append(Routine(name, [InputArgument(symbol) for symbol in sorted(symbols)], [Result(expr)]))

    # Write the code.
    return code_gen.write(routines, prefix, to_files, header, empty)
Пример #30
0
                print "Tested: ", num_tested, ", Solutions:", len(sols)
            try:
                sol_dict, failed_var = backward_sub(eqns, knowns, ord_unks,
                                                    multiple_sols, sub_all)
            except FlameTensorError, e:
                print "Error for:", ord_unks
                traceback.print_exc()
                continue

    #        print "  result:", sol_dict, failed_var
            if sol_dict is None:
                if failed_var in ord_unks:
                    ord_unk_iter.bad_pos(ord_unks.index(failed_var))
            else:
                #                for var in sol_dict:
                #                    sol_dict[var] = sol_dict[var].expand()
                if len(filter(lambda x: x[0] == sol_dict, sols)) == 0:
                    sols.append((sol_dict, ord_unks))
                    print "Found new solution:\n%s" % pprint.pformat(
                        sol_dict, 4, 80)
        except KeyboardInterrupt:
            break
    print "Tested %d orders" % num_tested
    print "Found %d unique solutions" % len(sols)
    sols.sort(key=lambda s: sum([len(list(postorder_traversal(v))) \
                                      for _, v in s[0].iteritems()]))
    if not allow_recompute:
        sols = filter(lambda x: x, map(sol_without_recomputes, sols))
        print "Found %d unique solutions without recomputation" % len(sols)
    return sols